sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -42,10 +42,25 @@ from sglang.srt.layers.moe import (
42
42
  )
43
43
  from sglang.srt.managers.schedule_batch import global_server_args_dict
44
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
- from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
45
+ from sglang.srt.utils import (
46
+ get_bool_env_var,
47
+ is_cuda,
48
+ is_flashinfer_available,
49
+ is_gfx95_supported,
50
+ is_hip,
51
+ is_sm90_supported,
52
+ is_sm100_supported,
53
+ prepare_weight_cache,
54
+ )
46
55
 
47
56
  _is_flashinfer_available = is_flashinfer_available()
57
+ _is_sm90_supported = is_cuda() and is_sm90_supported()
48
58
  _is_sm100_supported = is_cuda() and is_sm100_supported()
59
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
60
+ _is_gfx95_supported = is_gfx95_supported()
61
+
62
+ if _use_aiter and _is_gfx95_supported:
63
+ from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
49
64
 
50
65
  FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
51
66
 
@@ -201,6 +216,7 @@ class LayerCommunicator:
201
216
  hidden_states: torch.Tensor,
202
217
  residual: torch.Tensor,
203
218
  forward_batch: ForwardBatch,
219
+ qaunt_format: str = "",
204
220
  ):
205
221
  if hidden_states.shape[0] == 0:
206
222
  residual = hidden_states
@@ -218,11 +234,34 @@ class LayerCommunicator:
218
234
  else:
219
235
  if residual is None:
220
236
  residual = hidden_states
221
- hidden_states = self.input_layernorm(hidden_states)
237
+
238
+ if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
239
+ hidden_states = fused_rms_mxfp4_quant(
240
+ hidden_states,
241
+ self.input_layernorm.weight,
242
+ self.input_layernorm.variance_epsilon,
243
+ None,
244
+ None,
245
+ None,
246
+ None,
247
+ )
248
+ else:
249
+ hidden_states = self.input_layernorm(hidden_states)
222
250
  else:
223
- hidden_states, residual = self.input_layernorm(
224
- hidden_states, residual
225
- )
251
+ if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
252
+ hidden_states, residual = fused_rms_mxfp4_quant(
253
+ hidden_states,
254
+ self.input_layernorm.weight,
255
+ self.input_layernorm.variance_epsilon,
256
+ None,
257
+ None,
258
+ None,
259
+ residual,
260
+ )
261
+ else:
262
+ hidden_states, residual = self.input_layernorm(
263
+ hidden_states, residual
264
+ )
226
265
 
227
266
  hidden_states = self._communicate_simple_fn(
228
267
  hidden_states=hidden_states,
@@ -237,7 +276,11 @@ class LayerCommunicator:
237
276
  hidden_states: torch.Tensor,
238
277
  residual: torch.Tensor,
239
278
  forward_batch: ForwardBatch,
279
+ cache=None,
240
280
  ):
281
+ if cache is not None:
282
+ self._context.cache = cache
283
+
241
284
  return self._communicate_with_all_reduce_and_layer_norm_fn(
242
285
  hidden_states=hidden_states,
243
286
  residual=residual,
@@ -311,6 +354,7 @@ class CommunicateContext:
311
354
  attn_tp_size: int
312
355
  attn_dp_size: int
313
356
  tp_size: int
357
+ cache = None
314
358
 
315
359
  def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
316
360
  return self.process_group_sizes[a] == self.process_group_sizes[b]
@@ -484,17 +528,19 @@ class CommunicateWithAllReduceAndLayerNormFn:
484
528
  # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
485
529
  # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
486
530
  if (
487
- _is_sm100_supported
531
+ (_is_sm100_supported or _is_sm90_supported)
488
532
  and _is_flashinfer_available
489
533
  and hasattr(layernorm, "forward_with_allreduce_fusion")
490
534
  and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
491
- and hidden_states.shape[0] <= 2048
535
+ and hidden_states.shape[0] <= 4096
492
536
  ):
493
537
  hidden_states, residual = layernorm.forward_with_allreduce_fusion(
494
538
  hidden_states, residual
495
539
  )
496
540
  else:
497
541
  hidden_states = tensor_model_parallel_all_reduce(hidden_states)
542
+ if context.cache is not None:
543
+ _ = prepare_weight_cache(hidden_states, context.cache)
498
544
  hidden_states, residual = layernorm(hidden_states, residual)
499
545
  return hidden_states, residual
500
546
 
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
17
17
  get_tp_group,
18
18
  tensor_model_parallel_all_reduce,
19
19
  )
20
+ from sglang.srt.utils import get_bool_env_var, is_hip
20
21
 
21
22
  if TYPE_CHECKING:
22
23
  from sglang.srt.configs.model_config import ModelConfig
@@ -36,6 +37,9 @@ _LOCAL_ATTN_DP_SIZE: Optional[int] = None
36
37
  _LOCAL_ATTN_DP_RANK: Optional[int] = None
37
38
  _ENABLE_DP_ATTENTION_FLAG: bool = False
38
39
 
40
+ _is_hip = is_hip()
41
+ _USE_ROCM700A_WA = _is_hip and get_bool_env_var("SGLANG_USE_ROCM700A")
42
+
39
43
 
40
44
  class DpPaddingMode(IntEnum):
41
45
 
@@ -51,7 +55,12 @@ class DpPaddingMode(IntEnum):
51
55
  return self == DpPaddingMode.SUM_LEN
52
56
 
53
57
  @classmethod
54
- def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
58
+ def get_dp_padding_mode(
59
+ cls, is_extend_in_batch, global_num_tokens: List[int]
60
+ ) -> DpPaddingMode:
61
+ if is_extend_in_batch:
62
+ return DpPaddingMode.SUM_LEN
63
+
55
64
  # we choose the mode that minimizes the communication cost
56
65
  max_len = max(global_num_tokens)
57
66
  sum_len = sum(global_num_tokens)
@@ -62,7 +71,12 @@ class DpPaddingMode(IntEnum):
62
71
 
63
72
  @classmethod
64
73
  def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
65
- return cls.MAX_LEN
74
+ # TODO(kkhuang-amd): noqa, temporary work-around for rocm 7.0.0 alpha
75
+ # it can be safely removed later, once RCCL fixed
76
+ if _USE_ROCM700A_WA:
77
+ return cls.SUM_LEN
78
+ else:
79
+ return cls.MAX_LEN
66
80
 
67
81
 
68
82
  class _DpGatheredBufferWrapper:
@@ -119,6 +133,18 @@ class _DpGatheredBufferWrapper:
119
133
  def get_dp_global_num_tokens(cls) -> List[int]:
120
134
  return cls._global_num_tokens
121
135
 
136
+ @classmethod
137
+ def get_dp_hidden_size(cls) -> int:
138
+ return cls._hidden_size
139
+
140
+ @classmethod
141
+ def get_dp_dtype(cls) -> torch.dtype:
142
+ return cls._dtype
143
+
144
+ @classmethod
145
+ def get_dp_device(cls) -> torch.device:
146
+ return cls._device
147
+
122
148
 
123
149
  def set_dp_buffer_len(
124
150
  global_dp_buffer_len: int,
@@ -150,6 +176,18 @@ def get_dp_global_num_tokens() -> List[int]:
150
176
  return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
151
177
 
152
178
 
179
+ def get_dp_hidden_size() -> int:
180
+ return _DpGatheredBufferWrapper.get_dp_hidden_size()
181
+
182
+
183
+ def get_dp_dtype() -> torch.dtype:
184
+ return _DpGatheredBufferWrapper.get_dp_dtype()
185
+
186
+
187
+ def get_dp_device() -> torch.device:
188
+ return _DpGatheredBufferWrapper.get_dp_device()
189
+
190
+
153
191
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
154
192
  if not enable_dp_attention:
155
193
  return tp_rank, tp_size, 0
@@ -225,6 +263,7 @@ def initialize_dp_attention(
225
263
  use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
226
264
  use_pymscclpp=False,
227
265
  use_custom_allreduce=False,
266
+ use_torch_symm_mem=False,
228
267
  use_hpu_communicator=False,
229
268
  use_xpu_communicator=False,
230
269
  use_npu_communicator=False,
@@ -187,7 +187,9 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
187
187
 
188
188
  def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
189
189
  assert len(x.shape) == 2
190
- assert x.shape == residual.shape and x.dtype == residual.dtype
190
+ assert (
191
+ x.shape == residual.shape and x.dtype == residual.dtype
192
+ ), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
191
193
  output, mid = torch.empty_like(x), torch.empty_like(x)
192
194
  bs, hidden_dim = x.shape
193
195
  if autotune:
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
18
18
 
19
19
  import torch
20
20
  import torch.nn as nn
21
+ from packaging.version import Version
21
22
 
22
23
  from sglang.srt.custom_op import CustomOp
23
24
  from sglang.srt.utils import (
@@ -25,32 +26,38 @@ from sglang.srt.utils import (
25
26
  get_bool_env_var,
26
27
  is_cpu,
27
28
  is_cuda,
29
+ is_flashinfer_available,
28
30
  is_hip,
29
31
  is_npu,
32
+ is_xpu,
30
33
  supports_custom_op,
31
34
  )
32
35
 
33
36
  _is_cuda = is_cuda()
37
+ _is_flashinfer_available = is_flashinfer_available()
34
38
  _is_hip = is_hip()
35
39
  _is_npu = is_npu()
36
40
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
37
41
  _is_cpu_amx_available = cpu_has_amx_support()
38
42
  _is_cpu = is_cpu()
43
+ _is_xpu = is_xpu()
39
44
 
40
45
  if _is_cuda:
41
- from sgl_kernel import (
42
- fused_add_rmsnorm,
43
- gemma_fused_add_rmsnorm,
44
- gemma_rmsnorm,
45
- rmsnorm,
46
- )
46
+ if _is_flashinfer_available:
47
+ from flashinfer.norm import fused_add_rmsnorm
48
+ else:
49
+ from sgl_kernel import fused_add_rmsnorm
50
+ from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
47
51
 
48
52
  if _use_aiter:
49
53
  from aiter import rmsnorm2d_fwd as rms_norm
50
54
  from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
51
55
  elif _is_hip:
56
+ import vllm
52
57
  from vllm._custom_ops import fused_add_rms_norm, rms_norm
53
58
 
59
+ _vllm_version = Version(vllm.__version__)
60
+
54
61
  logger = logging.getLogger(__name__)
55
62
 
56
63
  if _is_npu:
@@ -73,6 +80,8 @@ class RMSNorm(CustomOp):
73
80
  )
74
81
  if _use_aiter:
75
82
  self._forward_method = self.forward_aiter
83
+ if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
84
+ self._forward_method = self.forward_native
76
85
 
77
86
  def forward_cuda(
78
87
  self,
@@ -127,8 +136,21 @@ class RMSNorm(CustomOp):
127
136
  # NOTE: Remove this if aiter kernel supports discontinuous input
128
137
  x = x.contiguous()
129
138
  if residual is not None:
130
- fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
131
- return x, residual
139
+ if _vllm_version < Version("0.9"):
140
+ fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
141
+ return x, residual
142
+ else:
143
+ residual_out = torch.empty_like(x)
144
+ output = torch.empty_like(x)
145
+ fused_add_rms_norm(
146
+ output,
147
+ x,
148
+ residual_out,
149
+ residual,
150
+ self.weight.data,
151
+ self.variance_epsilon,
152
+ )
153
+ return output, residual_out
132
154
  out = torch.empty_like(x)
133
155
  rms_norm(out, x, self.weight.data, self.variance_epsilon)
134
156
  return out
@@ -271,16 +293,11 @@ class GemmaRMSNorm(CustomOp):
271
293
  x: torch.Tensor,
272
294
  residual: Optional[torch.Tensor] = None,
273
295
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
274
- orig_dtype = x.dtype
275
296
  if residual is not None:
276
297
  x = x + residual
277
298
  residual = x
278
299
 
279
- x = x.float()
280
- variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
281
- x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
282
- x = x * (1.0 + self.weight.float())
283
- x = x.to(orig_dtype)
300
+ x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
284
301
  return x if residual is None else (x, residual)
285
302
 
286
303
 
@@ -312,7 +329,9 @@ class Gemma3RMSNorm(CustomOp):
312
329
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
313
330
 
314
331
 
315
- if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
332
+ if not (
333
+ _is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu
334
+ ):
316
335
  logger.info(
317
336
  "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
318
337
  )
@@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import (
31
31
  _ColumnvLLMParameter,
32
32
  )
33
33
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
34
+ from sglang.srt.layers.utils import pad_or_narrow_weight
34
35
  from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
35
36
 
36
37
  if TYPE_CHECKING:
@@ -235,9 +236,8 @@ class ReplicatedLinear(LinearBase):
235
236
  loaded_weight = loaded_weight[:1]
236
237
  else:
237
238
  raise ValueError(f"{loaded_weight} are not all equal")
238
- assert (
239
- param.size() == loaded_weight.size()
240
- ), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
239
+
240
+ assert param.size() == loaded_weight.size()
241
241
  param.data.copy_(loaded_weight)
242
242
 
243
243
  def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -626,9 +626,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
626
626
  # bitsandbytes loads the weights of the specific portion
627
627
  # no need to narrow here
628
628
  if not use_bitsandbytes_4bit and not self.use_presharded_weights:
629
- loaded_weight = loaded_weight.narrow(
630
- output_dim, start_idx, shard_size
631
- )
629
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
630
+ end_idx = start_idx + shard_size
631
+ if end_idx > loaded_weight.shape[output_dim]:
632
+ loaded_weight = pad_or_narrow_weight(
633
+ loaded_weight, output_dim, start_idx, shard_size
634
+ )
635
+ else:
636
+ loaded_weight = loaded_weight.narrow(
637
+ output_dim, start_idx, shard_size
638
+ )
632
639
 
633
640
  # Special case for AQLM codebooks.
634
641
  elif is_metadata:
@@ -894,6 +901,35 @@ class QKVParallelLinear(ColumnParallelLinear):
894
901
  )
895
902
  self.weight_loader_v2(param, loaded_weight_shard, shard_id)
896
903
 
904
+ def _load_qkv_block_scale(
905
+ self, param: BasevLLMParameter, loaded_weight: torch.Tensor
906
+ ):
907
+ block_n, _ = self.quant_method.quant_config.weight_block_size
908
+ q_size = self.total_num_heads * self.head_size // block_n
909
+ k_size = self.total_num_kv_heads * self.head_size // block_n
910
+ v_size = self.total_num_kv_heads * self.head_size // block_n
911
+ shard_offsets = [
912
+ # (shard_id, shard_offset, shard_size)
913
+ ("q", 0, q_size),
914
+ ("k", q_size, k_size),
915
+ ("v", q_size + k_size, v_size),
916
+ ]
917
+ for shard_id, shard_offset, shard_size in shard_offsets:
918
+ loaded_weight_shard = loaded_weight.narrow(
919
+ param.output_dim, shard_offset, shard_size
920
+ )
921
+ rank_shard_offset = self._get_shard_offset_mapping(shard_id) // block_n
922
+ rank_shard_size = self._get_shard_size_mapping(shard_id) // block_n
923
+ param.load_qkv_weight(
924
+ loaded_weight=loaded_weight_shard,
925
+ num_heads=self.num_kv_head_replicas,
926
+ shard_id=shard_id,
927
+ shard_offset=rank_shard_offset,
928
+ shard_size=rank_shard_size,
929
+ tp_rank=self.tp_rank,
930
+ use_presharded_weights=self.use_presharded_weights,
931
+ )
932
+
897
933
  def weight_loader_v2(
898
934
  self,
899
935
  param: BasevLLMParameter,
@@ -907,6 +943,9 @@ class QKVParallelLinear(ColumnParallelLinear):
907
943
  elif type(param) in (RowvLLMParameter, BasevLLMParameter):
908
944
  param.load_qkv_weight(loaded_weight=loaded_weight)
909
945
  return
946
+ elif isinstance(param, BlockQuantScaleParameter):
947
+ self._load_qkv_block_scale(param, loaded_weight)
948
+ return
910
949
  # TODO: @dsikka - move to parameter.py
911
950
  self._load_fused_module_from_checkpoint(param, loaded_weight)
912
951
  return
@@ -1271,7 +1310,16 @@ class RowParallelLinear(LinearBase):
1271
1310
  shard_size,
1272
1311
  )
1273
1312
  else:
1274
- loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1313
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
1314
+ end_idx = start_idx + shard_size
1315
+ if end_idx > loaded_weight.shape[input_dim]:
1316
+ loaded_weight = pad_or_narrow_weight(
1317
+ loaded_weight, input_dim, start_idx, shard_size
1318
+ )
1319
+ else:
1320
+ loaded_weight = loaded_weight.narrow(
1321
+ input_dim, start_idx, shard_size
1322
+ )
1275
1323
 
1276
1324
  # Special case for loading scales off disk, which often do not
1277
1325
  # have a shape (such as in the case of AutoFP8).
@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import (
35
35
  get_attention_dp_rank,
36
36
  get_attention_dp_size,
37
37
  get_attention_tp_size,
38
+ get_dp_device,
39
+ get_dp_dtype,
40
+ get_dp_hidden_size,
38
41
  get_global_dp_buffer,
39
42
  get_local_attention_dp_size,
40
43
  set_dp_buffer_len,
@@ -46,10 +49,12 @@ from sglang.srt.model_executor.forward_batch_info import (
46
49
  ForwardBatch,
47
50
  ForwardMode,
48
51
  )
49
- from sglang.srt.utils import dump_to_file, use_intel_amx_backend
52
+ from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
50
53
 
51
54
  logger = logging.getLogger(__name__)
52
55
 
56
+ _is_npu = is_npu()
57
+
53
58
 
54
59
  @dataclasses.dataclass
55
60
  class LogitsProcessorOutput:
@@ -67,7 +72,10 @@ class LogitsProcessorOutput:
67
72
  next_token_top_logprobs_val: Optional[List] = None
68
73
  next_token_top_logprobs_idx: Optional[List] = None
69
74
  # The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
70
- next_token_token_ids_logprobs_val: Optional[List] = None
75
+ # Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests)
76
+ next_token_token_ids_logprobs_val: Optional[
77
+ List[Union[List[float], torch.Tensor]]
78
+ ] = None
71
79
  next_token_token_ids_logprobs_idx: Optional[List] = None
72
80
 
73
81
  ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
@@ -180,10 +188,13 @@ class LogitsMetadata:
180
188
  )
181
189
  else:
182
190
  dp_local_start_pos = cumtokens[dp_rank - 1]
183
- dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
184
191
 
185
192
  self.dp_local_start_pos = dp_local_start_pos
186
- self.dp_local_num_tokens = dp_local_num_tokens
193
+ self.dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
194
+
195
+ hidden_size = get_dp_hidden_size()
196
+ dtype = get_dp_dtype()
197
+ device = get_dp_device()
187
198
 
188
199
  if self.global_num_tokens_for_logprob_cpu is not None:
189
200
  # create a smaller buffer to reduce peak memory usage
@@ -191,10 +202,13 @@ class LogitsMetadata:
191
202
  else:
192
203
  self.global_dp_buffer_len = self.global_dp_buffer_len
193
204
 
194
- set_dp_buffer_len(
195
- self.global_dp_buffer_len,
196
- self.dp_local_num_tokens,
197
- self.global_num_tokens_for_logprob_cpu,
205
+ self.gathered_buffer = torch.empty(
206
+ (
207
+ self.global_dp_buffer_len,
208
+ hidden_size,
209
+ ),
210
+ dtype=dtype,
211
+ device=device,
198
212
  )
199
213
 
200
214
 
@@ -206,6 +220,7 @@ class LogitsProcessor(nn.Module):
206
220
  self.config = config
207
221
  self.logit_scale = logit_scale
208
222
  self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
223
+ self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
209
224
  if self.use_attn_tp_group:
210
225
  self.attn_tp_size = get_attention_tp_size()
211
226
  self.do_tensor_parallel_all_gather = (
@@ -441,13 +456,17 @@ class LogitsProcessor(nn.Module):
441
456
  if self.do_tensor_parallel_all_gather_dp_attn:
442
457
  logits_metadata.compute_dp_attention_metadata()
443
458
  hidden_states, local_hidden_states = (
444
- get_global_dp_buffer(),
459
+ logits_metadata.gathered_buffer,
445
460
  hidden_states,
446
461
  )
447
462
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
448
463
 
449
464
  if hasattr(lm_head, "weight"):
450
- if use_intel_amx_backend(lm_head):
465
+ if self.use_fp32_lm_head:
466
+ logits = torch.matmul(
467
+ hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
468
+ )
469
+ elif use_intel_amx_backend(lm_head):
451
470
  logits = torch.ops.sgl_kernel.weight_packed_linear(
452
471
  hidden_states.to(lm_head.weight.dtype),
453
472
  lm_head.weight,
@@ -461,7 +480,15 @@ class LogitsProcessor(nn.Module):
461
480
  else:
462
481
  # GGUF models
463
482
  # TODO: use weight_packed_linear for GGUF models
464
- logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
483
+ if self.use_fp32_lm_head:
484
+ with torch.cuda.amp.autocast(enabled=False):
485
+ logits = lm_head.quant_method.apply(
486
+ lm_head, hidden_states.to(torch.float32), embedding_bias
487
+ )
488
+ else:
489
+ logits = lm_head.quant_method.apply(
490
+ lm_head, hidden_states, embedding_bias
491
+ )
465
492
 
466
493
  if self.logit_scale is not None:
467
494
  logits.mul_(self.logit_scale)
@@ -517,7 +544,12 @@ class LogitsProcessor(nn.Module):
517
544
  logits = logits[:, : self.config.vocab_size].float()
518
545
 
519
546
  if self.final_logit_softcapping:
520
- fused_softcap(logits, self.final_logit_softcapping)
547
+ if not _is_npu:
548
+ fused_softcap(logits, self.final_logit_softcapping)
549
+ else:
550
+ logits = self.final_logit_softcapping * torch.tanh(
551
+ logits / self.final_logit_softcapping
552
+ )
521
553
 
522
554
  return logits
523
555
 
@@ -1,4 +1,4 @@
1
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
1
+ from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig
2
2
  from sglang.srt.layers.moe.utils import (
3
3
  DeepEPMode,
4
4
  MoeA2ABackend,
@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import (
17
17
  __all__ = [
18
18
  "DeepEPMode",
19
19
  "MoeA2ABackend",
20
+ "MoeRunner",
20
21
  "MoeRunnerConfig",
21
22
  "MoeRunnerBackend",
22
23
  "initialize_moe_config",
@@ -147,8 +147,8 @@ def cutlass_w4a8_moe(
147
147
  k,
148
148
  )
149
149
 
150
- c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
151
- c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
150
+ c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
151
+ c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
152
152
 
153
153
  cutlass_w4a8_moe_mm(
154
154
  c1,
@@ -166,7 +166,7 @@ def cutlass_w4a8_moe(
166
166
  topk,
167
167
  )
168
168
 
169
- intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
169
+ intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
170
170
  silu_and_mul(c1, intermediate)
171
171
 
172
172
  intermediate_q = torch.empty(
@@ -1104,10 +1104,10 @@ def ep_gather(
1104
1104
  input_index: torch.Tensor,
1105
1105
  output_tensor: torch.Tensor,
1106
1106
  ):
1107
- BLOCK_D = 1024 if not is_in_ci() else 128 # block size of quantization
1108
1107
  num_warps = 2
1109
1108
  num_tokens = output_tensor.shape[0]
1110
1109
  hidden_size = input_tensor.shape[1]
1110
+ BLOCK_D = 128 if hidden_size % 1024 != 0 else 1024 # block size of quantization
1111
1111
  assert hidden_size % BLOCK_D == 0
1112
1112
  grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
1113
1113
  _fwd_kernel_ep_gather[grid](
@@ -1416,7 +1416,7 @@ def zero_experts_compute_triton(
1416
1416
  zero_expert_scales[zero_expert_mask] = 0.0
1417
1417
 
1418
1418
  normal_expert_mask = expert_indices >= num_experts
1419
- expert_indices[normal_expert_mask] = 0
1419
+ expert_indices[normal_expert_mask] = -1
1420
1420
  expert_scales[normal_expert_mask] = 0.0
1421
1421
 
1422
1422
  output = torch.zeros_like(hidden_states).to(hidden_states.device)