sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
sglang/check_env.py CHANGED
@@ -47,7 +47,7 @@ PACKAGE_LIST = [
47
47
  "tiktoken",
48
48
  "anthropic",
49
49
  "litellm",
50
- "decord",
50
+ "decord2",
51
51
  ]
52
52
 
53
53
 
@@ -19,6 +19,7 @@ import requests
19
19
 
20
20
  from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
21
21
  from sglang.srt.entrypoints.http_server import launch_server
22
+ from sglang.srt.environ import envs
22
23
  from sglang.srt.managers.io_struct import GenerateReqInput
23
24
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
24
25
  from sglang.srt.server_args import ServerArgs
@@ -28,9 +29,9 @@ from sglang.srt.warmup import warmup
28
29
  multiprocessing.set_start_method("spawn", force=True)
29
30
 
30
31
  # Reduce warning
31
- os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
32
+ envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.set(True)
32
33
  # Force enable deep gemm
33
- os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
34
+ envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(True)
34
35
  # Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
35
36
  os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
36
37
 
@@ -141,6 +142,9 @@ def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
141
142
  server_args.enable_torch_compile = False
142
143
  print(f"Disable CUDA Graph and Torch Compile to save time...")
143
144
 
145
+ server_args.load_format = "dummy"
146
+ print(f"Set load format to dummy to save time...")
147
+
144
148
  # Set watchdog timeout to compile_args.timeout because compilation will take a long time
145
149
  server_args.watchdog_timeout = compile_args.timeout
146
150
  server_args.warmups = "compile-deep-gemm"
sglang/global_config.py CHANGED
@@ -1,14 +1,11 @@
1
1
  """Global configurations"""
2
2
 
3
- import os
3
+ # FIXME: deprecate this file and move all usage to sglang.srt.environ or sglang.__init__.py
4
4
 
5
5
 
6
6
  class GlobalConfig:
7
7
  """
8
8
  Store some global constants.
9
-
10
- See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
11
- many global runtime arguments as well.
12
9
  """
13
10
 
14
11
  def __init__(self):
@@ -20,27 +17,6 @@ class GlobalConfig:
20
17
  # Default backend of the language
21
18
  self.default_backend = None
22
19
 
23
- # Runtime constants: New generation token ratio estimation
24
- self.default_init_new_token_ratio = float(
25
- os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
26
- )
27
- self.default_min_new_token_ratio_factor = float(
28
- os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
29
- )
30
- self.default_new_token_ratio_decay_steps = float(
31
- os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
32
- )
33
- self.torch_empty_cache_interval = float(
34
- os.environ.get(
35
- "SGLANG_EMPTY_CACHE_INTERVAL", -1
36
- ) # in seconds. Set if you observe high memory accumulation over a long serving period.
37
- )
38
- # Runtime constants: others
39
- self.retract_decode_steps = 20
40
- self.flashinfer_workspace_size = int(
41
- os.environ.get("FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024)
42
- )
43
-
44
20
  # Output tokenization configs
45
21
  self.skip_special_tokens_in_output = True
46
22
  self.spaces_between_special_tokens_in_out = True
sglang/lang/api.py CHANGED
@@ -79,6 +79,7 @@ def gen(
79
79
  n: Optional[int] = None,
80
80
  stop: Optional[Union[str, List[str]]] = None,
81
81
  stop_token_ids: Optional[List[int]] = None,
82
+ stop_regex: Optional[Union[str, List[str]]] = None,
82
83
  temperature: Optional[float] = None,
83
84
  top_p: Optional[float] = None,
84
85
  top_k: Optional[int] = None,
@@ -120,6 +121,7 @@ def gen(
120
121
  n,
121
122
  stop,
122
123
  stop_token_ids,
124
+ stop_regex,
123
125
  temperature,
124
126
  top_p,
125
127
  top_k,
@@ -143,6 +145,7 @@ def gen_int(
143
145
  n: Optional[int] = None,
144
146
  stop: Optional[Union[str, List[str]]] = None,
145
147
  stop_token_ids: Optional[List[int]] = None,
148
+ stop_regex: Optional[Union[str, List[str]]] = None,
146
149
  temperature: Optional[float] = None,
147
150
  top_p: Optional[float] = None,
148
151
  top_k: Optional[int] = None,
@@ -162,6 +165,7 @@ def gen_int(
162
165
  n,
163
166
  stop,
164
167
  stop_token_ids,
168
+ stop_regex,
165
169
  temperature,
166
170
  top_p,
167
171
  top_k,
@@ -184,6 +188,7 @@ def gen_string(
184
188
  n: Optional[int] = None,
185
189
  stop: Optional[Union[str, List[str]]] = None,
186
190
  stop_token_ids: Optional[List[int]] = None,
191
+ stop_regex: Optional[Union[str, List[str]]] = None,
187
192
  temperature: Optional[float] = None,
188
193
  top_p: Optional[float] = None,
189
194
  top_k: Optional[int] = None,
@@ -203,6 +208,7 @@ def gen_string(
203
208
  n,
204
209
  stop,
205
210
  stop_token_ids,
211
+ stop_regex,
206
212
  temperature,
207
213
  top_p,
208
214
  top_k,
@@ -792,6 +792,7 @@ class StreamExecutor:
792
792
  "n",
793
793
  "stop",
794
794
  "stop_token_ids",
795
+ "stop_regex",
795
796
  "temperature",
796
797
  "top_p",
797
798
  "top_k",
sglang/lang/ir.py CHANGED
@@ -21,6 +21,7 @@ class SglSamplingParams:
21
21
  n: int = 1
22
22
  stop: Union[str, List[str]] = ()
23
23
  stop_token_ids: Optional[List[int]] = ()
24
+ stop_regex: Optional[Union[str, List[str]]] = ()
24
25
  temperature: float = 1.0
25
26
  top_p: float = 1.0
26
27
  top_k: int = -1 # -1 means disable
@@ -45,6 +46,7 @@ class SglSamplingParams:
45
46
  self.n,
46
47
  self.stop,
47
48
  self.stop_token_ids,
49
+ self.stop_regex,
48
50
  self.temperature,
49
51
  self.top_p,
50
52
  self.top_k,
@@ -123,6 +125,7 @@ class SglSamplingParams:
123
125
  "n": self.n,
124
126
  "stop": self.stop,
125
127
  "stop_token_ids": self.stop_token_ids,
128
+ "stop_regex": self.stop_regex,
126
129
  "temperature": self.temperature,
127
130
  "top_p": self.top_p,
128
131
  "top_k": self.top_k,
@@ -161,6 +164,7 @@ class SglFunction:
161
164
  n: int = 1,
162
165
  stop: Optional[Union[str, List[str]]] = None,
163
166
  stop_token_ids: Optional[List[int]] = None,
167
+ stop_regex: Optional[Union[str, List[str]]] = None,
164
168
  temperature: float = 1.0,
165
169
  top_p: float = 1.0,
166
170
  top_k: int = -1,
@@ -184,12 +188,15 @@ class SglFunction:
184
188
  stop = []
185
189
  if stop_token_ids is None:
186
190
  stop_token_ids = []
191
+ if stop_regex is None:
192
+ stop_regex = []
187
193
 
188
194
  default_sampling_para = SglSamplingParams(
189
195
  max_new_tokens=max_new_tokens,
190
196
  n=n,
191
197
  stop=stop,
192
198
  stop_token_ids=stop_token_ids,
199
+ stop_regex=stop_regex,
193
200
  temperature=temperature,
194
201
  top_p=top_p,
195
202
  top_k=top_k,
@@ -221,6 +228,7 @@ class SglFunction:
221
228
  n: int = 1,
222
229
  stop: Optional[Union[str, List[str]]] = None,
223
230
  stop_token_ids: Optional[List[int]] = None,
231
+ stop_regex: Optional[Union[str, List[str]]] = None,
224
232
  temperature: float = 1.0,
225
233
  top_p: float = 1.0,
226
234
  top_k: int = -1,
@@ -243,6 +251,8 @@ class SglFunction:
243
251
  stop = []
244
252
  if stop_token_ids is None:
245
253
  stop_token_ids = []
254
+ if stop_regex is None:
255
+ stop_regex = []
246
256
 
247
257
  assert isinstance(batch_kwargs, (list, tuple))
248
258
  if len(batch_kwargs) == 0:
@@ -267,6 +277,7 @@ class SglFunction:
267
277
  n=n,
268
278
  stop=stop,
269
279
  stop_token_ids=stop_token_ids,
280
+ stop_regex=stop_regex,
270
281
  temperature=temperature,
271
282
  top_p=top_p,
272
283
  top_k=top_k,
@@ -451,6 +462,7 @@ class SglGen(SglExpr):
451
462
  n: Optional[int] = None,
452
463
  stop: Optional[Union[str, List[str]]] = None,
453
464
  stop_token_ids: Optional[List[int]] = None,
465
+ stop_regex: Optional[Union[str, List[str]]] = None,
454
466
  temperature: Optional[float] = None,
455
467
  top_p: Optional[float] = None,
456
468
  top_k: Optional[int] = None,
@@ -474,6 +486,7 @@ class SglGen(SglExpr):
474
486
  min_new_tokens=min_new_tokens,
475
487
  n=n,
476
488
  stop=stop,
489
+ stop_regex=stop_regex,
477
490
  stop_token_ids=stop_token_ids,
478
491
  temperature=temperature,
479
492
  top_p=top_p,
sglang/launch_server.py CHANGED
@@ -1,30 +1,25 @@
1
1
  """Launch the inference server."""
2
2
 
3
+ import asyncio
3
4
  import os
4
5
  import sys
5
6
 
6
- from sglang.srt.entrypoints.http_server import launch_server
7
7
  from sglang.srt.server_args import prepare_server_args
8
8
  from sglang.srt.utils import kill_process_tree
9
9
 
10
- MOVE_ENVS_WARN = """
11
- ########################################################################
12
- # For contributors and developers: #
13
- # Please move environment variable definitions to sglang.srt.environ #
14
- # using the following pattern: #
15
- # SGLANG_XXX = EnvBool(False) #
16
- # #
17
- ########################################################################
18
- """
19
-
20
10
  if __name__ == "__main__":
21
11
  server_args = prepare_server_args(sys.argv[1:])
22
12
 
23
- from sglang.srt.server_args import print_deprecated_warning
13
+ try:
14
+ if server_args.grpc_mode:
15
+ # Handle gRPC server
16
+ from sglang.srt.entrypoints.grpc_server import serve_grpc
24
17
 
25
- print_deprecated_warning(MOVE_ENVS_WARN)
18
+ asyncio.run(serve_grpc(server_args))
19
+ else:
20
+ # Handle HTTP server
21
+ from sglang.srt.entrypoints.http_server import launch_server
26
22
 
27
- try:
28
- launch_server(server_args)
23
+ launch_server(server_args)
29
24
  finally:
30
25
  kill_process_tree(os.getpid(), include_parent=False)
sglang/profiler.py CHANGED
@@ -25,6 +25,7 @@ def _run_profile(
25
25
  output_dir: Optional[str] = None,
26
26
  profile_name: Optional[str] = None,
27
27
  profile_by_stage: bool = False,
28
+ merge_profiles: bool = False,
28
29
  ) -> str:
29
30
  if output_dir is None:
30
31
  output_dir = PROFILER_DIR
@@ -60,6 +61,7 @@ def _run_profile(
60
61
  "num_steps": str(num_steps),
61
62
  "activities": activities,
62
63
  "profile_by_stage": profile_by_stage,
64
+ "merge_profiles": merge_profiles,
63
65
  }
64
66
 
65
67
  response = requests.post(url=url + "/start_profile", json=json_data)
@@ -76,10 +78,17 @@ def run_profile(
76
78
  output_dir: Optional[str] = None,
77
79
  profile_name: Optional[str] = None,
78
80
  profile_by_stage: bool = False,
81
+ merge_profiles: bool = False,
79
82
  ):
80
83
  # step based profile will self terminate on num_steps constraints
81
84
  link = _run_profile(
82
- url, num_steps, activities, output_dir, profile_name, profile_by_stage
85
+ url,
86
+ num_steps,
87
+ activities,
88
+ output_dir,
89
+ profile_name,
90
+ profile_by_stage,
91
+ merge_profiles,
83
92
  )
84
93
  return link
85
94
 
@@ -145,6 +154,13 @@ if __name__ == "__main__":
145
154
  default=False,
146
155
  help="Whether to use rpd profiler (https://github.com/ROCm/rocmProfileData)",
147
156
  )
157
+ parser.add_argument(
158
+ "--merge-profiles",
159
+ action=argparse.BooleanOptionalAction,
160
+ type=bool,
161
+ default=False,
162
+ help="Whether to merge profiles from all ranks into a single trace file",
163
+ )
148
164
 
149
165
  args = parser.parse_args()
150
166
  activities = []
@@ -163,4 +179,5 @@ if __name__ == "__main__":
163
179
  args.output_dir,
164
180
  args.profile_name,
165
181
  args.profile_by_stage,
182
+ args.merge_profiles,
166
183
  )
sglang/srt/_custom_ops.py CHANGED
@@ -15,7 +15,7 @@ if not is_hpu():
15
15
  # ROCm does not use vllm custom allreduce
16
16
  if use_vllm_custom_allreduce and not is_hip():
17
17
  try:
18
- import vllm._C
18
+ import vllm._C # noqa: F401
19
19
  except ImportError as e:
20
20
  logger.warning("Failed to import from vllm._C with %r", e)
21
21
  else:
@@ -9,6 +9,22 @@ import torch
9
9
  import triton
10
10
  import triton.language as tl
11
11
 
12
+ from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
13
+ from sglang.srt.utils.common import calc_diff, get_bool_env_var
14
+
15
+ if ENABLE_JIT_DEEPGEMM:
16
+ import deep_gemm
17
+
18
+ _ENABLE_MM_DEEPGEMM = get_bool_env_var(
19
+ "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_DEEPGEMM", "1"
20
+ )
21
+ _ENABLE_MM_COMPARISON_TEST = get_bool_env_var(
22
+ "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_COMPARISON_TEST"
23
+ )
24
+
25
+ if not _ENABLE_MM_DEEPGEMM:
26
+ print("Disable DeepGEMM in batch invariant ops. Performance may be suboptimal.")
27
+
12
28
  __all__ = [
13
29
  "set_batch_invariant_mode",
14
30
  "is_batch_invariant_mode_enabled",
@@ -77,8 +93,6 @@ def matmul_kernel_persistent(
77
93
  k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
78
94
  num_tiles = num_pid_m * num_pid_n
79
95
 
80
- tile_id_c = start_pid - NUM_SMS
81
-
82
96
  offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
83
97
  num_pid_in_group = GROUP_SIZE_M * num_pid_n
84
98
 
@@ -120,10 +134,6 @@ def matmul_kernel_persistent(
120
134
  )
121
135
  accumulator = tl.dot(a, b, accumulator)
122
136
 
123
- tile_id_c += NUM_SMS
124
- pid_m, pid_n = _compute_pid(
125
- tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
126
- )
127
137
  offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
128
138
  offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
129
139
  if C_LARGE:
@@ -137,12 +147,16 @@ def matmul_kernel_persistent(
137
147
  accumulator += bias
138
148
  if c_ptr.dtype.element_ty == tl.float8e4nv:
139
149
  c = accumulator.to(tl.float8e4nv)
150
+ elif c_ptr.dtype.element_ty == tl.bfloat16:
151
+ c = accumulator.to(tl.bfloat16)
152
+ elif c_ptr.dtype.element_ty == tl.float32:
153
+ c = accumulator.to(tl.float32)
140
154
  else:
141
155
  c = accumulator.to(tl.float16)
142
156
  tl.store(c_ptrs, c, mask=c_mask)
143
157
 
144
158
 
145
- def matmul_persistent(
159
+ def _matmul_persistent_triton(
146
160
  a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
147
161
  ):
148
162
  # Check constraints.
@@ -219,6 +233,54 @@ def matmul_persistent(
219
233
  return c
220
234
 
221
235
 
236
+ def _matmul_persistent_deepgemm(
237
+ a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
238
+ ):
239
+ M, K = a.shape
240
+ K, N = b.shape
241
+ dtype = a.dtype
242
+ out = torch.empty((M, N), device=a.device, dtype=dtype)
243
+
244
+ deep_gemm.bf16_gemm_nn(a, b, out)
245
+
246
+ # TODO can this be put in DeepGEMM's `c`?
247
+ if bias is not None:
248
+ out += bias
249
+
250
+ return out
251
+
252
+
253
+ def matmul_persistent(
254
+ a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
255
+ ):
256
+ if (
257
+ _ENABLE_MM_DEEPGEMM
258
+ and ENABLE_JIT_DEEPGEMM
259
+ and (a.dtype == torch.bfloat16)
260
+ and (b.dtype == torch.bfloat16)
261
+ and a.is_contiguous()
262
+ and b.transpose(0, 1).is_contiguous()
263
+ ):
264
+ if _ENABLE_MM_COMPARISON_TEST:
265
+ out_triton = _matmul_persistent_triton(a=a, b=b, bias=bias)
266
+ out_deepgemm = _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
267
+ diff = calc_diff(out_triton, out_deepgemm)
268
+ assert diff < 0.0001, f"{diff=} {out_triton=} {out_deepgemm=}"
269
+ # can be enabled for debugging
270
+ # print(
271
+ # f"{diff=} "
272
+ # f"{(out_triton - out_deepgemm).abs().mean()=} "
273
+ # f"{(out_triton - out_deepgemm).abs().sum()=} "
274
+ # f"{torch.sum(out_triton != out_deepgemm)=} "
275
+ # )
276
+ # print(f"{a=} {b=} {bias=} {out_triton=} {out_deepgemm=}")
277
+ return out_deepgemm
278
+
279
+ return _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
280
+
281
+ return _matmul_persistent_triton(a=a, b=b, bias=bias)
282
+
283
+
222
284
  @triton.jit
223
285
  def _log_softmax_kernel(
224
286
  input_ptr,
@@ -497,16 +559,39 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
497
559
  return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
498
560
 
499
561
 
562
+ def bmm_batch_invariant(a, b, *, out=None):
563
+ # Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
564
+ # Process each batch separately with our persistent kernel
565
+ if a.ndim == 3 and b.ndim == 3:
566
+ results = []
567
+ for i in range(a.shape[0]):
568
+ results.append(matmul_persistent(a[i], b[i]))
569
+ result = torch.stack(results, dim=0)
570
+
571
+ if out is not None:
572
+ out.copy_(result)
573
+ return out
574
+ return result
575
+ else:
576
+ raise ValueError(
577
+ f"bmm_batch_invariant expects 3D tensors, "
578
+ f"got shapes {a.shape} and {b.shape}"
579
+ )
580
+
581
+
500
582
  _batch_invariant_MODE = False
501
583
  _batch_invariant_LIB = None
584
+ _original_torch_bmm = None
502
585
 
503
586
 
504
587
  def is_batch_invariant_mode_enabled():
505
588
  return _batch_invariant_MODE
506
589
 
507
590
 
508
- def enable_batch_invariant_mode():
509
- global _batch_invariant_MODE, _batch_invariant_LIB
591
+ def enable_batch_invariant_mode(
592
+ enable_bmm: bool = True,
593
+ ):
594
+ global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
510
595
  if _batch_invariant_MODE:
511
596
  return
512
597
 
@@ -519,11 +604,21 @@ def enable_batch_invariant_mode():
519
604
  )
520
605
  _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
521
606
 
607
+ if enable_bmm:
608
+ _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
609
+
610
+ # Also monkeypatch torch.bmm directly as a fallback
611
+ _original_torch_bmm = torch.bmm
612
+ torch.bmm = bmm_batch_invariant
613
+
522
614
 
523
615
  def disable_batch_invariant_mode():
524
- global _batch_invariant_MODE, _batch_invariant_LIB
616
+ global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
525
617
  if _batch_invariant_LIB is not None:
526
618
  _batch_invariant_LIB._destroy()
619
+ if _original_torch_bmm is not None:
620
+ torch.bmm = _original_torch_bmm
621
+ _original_torch_bmm = None
527
622
  _batch_invariant_MODE = False
528
623
  _batch_invariant_LIB = None
529
624
 
@@ -0,0 +1,142 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """
15
+ Checkpoint-engine integration for SGLang.
16
+ This module provides weight update functionality via IPC for checkpoint-engine compatibility.
17
+ """
18
+ import logging
19
+ from typing import Callable, Dict, Optional
20
+
21
+ import torch
22
+ import zmq
23
+
24
+ try:
25
+ from checkpoint_engine.worker import update_weights_from_ipc
26
+ except ImportError:
27
+ raise ImportError(
28
+ "checkpoint-engine is not installed. "
29
+ "Please install it with: pip install sglang[checkpoint-engine]"
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class SGLangCheckpointEngineWorkerExtension:
36
+ """
37
+ Worker extension for SGLang to support checkpoint-engine IPC weight updates.
38
+ This class provides the interface needed for checkpoint-engine integration.
39
+ """
40
+
41
+ def __init__(self):
42
+ self._zmq_ctx: Optional[zmq.Context] = None
43
+
44
+ def get_device_uuid(self) -> str:
45
+ """Get the UUID of current device."""
46
+ # We need to implement this to get the device UUID
47
+ # This will be overridden when integrated into SGLang's worker
48
+ raise NotImplementedError(
49
+ "This method should be overridden by SGLang integration"
50
+ )
51
+
52
+ def get_device_id(self) -> int:
53
+ """Get the device ID."""
54
+ raise NotImplementedError(
55
+ "This method should be overridden by SGLang integration"
56
+ )
57
+
58
+ def get_model_loader(self) -> Callable:
59
+ """Get the model weight loader function."""
60
+ raise NotImplementedError(
61
+ "This method should be overridden by SGLang integration"
62
+ )
63
+
64
+ def get_post_hook(self) -> Optional[Callable]:
65
+ """Get the post-processing hook after weight loading."""
66
+ return None
67
+
68
+ def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):
69
+ """
70
+ Update weights from IPC communication.
71
+ Args:
72
+ zmq_handles: Dict mapping device UUID to ZMQ socket path
73
+ """
74
+ if self._zmq_ctx is None:
75
+ self._zmq_ctx = zmq.Context()
76
+ device_uuid = self.get_device_uuid()
77
+ device_id = self.get_device_id()
78
+ if device_uuid not in zmq_handles:
79
+ raise ValueError(
80
+ f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}"
81
+ )
82
+ update_weights_from_ipc(
83
+ self._zmq_ctx,
84
+ zmq_handles[device_uuid],
85
+ device_id=device_id,
86
+ run=self.get_model_loader(),
87
+ post_hook=self.get_post_hook(),
88
+ )
89
+
90
+
91
+ class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension):
92
+ """
93
+ Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner.
94
+ This class provides the concrete implementation for checkpoint-engine IPC weight updates.
95
+ """
96
+
97
+ def __init__(self, model_runner):
98
+ super().__init__()
99
+ self.model_runner = model_runner
100
+
101
+ def get_device_uuid(self) -> str:
102
+ """Get the UUID of current device."""
103
+ # Get device UUID for current device
104
+ device_id = torch.cuda.current_device()
105
+ try:
106
+ return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}"
107
+ except AssertionError as e:
108
+ raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e
109
+
110
+ def get_device_id(self) -> int:
111
+ """Get the device ID."""
112
+ return torch.cuda.current_device()
113
+
114
+ def get_model_loader(self) -> Callable:
115
+ """Get the model weight loader function."""
116
+ return self.model_runner.model.load_weights
117
+
118
+ def get_post_hook(self) -> Optional[Callable]:
119
+ """Get the post-processing hook after weight loading."""
120
+
121
+ def post_hook():
122
+ # Perform post-processing after weight loading similar to DefaultModelLoader
123
+ try:
124
+ from sglang.srt.model_loader.loader import device_loading_context
125
+
126
+ # Process quantization methods after loading weights
127
+ for _, module in self.model_runner.model.named_modules():
128
+ quant_method = getattr(module, "quant_method", None)
129
+ if quant_method is not None:
130
+ # Move parameters to device if needed for quantization processing
131
+ target_device = torch.device(
132
+ "cuda", torch.cuda.current_device()
133
+ )
134
+ with device_loading_context(module, target_device):
135
+ quant_method.process_weights_after_loading(module)
136
+ # Call model-specific post-loading hook if available
137
+ if hasattr(self.model_runner.model, "post_load_weights"):
138
+ self.model_runner.model.post_load_weights()
139
+ except Exception as e:
140
+ logger.warning(f"Post-hook processing failed: {e}")
141
+
142
+ return post_hook