sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,82 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import logging
4
+ from typing import Generator, List, Optional, Tuple
5
+ from urllib.parse import urlparse
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+
10
+ from sglang.srt.connector import BaseConnector
11
+ from sglang.srt.utils import init_custom_process_group
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class RemoteInstanceConnector(BaseConnector):
17
+
18
+ def __init__(self, url: str, device: torch.device = "cpu"):
19
+ assert (
20
+ device.type == "cuda"
21
+ ), "RemoteInstanceConnector only supports cuda device."
22
+ super().__init__(url)
23
+ self.url = url
24
+ self.device = device
25
+
26
+ def build_group(
27
+ self,
28
+ gpu_id: int = -1,
29
+ tp_rank: int = -1,
30
+ instance_ip: str = None,
31
+ group_rank: int = 1,
32
+ world_size: int = 2,
33
+ ):
34
+ assert (
35
+ self.device.type == "cuda"
36
+ ), "RemoteInstanceConnector only supports cuda device."
37
+ assert (
38
+ gpu_id != -1 and tp_rank != -1
39
+ ), "gpu_id and tp_rank must be specified for RemoteInstanceConnector. "
40
+
41
+ self.device_id = torch.device(self.device.type, gpu_id)
42
+
43
+ parsed_url = urlparse(self.url)
44
+ master_address = parsed_url.hostname
45
+ master_port = parsed_url.port
46
+ group_name = f"send_weights_{instance_ip}_{master_port}_{tp_rank}"
47
+ backend = "nccl"
48
+
49
+ logger.info(
50
+ f"init custom process group: master_address={master_address}, master_port={master_port}, "
51
+ f"rank_offset={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
52
+ )
53
+
54
+ try:
55
+ self._model_update_group = init_custom_process_group(
56
+ backend=backend,
57
+ init_method=f"tcp://{master_address}:{master_port}",
58
+ world_size=world_size,
59
+ rank=group_rank,
60
+ group_name=group_name,
61
+ device_id=self.device_id,
62
+ )
63
+ dist.barrier(group=self._model_update_group)
64
+ return True, "Succeeded to initialize custom process group."
65
+ except Exception as e:
66
+ message = f"Failed to initialize custom process group: {e}."
67
+ logger.error(message)
68
+ return False, message
69
+
70
+ # Implemented as a no-op to make BaseConnector interface consistent.
71
+ def pull_files(
72
+ self,
73
+ allow_pattern: Optional[list[str]] = None,
74
+ ignore_pattern: Optional[list[str]] = None,
75
+ ) -> None:
76
+ return
77
+
78
+ # Implemented as a no-op to make BaseConnector interface consistent.
79
+ def weight_iterator(
80
+ self, rank: int = 0
81
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
82
+ return
@@ -14,8 +14,9 @@
14
14
  """The baseclass of a backend for grammar-guided constrained decoding."""
15
15
 
16
16
  import logging
17
+ import time
17
18
  from concurrent.futures import ThreadPoolExecutor
18
- from dataclasses import dataclass
19
+ from dataclasses import dataclass, field
19
20
  from threading import Event
20
21
  from typing import Dict, List, Optional, Tuple
21
22
 
@@ -26,10 +27,23 @@ from sglang.srt.server_args import ServerArgs
26
27
  logger = logging.getLogger(__name__)
27
28
 
28
29
 
30
+ @dataclass
31
+ class GrammarStats:
32
+ compilation_time: Optional[float] = None
33
+ schema_count: Optional[int] = None
34
+ ebnf_size: Optional[int] = None
35
+ is_cache_hit: bool = False
36
+ is_grammar_aborted: bool = False
37
+ tree_traversal_time: List[float] = field(default_factory=list)
38
+ dispatch_type: Optional[str] = None
39
+
40
+
29
41
  class BaseGrammarObject:
30
42
 
31
43
  def __init__(self):
32
44
  self._finished = False
45
+ self.grammar_stats = None
46
+ self.current_token = None
33
47
 
34
48
  def accept_token(self, token: int) -> None:
35
49
  """
@@ -137,19 +151,26 @@ class BaseGrammarBackend:
137
151
  return self._not_supported("structural_tag", key_string)
138
152
 
139
153
  def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
154
+ s = time.perf_counter()
140
155
  key_type, key_string = key
141
156
  if key_type == "json":
142
- return self.dispatch_json(key_string)
157
+ grammar = self.dispatch_json(key_string)
143
158
  elif key_type == "regex":
144
- return self.dispatch_regex(key_string)
159
+ grammar = self.dispatch_regex(key_string)
145
160
  elif key_type == "ebnf":
146
- return self.dispatch_ebnf(key_string)
161
+ grammar = self.dispatch_ebnf(key_string)
147
162
  elif key_type == "structural_tag":
148
- return self.dispatch_structural_tag(key_string)
163
+ grammar = self.dispatch_structural_tag(key_string)
149
164
  elif key_type == "structural_pattern":
150
- return self.dispatch_structural_pattern(key_string)
165
+ grammar = self.dispatch_structural_pattern(key_string)
166
+ elif key_type == "structural_pattern_v2":
167
+ grammar = self.dispatch_structural_pattern_v2(key_string)
151
168
  else:
152
- return self.dispatch_fallback(key_type, key_string)
169
+ grammar = self.dispatch_fallback(key_type, key_string)
170
+
171
+ if grammar is not None and grammar.grammar_stats is not None:
172
+ grammar.grammar_stats.compilation_time = time.perf_counter() - s
173
+ return grammar
153
174
 
154
175
  def get_cached_or_future_value(
155
176
  self, key: Tuple[str, str]
@@ -167,20 +188,36 @@ class BaseGrammarBackend:
167
188
  self.cache.clear()
168
189
 
169
190
 
191
+ GRAMMAR_BACKEND_REGISTRY = {}
192
+
193
+
194
+ def register_grammar_backend(name, init_func):
195
+ GRAMMAR_BACKEND_REGISTRY[name] = init_func
196
+
197
+
170
198
  def create_grammar_backend(
171
199
  server_args: ServerArgs,
172
200
  tokenizer,
173
201
  vocab_size: int,
174
202
  eos_token_ids: Optional[set] = None,
175
203
  ) -> Optional[BaseGrammarBackend]:
176
- if server_args.grammar_backend == "outlines":
204
+ name = server_args.grammar_backend
205
+
206
+ # Custom grammar backend has the highest priority
207
+ if name in GRAMMAR_BACKEND_REGISTRY:
208
+ return GRAMMAR_BACKEND_REGISTRY[name](
209
+ server_args, tokenizer, vocab_size, eos_token_ids
210
+ )
211
+
212
+ # Default grammar backends
213
+ if name == "outlines":
177
214
  from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
178
215
 
179
216
  grammar_backend = OutlinesGrammarBackend(
180
217
  tokenizer,
181
218
  whitespace_pattern=server_args.constrained_json_whitespace_pattern,
182
219
  )
183
- elif server_args.grammar_backend == "xgrammar":
220
+ elif name == "xgrammar":
184
221
  from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
185
222
 
186
223
  # Convert Set[int] to List[int] if needed
@@ -189,17 +226,17 @@ def create_grammar_backend(
189
226
  grammar_backend = XGrammarGrammarBackend(
190
227
  tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list
191
228
  )
192
- elif server_args.grammar_backend == "llguidance":
229
+ elif name == "llguidance":
193
230
  from sglang.srt.constrained.llguidance_backend import GuidanceBackend
194
231
 
195
232
  grammar_backend = GuidanceBackend(
196
233
  tokenizer=tokenizer,
197
234
  whitespace_pattern=server_args.constrained_json_whitespace_pattern,
198
235
  )
199
- elif server_args.grammar_backend == "none":
236
+ elif name == "none":
200
237
  return None
201
238
  else:
202
- raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
239
+ raise ValueError(f"Invalid grammar backend: {name}")
203
240
 
204
241
  if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
205
242
  from sglang.srt.constrained.reasoner_grammar_backend import (
@@ -48,7 +48,6 @@ class GuidanceGrammar(BaseGrammarObject):
48
48
  self.serialized_grammar,
49
49
  log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
50
50
  )
51
- self.finished = False
52
51
  self.bitmask = None
53
52
 
54
53
  def accept_token(self, token: int):
@@ -49,7 +49,6 @@ class OutlinesGrammar(BaseGrammarObject):
49
49
  self.guide = guide
50
50
  self.jump_forward_map = jump_forward_map
51
51
  self.state = 0
52
- self.finished = False
53
52
 
54
53
  def accept_token(self, token: int):
55
54
  self.state = self.guide.get_next_state(self.state, token)
@@ -37,7 +37,7 @@ except ImportError:
37
37
 
38
38
  IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
39
39
 
40
- # Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
40
+ # Env var was set in sglang.srt.server_args.ServerArgs.__post_init__
41
41
  DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
42
42
 
43
43
  logger = logging.getLogger(__name__)
@@ -13,6 +13,7 @@
13
13
  # ==============================================================================
14
14
  """Constrained decoding with xgrammar backend."""
15
15
 
16
+ import dataclasses
16
17
  import json
17
18
  import logging
18
19
  from typing import List, Optional, Tuple, Union
@@ -31,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import (
31
32
  INVALID_GRAMMAR_OBJ,
32
33
  BaseGrammarBackend,
33
34
  BaseGrammarObject,
35
+ GrammarStats,
34
36
  )
35
37
  from sglang.srt.utils import is_hip
36
38
 
@@ -41,9 +43,9 @@ else:
41
43
  from sglang.srt.constrained.triton_ops.bitmask_ops import (
42
44
  apply_token_bitmask_inplace_triton,
43
45
  )
44
- logger = logging.getLogger(__name__)
45
46
 
46
47
 
48
+ logger = logging.getLogger(__name__)
47
49
  MAX_ROLLBACK_TOKENS = 200
48
50
 
49
51
 
@@ -56,17 +58,20 @@ class XGrammarGrammar(BaseGrammarObject):
56
58
  ctx: CompiledGrammar,
57
59
  override_stop_tokens: Optional[Union[List[int], int]],
58
60
  key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
61
+ grammar_stats: Optional[GrammarStats] = GrammarStats(),
59
62
  ) -> None:
63
+ super().__init__()
60
64
  self.matcher = matcher
61
65
  self.vocab_size = vocab_size
62
66
  self.ctx = ctx
63
67
  self.override_stop_tokens = override_stop_tokens
64
- self.finished = False
65
68
  self.accepted_tokens = []
66
69
  self.key_string = key_string
70
+ self.grammar_stats = grammar_stats
67
71
 
68
72
  def accept_token(self, token: int):
69
73
  if not self.is_terminated():
74
+ self.current_token = token
70
75
  accepted = self.matcher.accept_token(token)
71
76
  if not accepted:
72
77
  # log for debugging
@@ -120,6 +125,9 @@ class XGrammarGrammar(BaseGrammarObject):
120
125
  self.ctx,
121
126
  self.override_stop_tokens,
122
127
  self.key_string,
128
+ dataclasses.replace(
129
+ self.grammar_stats, is_cache_hit=True, tree_traversal_time=[]
130
+ ),
123
131
  )
124
132
 
125
133
  def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
@@ -150,7 +158,7 @@ class XGrammarGrammar(BaseGrammarObject):
150
158
  assert self.matcher.accept_token(new_output_ids[i])
151
159
 
152
160
  def __repr__(self):
153
- return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
161
+ return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=}, {self.current_token=})"
154
162
 
155
163
 
156
164
  class XGrammarGrammarBackend(BaseGrammarBackend):
@@ -165,6 +173,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
165
173
  if hasattr(tokenizer, "init_xgrammar"):
166
174
  # For special tokenizer
167
175
  tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar()
176
+
177
+ if tokenizer_info is None:
178
+ # Not supported tokenizer
179
+ return
168
180
  else:
169
181
  # Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
170
182
  # This ensures consistency between what the model considers EOS and what XGrammar uses
@@ -177,14 +189,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
177
189
  self.vocab_size = vocab_size
178
190
  self.override_stop_tokens = override_stop_tokens
179
191
 
180
- def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
192
+ def _from_context(
193
+ self, ctx: CompiledGrammar, key_string: str, grammar_stats: GrammarStats
194
+ ) -> XGrammarGrammar:
181
195
  matcher = GrammarMatcher(
182
196
  ctx,
183
197
  max_rollback_tokens=MAX_ROLLBACK_TOKENS,
184
198
  override_stop_tokens=self.override_stop_tokens,
185
199
  )
186
200
  return XGrammarGrammar(
187
- matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
201
+ matcher,
202
+ self.vocab_size,
203
+ ctx,
204
+ self.override_stop_tokens,
205
+ key_string,
206
+ grammar_stats,
188
207
  )
189
208
 
190
209
  def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
@@ -198,7 +217,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
198
217
  except (RuntimeError, json.decoder.JSONDecodeError) as e:
199
218
  logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
200
219
  return INVALID_GRAMMAR_OBJ
201
- return self._from_context(ctx, key_string)
220
+ return self._from_context(ctx, key_string, GrammarStats(dispatch_type="json"))
202
221
 
203
222
  def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
204
223
  try:
@@ -206,7 +225,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
206
225
  except RuntimeError as e:
207
226
  logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
208
227
  return INVALID_GRAMMAR_OBJ
209
- return self._from_context(ctx, key_string)
228
+ return self._from_context(ctx, key_string, GrammarStats(dispatch_type="ebnf"))
210
229
 
211
230
  def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
212
231
  try:
@@ -214,7 +233,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
214
233
  except RuntimeError as e:
215
234
  logging.error(f"Hit invalid regex: {key_string=}, {e=}")
216
235
  return INVALID_GRAMMAR_OBJ
217
- return self._from_context(ctx, key_string)
236
+ return self._from_context(ctx, key_string, GrammarStats(dispatch_type="regex"))
218
237
 
219
238
  def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
220
239
  try:
@@ -233,7 +252,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
233
252
  except (RuntimeError, json.decoder.JSONDecodeError) as e:
234
253
  logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
235
254
  return INVALID_GRAMMAR_OBJ
236
- return self._from_context(ctx, key_string)
255
+ return self._from_context(
256
+ ctx, key_string, GrammarStats(dispatch_type="structural_tag")
257
+ )
237
258
 
238
259
  def reset(self):
239
260
  self.grammar_compiler.clear_cache()
sglang/srt/custom_op.py CHANGED
@@ -1,12 +1,20 @@
1
1
  from torch import nn
2
2
 
3
- from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
3
+ from sglang.srt.utils import (
4
+ cpu_has_amx_support,
5
+ is_cpu,
6
+ is_cuda,
7
+ is_hip,
8
+ is_npu,
9
+ is_xpu,
10
+ )
4
11
 
5
12
  _is_cuda = is_cuda()
6
13
  _is_hip = is_hip()
7
14
  _is_cpu = is_cpu()
8
15
  _is_cpu_amx_available = cpu_has_amx_support()
9
16
  _is_npu = is_npu()
17
+ _is_xpu = is_xpu()
10
18
 
11
19
 
12
20
  class CustomOp(nn.Module):
@@ -88,5 +96,7 @@ class CustomOp(nn.Module):
88
96
  return self.forward_cpu
89
97
  elif _is_npu:
90
98
  return self.forward_npu
99
+ elif _is_xpu:
100
+ return self.forward_xpu
91
101
  else:
92
102
  return self.forward_native
@@ -1,11 +1,11 @@
1
1
  import argparse
2
2
  import functools
3
- import re
4
3
  from pathlib import Path
5
4
 
6
5
  import polars as pl
7
6
  import torch
8
7
 
8
+ from sglang.srt.debug_utils.dump_loader import find_row, read_meta
9
9
  from sglang.srt.debug_utils.dumper import get_truncated_value
10
10
 
11
11
 
@@ -26,66 +26,77 @@ def main(args):
26
26
  print("df_baseline", df_baseline)
27
27
 
28
28
  for row in df_target.iter_rows(named=True):
29
- rows_baseline = df_baseline.filter(
30
- (
31
- pl.col("forward_pass_id")
32
- == row["forward_pass_id"] - args.start_id + args.baseline_start_id
33
- )
34
- & functools.reduce(
35
- lambda a, b: a & b,
36
- [
37
- pl.col(col) == row[col]
38
- for col in row.keys()
39
- if col not in ["forward_pass_id", "dump_index", "filename"]
40
- ],
41
- )
29
+ path_target = Path(args.target_path) / row["filename"]
30
+
31
+ row_baseline = find_row(
32
+ df_baseline,
33
+ conditions=dict(
34
+ forward_pass_id=row["forward_pass_id"]
35
+ - args.start_id
36
+ + args.baseline_start_id,
37
+ **{
38
+ k: v
39
+ for k, v in row.items()
40
+ if k not in ["forward_pass_id", "dump_index", "filename"]
41
+ },
42
+ ),
42
43
  )
43
- assert len(rows_baseline) == 1, f"{rows_baseline=}"
44
- row_baseline = rows_baseline.to_dicts()[0]
44
+
45
+ if row_baseline is None:
46
+ print(f"Skip: target={str(path_target)} since no baseline")
47
+ x_target = _load_object(path_target)
48
+ if x_target is not None:
49
+ print(f"x_target(sample)={get_truncated_value(x_target)}")
50
+ continue
45
51
 
46
52
  path_baseline = Path(args.baseline_path) / row_baseline["filename"]
47
- path_target = Path(args.target_path) / row["filename"]
48
53
  print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
49
- check_tensor_pair(path_baseline=path_baseline, path_target=path_target)
54
+ check_tensor_pair(
55
+ path_baseline=path_baseline, path_target=path_target, name=row["name"]
56
+ )
50
57
  print()
51
58
 
52
59
 
53
- def read_meta(directory):
54
- directory = Path(directory)
55
- assert directory.is_dir(), f"{directory=} should be a directory"
56
-
57
- rows = []
58
- for p in directory.glob("*.pt"):
59
- full_kwargs = {}
60
- for kv in p.stem.split("___"):
61
- k, v = kv.split("=")
62
- full_kwargs[k] = v
63
- rows.append(
64
- {
65
- "filename": str(p.name),
66
- **full_kwargs,
67
- }
68
- )
60
+ def check_tensor_pair(path_baseline, path_target, name=""):
61
+ x_baseline = _load_object(path_baseline)
62
+ x_target = _load_object(path_target)
69
63
 
70
- df = pl.DataFrame(rows)
71
- df = df.with_columns(
72
- pl.col("forward_pass_id").cast(int),
73
- pl.col("rank").cast(int),
64
+ print(
65
+ f"Raw "
66
+ f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
67
+ f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
74
68
  )
75
- return df
76
-
77
69
 
78
- def check_tensor_pair(path_baseline, path_target):
79
- x_baseline = torch.load(path_baseline, weights_only=True)
80
- x_target = torch.load(path_target, weights_only=True)
70
+ x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name)
71
+ x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape)
81
72
 
82
73
  print(
74
+ f"After preprocessor "
83
75
  f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
84
76
  f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
85
77
  )
86
78
 
79
+ x_target = x_target.float()
80
+ x_baseline = x_baseline.float()
81
+
82
+ for name, fn in (
83
+ ("mean", torch.mean),
84
+ ("std", torch.std),
85
+ ("min", torch.min),
86
+ ("max", torch.max),
87
+ ("p1", functools.partial(torch.quantile, q=0.01)),
88
+ ("p5", functools.partial(torch.quantile, q=0.05)),
89
+ ("p95", functools.partial(torch.quantile, q=0.95)),
90
+ ("p99", functools.partial(torch.quantile, q=0.99)),
91
+ ):
92
+ value_baseline = fn(x_baseline).item()
93
+ value_target = fn(x_target).item()
94
+ print(
95
+ f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})"
96
+ )
97
+
87
98
  if x_baseline.shape != x_target.shape:
88
- print(f" Shape mismatch")
99
+ print(f"⚠️ Shape mismatch")
89
100
  return
90
101
 
91
102
  raw_abs_diff = (x_target - x_baseline).abs()
@@ -112,6 +123,19 @@ def check_tensor_pair(path_baseline, path_target):
112
123
  print(f"x_target(sample)={get_truncated_value(x_target)}")
113
124
 
114
125
 
126
+ def _try_unify_shape(x: torch.Tensor, target_shape):
127
+ x_shape = x.shape
128
+ num_dim_to_remove = len(x_shape) - len(target_shape)
129
+ if (x_shape[num_dim_to_remove:] == target_shape) and all(
130
+ val == 1 for val in x_shape[:num_dim_to_remove]
131
+ ):
132
+ out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)
133
+ print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})")
134
+ return out
135
+
136
+ return x
137
+
138
+
115
139
  # Copied from DeepGEMM
116
140
  def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
117
141
  x, y = x.double(), y.double()
@@ -120,6 +144,19 @@ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
120
144
  return 1 - sim
121
145
 
122
146
 
147
+ def _comparison_preprocessor(x_baseline, x_target, name):
148
+ # can insert arbitrary adhoc postprocessing logic here
149
+ return x_baseline, x_target
150
+
151
+
152
+ def _load_object(path):
153
+ x = torch.load(path, weights_only=False)
154
+ if not isinstance(x, torch.Tensor):
155
+ print(f"Skip load {path} since {type(x)=} is not a Tensor")
156
+ return None
157
+ return x.cuda()
158
+
159
+
123
160
  if __name__ == "__main__":
124
161
  parser = argparse.ArgumentParser()
125
162
  parser.add_argument("--baseline-path", type=str)
@@ -0,0 +1,97 @@
1
+ import functools
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Any, Dict
5
+
6
+ import polars as pl
7
+ import torch
8
+
9
+
10
+ class DumpLoader:
11
+ def __init__(self):
12
+ directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")
13
+
14
+ self._enable = directory is not None
15
+ if self._enable:
16
+ self._directory = Path(directory)
17
+ self._df = read_meta(directory)
18
+
19
+ @property
20
+ def enable(self):
21
+ return self._enable
22
+
23
+ def load(self, name, **kwargs):
24
+ assert self._enable, "Please call DumpLoader.load only when it is enabled"
25
+
26
+ from sglang.srt.debug_utils.dumper import dumper
27
+
28
+ forward_pass_id = dumper._forward_pass_id
29
+ conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
30
+ row = find_row(self._df, conditions=conditions)
31
+ assert (
32
+ row is not None
33
+ ), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"
34
+
35
+ path = self._directory / row["filename"]
36
+ output = torch.load(path, weights_only=False)
37
+
38
+ print(
39
+ f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
40
+ )
41
+ return output
42
+
43
+
44
+ def read_meta(directory):
45
+ directory = Path(directory)
46
+ assert directory.is_dir(), f"{directory=} should be a directory"
47
+
48
+ rows = []
49
+ for p in directory.glob("*.pt"):
50
+ full_kwargs = {}
51
+ for kv in p.stem.split("___"):
52
+ k, v = kv.split("=")
53
+ full_kwargs[k] = v
54
+ rows.append(
55
+ {
56
+ "filename": str(p.name),
57
+ **full_kwargs,
58
+ }
59
+ )
60
+
61
+ df = pl.DataFrame(rows)
62
+ df = df.with_columns(
63
+ pl.col("forward_pass_id").cast(int),
64
+ pl.col("rank").cast(int),
65
+ pl.col("dump_index").cast(int),
66
+ )
67
+ return df
68
+
69
+
70
+ def find_row(df, conditions: Dict[str, Any]):
71
+ df_sub = df.filter(
72
+ functools.reduce(
73
+ lambda a, b: a & b,
74
+ [
75
+ pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])
76
+ for col in conditions.keys()
77
+ ],
78
+ )
79
+ )
80
+ assert len(df_sub) <= 1
81
+ return df_sub.to_dicts()[0] if len(df_sub) > 0 else None
82
+
83
+
84
+ def _cast_to_polars_dtype(value, target_dtype):
85
+ if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
86
+ return int(value)
87
+ elif target_dtype in (pl.Float64, pl.Float32):
88
+ return float(value)
89
+ elif target_dtype == pl.Boolean:
90
+ return bool(value)
91
+ elif target_dtype == pl.String:
92
+ return str(value)
93
+ else:
94
+ return value
95
+
96
+
97
+ dump_loader = DumpLoader()