sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -40,8 +40,9 @@ from sglang.srt.layers.moe import (
40
40
  get_moe_a2a_backend,
41
41
  should_use_flashinfer_cutlass_moe_fp4_allgather,
42
42
  )
43
- from sglang.srt.managers.schedule_batch import global_server_args_dict
44
43
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
+ from sglang.srt.server_args import get_global_server_args
45
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
45
46
  from sglang.srt.utils import (
46
47
  get_bool_env_var,
47
48
  is_cuda,
@@ -50,6 +51,7 @@ from sglang.srt.utils import (
50
51
  is_hip,
51
52
  is_sm90_supported,
52
53
  is_sm100_supported,
54
+ prepare_weight_cache,
53
55
  )
54
56
 
55
57
  _is_flashinfer_available = is_flashinfer_available()
@@ -167,7 +169,7 @@ class LayerScatterModes:
167
169
 
168
170
 
169
171
  def enable_moe_dense_fully_dp():
170
- return global_server_args_dict["moe_dense_tp_size"] == 1
172
+ return get_global_server_args().moe_dense_tp_size == 1
171
173
 
172
174
 
173
175
  class LayerCommunicator:
@@ -210,6 +212,10 @@ class LayerCommunicator:
210
212
  )
211
213
  )
212
214
 
215
+ self._speculative_algo = SpeculativeAlgorithm.from_string(
216
+ get_global_server_args().speculative_algorithm
217
+ )
218
+
213
219
  def prepare_attn(
214
220
  self,
215
221
  hidden_states: torch.Tensor,
@@ -275,7 +281,11 @@ class LayerCommunicator:
275
281
  hidden_states: torch.Tensor,
276
282
  residual: torch.Tensor,
277
283
  forward_batch: ForwardBatch,
284
+ cache=None,
278
285
  ):
286
+ if cache is not None:
287
+ self._context.cache = cache
288
+
279
289
  return self._communicate_with_all_reduce_and_layer_norm_fn(
280
290
  hidden_states=hidden_states,
281
291
  residual=residual,
@@ -309,11 +319,10 @@ class LayerCommunicator:
309
319
  def should_fuse_mlp_allreduce_with_next_layer(
310
320
  self, forward_batch: ForwardBatch
311
321
  ) -> bool:
312
- speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
313
322
  if (
314
323
  is_dp_attention_enabled()
315
- and speculative_algo is not None
316
- and speculative_algo.is_eagle()
324
+ and self._speculative_algo is not None
325
+ and self._speculative_algo.is_eagle()
317
326
  ):
318
327
  return False
319
328
 
@@ -328,7 +337,7 @@ class LayerCommunicator:
328
337
  static_conditions_met = (
329
338
  (not self.is_last_layer)
330
339
  and (self._context.tp_size > 1)
331
- and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
340
+ and get_global_server_args().enable_flashinfer_allreduce_fusion
332
341
  and _is_flashinfer_available
333
342
  )
334
343
 
@@ -349,6 +358,7 @@ class CommunicateContext:
349
358
  attn_tp_size: int
350
359
  attn_dp_size: int
351
360
  tp_size: int
361
+ cache = None
352
362
 
353
363
  def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
354
364
  return self.process_group_sizes[a] == self.process_group_sizes[b]
@@ -525,7 +535,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
525
535
  (_is_sm100_supported or _is_sm90_supported)
526
536
  and _is_flashinfer_available
527
537
  and hasattr(layernorm, "forward_with_allreduce_fusion")
528
- and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
538
+ and get_global_server_args().enable_flashinfer_allreduce_fusion
529
539
  and hidden_states.shape[0] <= 4096
530
540
  ):
531
541
  hidden_states, residual = layernorm.forward_with_allreduce_fusion(
@@ -533,6 +543,8 @@ class CommunicateWithAllReduceAndLayerNormFn:
533
543
  )
534
544
  else:
535
545
  hidden_states = tensor_model_parallel_all_reduce(hidden_states)
546
+ if context.cache is not None:
547
+ _ = prepare_weight_cache(hidden_states, context.cache)
536
548
  hidden_states, residual = layernorm(hidden_states, residual)
537
549
  return hidden_states, residual
538
550
 
@@ -7,11 +7,10 @@ from typing import Dict, List, Tuple
7
7
  import torch
8
8
  from tqdm import tqdm
9
9
 
10
- from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
11
- ENABLE_JIT_DEEPGEMM,
12
- )
10
+ from sglang.srt.environ import envs
11
+ from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
13
12
  from sglang.srt.server_args import ServerArgs
14
- from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
13
+ from sglang.srt.utils import ceil_div, get_bool_env_var
15
14
 
16
15
  logger = logging.getLogger(__name__)
17
16
 
@@ -20,12 +19,9 @@ if ENABLE_JIT_DEEPGEMM:
20
19
 
21
20
 
22
21
  _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
23
- _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
24
- "SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
25
- )
22
+ _ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get()
26
23
  _DO_COMPILE_ALL = True
27
24
  _IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
28
- _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
29
25
  _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
30
26
 
31
27
  # Force redirect deep_gemm cache_dir
@@ -0,0 +1,25 @@
1
+ import logging
2
+
3
+ from sglang.srt.environ import envs
4
+ from sglang.srt.utils import get_device_sm, is_blackwell
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def _compute_enable_deep_gemm():
10
+ sm_version = get_device_sm()
11
+ if sm_version < 90:
12
+ return False
13
+
14
+ try:
15
+ import deep_gemm # noqa: F401
16
+ except ImportError:
17
+ return False
18
+
19
+ return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get()
20
+
21
+
22
+ ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
23
+
24
+ DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and is_blackwell()
25
+ DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
@@ -4,8 +4,8 @@ from typing import Tuple
4
4
 
5
5
  import torch
6
6
 
7
- from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
8
- from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
7
+ from sglang.srt.layers.deep_gemm_wrapper import compile_utils
8
+ from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401
9
9
  DEEPGEMM_BLACKWELL,
10
10
  DEEPGEMM_SCALE_UE8M0,
11
11
  ENABLE_JIT_DEEPGEMM,
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
17
17
 
18
18
  if ENABLE_JIT_DEEPGEMM:
19
19
  import deep_gemm
20
- from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
20
+ from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor # noqa: F401
21
21
 
22
22
  _SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
23
23
 
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
17
17
  get_tp_group,
18
18
  tensor_model_parallel_all_reduce,
19
19
  )
20
+ from sglang.srt.utils import get_bool_env_var, is_hip
20
21
 
21
22
  if TYPE_CHECKING:
22
23
  from sglang.srt.configs.model_config import ModelConfig
@@ -36,6 +37,9 @@ _LOCAL_ATTN_DP_SIZE: Optional[int] = None
36
37
  _LOCAL_ATTN_DP_RANK: Optional[int] = None
37
38
  _ENABLE_DP_ATTENTION_FLAG: bool = False
38
39
 
40
+ _is_hip = is_hip()
41
+ _USE_ROCM700A_WA = _is_hip and get_bool_env_var("SGLANG_USE_ROCM700A")
42
+
39
43
 
40
44
  class DpPaddingMode(IntEnum):
41
45
 
@@ -67,7 +71,12 @@ class DpPaddingMode(IntEnum):
67
71
 
68
72
  @classmethod
69
73
  def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
70
- return cls.MAX_LEN
74
+ # TODO(kkhuang-amd): noqa, temporary work-around for rocm 7.0.0 alpha
75
+ # it can be safely removed later, once RCCL fixed
76
+ if _USE_ROCM700A_WA:
77
+ return cls.SUM_LEN
78
+ else:
79
+ return cls.MAX_LEN
71
80
 
72
81
 
73
82
  class _DpGatheredBufferWrapper:
@@ -78,6 +87,7 @@ class _DpGatheredBufferWrapper:
78
87
  _global_dp_buffer_len: int
79
88
  _local_dp_buffer_len: int
80
89
  _global_num_tokens: Optional[List[int]]
90
+ _is_extend_in_batch: bool
81
91
 
82
92
  @classmethod
83
93
  def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
@@ -136,6 +146,14 @@ class _DpGatheredBufferWrapper:
136
146
  def get_dp_device(cls) -> torch.device:
137
147
  return cls._device
138
148
 
149
+ @classmethod
150
+ def set_is_extend_in_batch(cls, is_extend_in_batch: bool):
151
+ cls._is_extend_in_batch = is_extend_in_batch
152
+
153
+ @classmethod
154
+ def get_is_extend_in_batch(cls) -> bool:
155
+ return cls._is_extend_in_batch
156
+
139
157
 
140
158
  def set_dp_buffer_len(
141
159
  global_dp_buffer_len: int,
@@ -179,6 +197,14 @@ def get_dp_device() -> torch.device:
179
197
  return _DpGatheredBufferWrapper.get_dp_device()
180
198
 
181
199
 
200
+ def set_is_extend_in_batch(is_extend_in_batch: bool):
201
+ _DpGatheredBufferWrapper.set_is_extend_in_batch(is_extend_in_batch)
202
+
203
+
204
+ def get_is_extend_in_batch() -> bool:
205
+ return _DpGatheredBufferWrapper.get_is_extend_in_batch()
206
+
207
+
182
208
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
183
209
  if not enable_dp_attention:
184
210
  return tp_rank, tp_size, 0
@@ -254,6 +280,7 @@ def initialize_dp_attention(
254
280
  use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
255
281
  use_pymscclpp=False,
256
282
  use_custom_allreduce=False,
283
+ use_torch_symm_mem=False,
257
284
  use_hpu_communicator=False,
258
285
  use_xpu_communicator=False,
259
286
  use_npu_communicator=False,
@@ -187,7 +187,9 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
187
187
 
188
188
  def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
189
189
  assert len(x.shape) == 2
190
- assert x.shape == residual.shape and x.dtype == residual.dtype
190
+ assert (
191
+ x.shape == residual.shape and x.dtype == residual.dtype
192
+ ), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
191
193
  output, mid = torch.empty_like(x), torch.empty_like(x)
192
194
  bs, hidden_dim = x.shape
193
195
  if autotune:
@@ -42,13 +42,16 @@ _is_cpu_amx_available = cpu_has_amx_support()
42
42
  _is_cpu = is_cpu()
43
43
  _is_xpu = is_xpu()
44
44
 
45
- if _is_cuda:
46
- if _is_flashinfer_available:
47
- from flashinfer.norm import fused_add_rmsnorm
48
- else:
49
- from sgl_kernel import fused_add_rmsnorm
50
- from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
51
-
45
+ if _is_cuda or _is_xpu:
46
+ # if _is_flashinfer_available:
47
+ # from flashinfer.norm import fused_add_rmsnorm
48
+ # else:
49
+ from sgl_kernel import (
50
+ fused_add_rmsnorm,
51
+ gemma_fused_add_rmsnorm,
52
+ gemma_rmsnorm,
53
+ rmsnorm,
54
+ )
52
55
  if _use_aiter:
53
56
  from aiter import rmsnorm2d_fwd as rms_norm
54
57
  from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
@@ -80,6 +83,8 @@ class RMSNorm(CustomOp):
80
83
  )
81
84
  if _use_aiter:
82
85
  self._forward_method = self.forward_aiter
86
+ if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
87
+ self._forward_method = self.forward_native
83
88
 
84
89
  def forward_cuda(
85
90
  self,
@@ -209,6 +214,19 @@ class RMSNorm(CustomOp):
209
214
  else:
210
215
  return self.forward_native(x, residual)
211
216
 
217
+ def forward_xpu(
218
+ self,
219
+ x: torch.Tensor,
220
+ residual: Optional[torch.Tensor] = None,
221
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
222
+ if self.variance_size_override is not None:
223
+ return self.forward_native(x, residual)
224
+ if residual is not None:
225
+ fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
226
+ return x, residual
227
+ out = rmsnorm(x, self.weight.data, self.variance_epsilon)
228
+ return out
229
+
212
230
  def forward_with_allreduce_fusion(
213
231
  self,
214
232
  x: torch.Tensor,
@@ -256,6 +274,19 @@ class GemmaRMSNorm(CustomOp):
256
274
  if _is_hip:
257
275
  self._forward_method = self.forward_native
258
276
 
277
+ def _forward_impl(
278
+ self,
279
+ x: torch.Tensor,
280
+ residual: Optional[torch.Tensor] = None,
281
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
282
+ if residual is not None:
283
+ gemma_fused_add_rmsnorm(
284
+ x, residual, self.weight.data, self.variance_epsilon
285
+ )
286
+ return x, residual
287
+ out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
288
+ return out
289
+
259
290
  def forward_native(
260
291
  self,
261
292
  x: torch.Tensor,
@@ -278,13 +309,7 @@ class GemmaRMSNorm(CustomOp):
278
309
  x: torch.Tensor,
279
310
  residual: Optional[torch.Tensor] = None,
280
311
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
281
- if residual is not None:
282
- gemma_fused_add_rmsnorm(
283
- x, residual, self.weight.data, self.variance_epsilon
284
- )
285
- return x, residual
286
- out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
287
- return out
312
+ return self._forward_impl(x, residual)
288
313
 
289
314
  def forward_npu(
290
315
  self,
@@ -298,6 +323,13 @@ class GemmaRMSNorm(CustomOp):
298
323
  x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
299
324
  return x if residual is None else (x, residual)
300
325
 
326
+ def forward_xpu(
327
+ self,
328
+ x: torch.Tensor,
329
+ residual: Optional[torch.Tensor] = None,
330
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
331
+ return self._forward_impl(x, residual)
332
+
301
333
 
302
334
  class Gemma3RMSNorm(CustomOp):
303
335
  def __init__(self, dim: int, eps: float = 1e-6):
@@ -333,4 +365,4 @@ if not (
333
365
  logger.info(
334
366
  "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
335
367
  )
336
- from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
368
+ from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm # noqa: F401
@@ -31,7 +31,8 @@ from sglang.srt.layers.parameter import (
31
31
  _ColumnvLLMParameter,
32
32
  )
33
33
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
34
- from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
34
+ from sglang.srt.layers.utils import pad_or_narrow_weight
35
+ from sglang.srt.utils import get_bool_env_var, is_cpu, is_hip, is_npu, set_weight_attrs
35
36
 
36
37
  if TYPE_CHECKING:
37
38
  from sglang.srt.layers.quantization.base_config import (
@@ -39,12 +40,18 @@ if TYPE_CHECKING:
39
40
  QuantizeMethodBase,
40
41
  )
41
42
 
43
+ _is_hip = is_hip()
44
+ _disable_hip_linear_quant = _is_hip and get_bool_env_var(
45
+ "SGLANG_ROCM_DISABLE_LINEARQUANT"
46
+ )
47
+
42
48
  logger = logging.getLogger(__name__)
43
49
 
44
50
  WEIGHT_LOADER_V2_SUPPORTED = [
45
51
  "CompressedTensorsLinearMethod",
46
52
  "AWQMarlinLinearMethod",
47
53
  "AWQLinearMethod",
54
+ "AWQLinearAscendMethod",
48
55
  "GPTQMarlinLinearMethod",
49
56
  "Fp8LinearMethod",
50
57
  "BlockInt8LinearMethod",
@@ -625,9 +632,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
625
632
  # bitsandbytes loads the weights of the specific portion
626
633
  # no need to narrow here
627
634
  if not use_bitsandbytes_4bit and not self.use_presharded_weights:
628
- loaded_weight = loaded_weight.narrow(
629
- output_dim, start_idx, shard_size
630
- )
635
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
636
+ end_idx = start_idx + shard_size
637
+ if end_idx > loaded_weight.shape[output_dim]:
638
+ loaded_weight = pad_or_narrow_weight(
639
+ loaded_weight, output_dim, start_idx, shard_size
640
+ )
641
+ else:
642
+ loaded_weight = loaded_weight.narrow(
643
+ output_dim, start_idx, shard_size
644
+ )
631
645
 
632
646
  # Special case for AQLM codebooks.
633
647
  elif is_metadata:
@@ -816,6 +830,7 @@ class QKVParallelLinear(ColumnParallelLinear):
816
830
  self.num_kv_heads * self.head_size * tp_size, # v_proj
817
831
  ]
818
832
  self.use_presharded_weights = load_presharded_attn
833
+ quant_config = None if _disable_hip_linear_quant else quant_config
819
834
 
820
835
  super().__init__(
821
836
  input_size=input_size,
@@ -1217,6 +1232,7 @@ class RowParallelLinear(LinearBase):
1217
1232
  tp_size: Optional[int] = None,
1218
1233
  use_presharded_weights: bool = False,
1219
1234
  ):
1235
+ quant_config = None if _disable_hip_linear_quant else quant_config
1220
1236
  super().__init__(
1221
1237
  input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
1222
1238
  )
@@ -1302,7 +1318,16 @@ class RowParallelLinear(LinearBase):
1302
1318
  shard_size,
1303
1319
  )
1304
1320
  else:
1305
- loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1321
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
1322
+ end_idx = start_idx + shard_size
1323
+ if end_idx > loaded_weight.shape[input_dim]:
1324
+ loaded_weight = pad_or_narrow_weight(
1325
+ loaded_weight, input_dim, start_idx, shard_size
1326
+ )
1327
+ else:
1328
+ loaded_weight = loaded_weight.narrow(
1329
+ input_dim, start_idx, shard_size
1330
+ )
1306
1331
 
1307
1332
  # Special case for loading scales off disk, which often do not
1308
1333
  # have a shape (such as in the case of AutoFP8).