sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
  # Adapted from:
16
16
  # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
17
17
  """Inference-only DeepseekV2 model."""
18
+ from __future__ import annotations
18
19
 
19
20
  import concurrent.futures
20
21
  import logging
@@ -25,9 +26,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
25
26
  import torch
26
27
  import torch.nn.functional as F
27
28
  from torch import nn
28
- from tqdm import tqdm
29
29
  from transformers import PretrainedConfig
30
30
 
31
+ from sglang.srt import single_batch_overlap
32
+ from sglang.srt.configs.model_config import (
33
+ get_nsa_index_head_dim,
34
+ get_nsa_index_n_heads,
35
+ get_nsa_index_topk,
36
+ is_deepseek_nsa,
37
+ )
38
+ from sglang.srt.debug_utils.dumper import dumper
31
39
  from sglang.srt.distributed import (
32
40
  get_moe_expert_parallel_world_size,
33
41
  get_pp_group,
@@ -43,6 +51,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
43
51
  from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
44
52
  from sglang.srt.layers.activation import SiluAndMul
45
53
  from sglang.srt.layers.amx_utils import PackWeightMethod
54
+ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
55
+ NPUFusedMLAPreprocess,
56
+ is_mla_preprocess_enabled,
57
+ )
58
+ from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
46
59
  from sglang.srt.layers.communicator import (
47
60
  LayerCommunicator,
48
61
  LayerScatterModes,
@@ -65,10 +78,11 @@ from sglang.srt.layers.moe import (
65
78
  get_deepep_mode,
66
79
  get_moe_a2a_backend,
67
80
  should_use_flashinfer_cutlass_moe_fp4_allgather,
81
+ should_use_flashinfer_trtllm_moe,
68
82
  )
69
83
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
70
84
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
71
- from sglang.srt.layers.moe.topk import TopK
85
+ from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
72
86
  from sglang.srt.layers.quantization import deep_gemm_wrapper
73
87
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
74
88
  from sglang.srt.layers.quantization.fp8_kernel import (
@@ -96,6 +110,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
96
110
  from sglang.srt.managers.schedule_batch import global_server_args_dict
97
111
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
98
112
  from sglang.srt.model_loader.weight_utils import default_weight_loader
113
+ from sglang.srt.single_batch_overlap import SboFlags
99
114
  from sglang.srt.two_batch_overlap import (
100
115
  MaybeTboDeepEPDispatcher,
101
116
  model_forward_maybe_tbo,
@@ -112,6 +127,7 @@ from sglang.srt.utils import (
112
127
  is_cpu,
113
128
  is_cuda,
114
129
  is_flashinfer_available,
130
+ is_gfx95_supported,
115
131
  is_hip,
116
132
  is_non_idle_and_non_empty,
117
133
  is_npu,
@@ -129,11 +145,28 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
129
145
  _is_cpu_amx_available = cpu_has_amx_support()
130
146
  _is_cpu = is_cpu()
131
147
  _device_sm = get_device_sm()
148
+ _is_gfx95_supported = is_gfx95_supported()
149
+
150
+ _use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
151
+
152
+ if _use_aiter_gfx95:
153
+ from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
154
+ from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
155
+ batched_gemm_afp4wfp4_pre_quant,
156
+ fused_flatten_mxfp4_quant,
157
+ fused_rms_mxfp4_quant,
158
+ )
159
+ from sglang.srt.layers.rocm_linear_utils import (
160
+ aiter_dsv3_router_gemm,
161
+ fused_qk_rope_cat,
162
+ get_dsv3_gemm_output_zero_allocator_size,
163
+ )
132
164
 
133
165
  if _is_cuda:
134
166
  from sgl_kernel import (
135
167
  awq_dequantize,
136
168
  bmm_fp8,
169
+ concat_mla_k,
137
170
  dsv3_fused_a_gemm,
138
171
  dsv3_router_gemm,
139
172
  merge_state_v2,
@@ -141,16 +174,18 @@ if _is_cuda:
141
174
  elif _is_cpu and _is_cpu_amx_available:
142
175
  pass
143
176
  elif _is_hip:
177
+ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
178
+ decode_attention_fwd_grouped_rope,
179
+ )
144
180
  from sglang.srt.layers.quantization.awq_triton import (
145
181
  awq_dequantize_triton as awq_dequantize,
146
182
  )
183
+ elif _is_npu:
184
+ import custom_ops
185
+ import sgl_kernel_npu
186
+ import torch_npu
147
187
  else:
148
- from vllm._custom_ops import awq_dequantize
149
-
150
- if _is_hip:
151
- from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
152
- decode_attention_fwd_grouped_rope,
153
- )
188
+ pass
154
189
 
155
190
  _is_flashinfer_available = is_flashinfer_available()
156
191
  _is_sm100_supported = is_cuda() and is_sm100_supported()
@@ -158,6 +193,21 @@ _is_sm100_supported = is_cuda() and is_sm100_supported()
158
193
 
159
194
  logger = logging.getLogger(__name__)
160
195
 
196
+ FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
197
+ "fa3",
198
+ "nsa",
199
+ "flashinfer",
200
+ "cutlass_mla",
201
+ "trtllm_mla",
202
+ "ascend",
203
+ ]
204
+
205
+
206
+ def add_forward_absorb_core_attention_backend(backend_name):
207
+ if backend_name not in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
208
+ FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.append(backend_name)
209
+ logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
210
+
161
211
 
162
212
  class AttnForwardMethod(IntEnum):
163
213
  # Use multi-head attention
@@ -166,6 +216,9 @@ class AttnForwardMethod(IntEnum):
166
216
  # Use absorbed multi-latent attention
167
217
  MLA = auto()
168
218
 
219
+ # Use Deepseek V3.2 sparse multi-latent attention
220
+ NPU_MLA_SPARSE = auto()
221
+
169
222
  # Use multi-head attention, but with KV cache chunked.
170
223
  # This method can avoid OOM when prefix lengths are long.
171
224
  MHA_CHUNKED_KV = auto()
@@ -177,6 +230,146 @@ class AttnForwardMethod(IntEnum):
177
230
  MLA_FUSED_ROPE_CPU = auto()
178
231
 
179
232
 
233
+ def _dispatch_mla_subtype(attn, forward_batch):
234
+ if _is_hip:
235
+ if attn.rocm_fused_decode_mla and forward_batch.forward_mode.is_decode():
236
+ return AttnForwardMethod.MLA_FUSED_ROPE
237
+ else:
238
+ return AttnForwardMethod.MLA
239
+ else:
240
+ if hasattr(attn, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(attn):
241
+ return AttnForwardMethod.MLA_FUSED_ROPE_CPU
242
+ else:
243
+ return AttnForwardMethod.MLA
244
+
245
+
246
+ class AttentionBackendRegistry:
247
+ _handlers = {}
248
+
249
+ @classmethod
250
+ def register(cls, backend_name, handler_func):
251
+ cls._handlers[backend_name] = handler_func
252
+
253
+ @classmethod
254
+ def get_handler(cls, backend_name):
255
+ return cls._handlers.get(backend_name, cls._handlers.get("triton"))
256
+
257
+
258
+ def handle_attention_ascend(attn, forward_batch):
259
+ if (
260
+ forward_batch.forward_mode.is_extend()
261
+ and not forward_batch.forward_mode.is_target_verify()
262
+ and not forward_batch.forward_mode.is_draft_extend()
263
+ ):
264
+ if hasattr(attn, "indexer"):
265
+ return AttnForwardMethod.NPU_MLA_SPARSE
266
+ else:
267
+ return AttnForwardMethod.MHA
268
+ else:
269
+ if hasattr(attn, "indexer"):
270
+ return AttnForwardMethod.NPU_MLA_SPARSE
271
+ else:
272
+ return AttnForwardMethod.MLA
273
+
274
+
275
+ def _get_sum_extend_prefix_lens(forward_batch):
276
+ return (
277
+ sum(forward_batch.extend_prefix_lens_cpu)
278
+ if forward_batch.extend_prefix_lens_cpu is not None
279
+ else 0
280
+ )
281
+
282
+
283
+ def _is_extend_without_speculative(forward_batch):
284
+ return (
285
+ forward_batch.forward_mode.is_extend()
286
+ and not forward_batch.forward_mode.is_target_verify()
287
+ and not forward_batch.forward_mode.is_draft_extend()
288
+ )
289
+
290
+
291
+ def _handle_attention_backend(
292
+ attn: DeepseekV2AttentionMLA, forward_batch, backend_name
293
+ ):
294
+ sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
295
+ disable_ragged = (
296
+ backend_name in ["flashinfer", "flashmla"]
297
+ ) and attn.flashinfer_mla_disable_ragged
298
+
299
+ if (
300
+ not disable_ragged
301
+ and _is_extend_without_speculative(forward_batch)
302
+ and (
303
+ (
304
+ sum_extend_prefix_lens >= attn.chunked_prefix_cache_threshold
305
+ and not attn.disable_chunked_prefix_cache
306
+ )
307
+ or sum_extend_prefix_lens == 0
308
+ )
309
+ ):
310
+ return AttnForwardMethod.MHA_CHUNKED_KV
311
+ else:
312
+ return _dispatch_mla_subtype(attn, forward_batch)
313
+
314
+
315
+ def handle_attention_flashinfer(attn, forward_batch):
316
+ return _handle_attention_backend(attn, forward_batch, "flashinfer")
317
+
318
+
319
+ def handle_attention_fa3(attn, forward_batch):
320
+ return _handle_attention_backend(attn, forward_batch, "fa3")
321
+
322
+
323
+ def handle_attention_flashmla(attn, forward_batch):
324
+ return _handle_attention_backend(attn, forward_batch, "flashmla")
325
+
326
+
327
+ def handle_attention_cutlass_mla(attn, forward_batch):
328
+ return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
329
+
330
+
331
+ def handle_attention_fa4(attn, forward_batch):
332
+ # TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
333
+ return AttnForwardMethod.MHA_CHUNKED_KV
334
+
335
+
336
+ def handle_attention_trtllm_mla(attn, forward_batch):
337
+ sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
338
+ if _is_extend_without_speculative(forward_batch) and (
339
+ not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
340
+ ):
341
+ return AttnForwardMethod.MHA_CHUNKED_KV
342
+ else:
343
+ return _dispatch_mla_subtype(attn, forward_batch)
344
+
345
+
346
+ def handle_attention_aiter(attn, forward_batch):
347
+ if _is_extend_without_speculative(forward_batch):
348
+ if is_dp_attention_enabled():
349
+ if sum(forward_batch.extend_prefix_lens_cpu) == 0:
350
+ return AttnForwardMethod.MHA
351
+ else:
352
+ return AttnForwardMethod.MLA
353
+ else:
354
+ return AttnForwardMethod.MHA
355
+ else:
356
+ return AttnForwardMethod.MLA
357
+
358
+
359
+ def handle_attention_nsa(attn, forward_batch):
360
+ return AttnForwardMethod.MLA
361
+
362
+
363
+ def handle_attention_triton(attn, forward_batch):
364
+ if (
365
+ _is_extend_without_speculative(forward_batch)
366
+ and sum(forward_batch.extend_prefix_lens_cpu) == 0
367
+ ):
368
+ return AttnForwardMethod.MHA
369
+ else:
370
+ return _dispatch_mla_subtype(attn, forward_batch)
371
+
372
+
180
373
  class DeepseekV2MLP(nn.Module):
181
374
  def __init__(
182
375
  self,
@@ -224,10 +417,21 @@ class DeepseekV2MLP(nn.Module):
224
417
  forward_batch=None,
225
418
  should_allreduce_fusion: bool = False,
226
419
  use_reduce_scatter: bool = False,
420
+ gemm_output_zero_allocator: BumpAllocator = None,
227
421
  ):
228
422
  if (self.tp_size == 1) and x.shape[0] == 0:
229
423
  return x
230
424
 
425
+ if (
426
+ gemm_output_zero_allocator is not None
427
+ and x.shape[0] <= 256
428
+ and self.gate_up_proj.weight.dtype == torch.uint8
429
+ ):
430
+ y = gemm_output_zero_allocator.allocate(
431
+ x.shape[0] * self.gate_up_proj.output_size_per_partition
432
+ ).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
433
+ x = (x, None, y)
434
+
231
435
  gate_up, _ = self.gate_up_proj(x)
232
436
  x = self.act_fn(gate_up)
233
437
  x, _ = self.down_proj(
@@ -240,6 +444,7 @@ class MoEGate(nn.Module):
240
444
  def __init__(
241
445
  self,
242
446
  config,
447
+ quant_config,
243
448
  prefix: str = "",
244
449
  is_nextn: bool = False,
245
450
  ):
@@ -249,15 +454,22 @@ class MoEGate(nn.Module):
249
454
  torch.empty((config.n_routed_experts, config.hidden_size))
250
455
  )
251
456
  if config.topk_method == "noaux_tc":
457
+ correction_bias_dtype = (
458
+ torch.bfloat16
459
+ if quant_config is not None
460
+ and quant_config.get_name() == "modelopt_fp4"
461
+ and should_use_flashinfer_trtllm_moe()
462
+ else torch.float32
463
+ )
252
464
  self.e_score_correction_bias = nn.Parameter(
253
- torch.empty((config.n_routed_experts), dtype=torch.float32)
465
+ torch.empty((config.n_routed_experts), dtype=correction_bias_dtype)
254
466
  )
255
467
  else:
256
468
  self.e_score_correction_bias = None
257
469
  if _is_cpu and _is_cpu_amx_available:
258
470
  self.quant_method = PackWeightMethod(weight_names=["weight"])
259
471
 
260
- def forward(self, hidden_states):
472
+ def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
261
473
  if use_intel_amx_backend(self):
262
474
  return torch.ops.sgl_kernel.weight_packed_linear(
263
475
  hidden_states,
@@ -271,11 +483,17 @@ class MoEGate(nn.Module):
271
483
  _is_cuda
272
484
  and hidden_states.shape[0] <= 16
273
485
  and hidden_states.shape[1] == 7168
274
- and self.weight.shape[0] == 256
486
+ and (self.weight.shape[0] == 256 or self.weight.shape[0] == 384)
275
487
  and _device_sm >= 90
276
488
  ):
277
489
  # router gemm output float32
278
- logits = dsv3_router_gemm(hidden_states, self.weight)
490
+ logits = dsv3_router_gemm(
491
+ hidden_states, self.weight, out_dtype=torch.float32
492
+ )
493
+ elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
494
+ logits = aiter_dsv3_router_gemm(
495
+ hidden_states, self.weight, gemm_output_zero_allocator
496
+ )
279
497
  else:
280
498
  logits = F.linear(hidden_states, self.weight, None)
281
499
 
@@ -319,7 +537,10 @@ class DeepseekV2MoE(nn.Module):
319
537
  )
320
538
 
321
539
  self.gate = MoEGate(
322
- config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
540
+ config=config,
541
+ quant_config=quant_config,
542
+ prefix=add_prefix("gate", prefix),
543
+ is_nextn=is_nextn,
323
544
  )
324
545
 
325
546
  self.experts = get_moe_impl_class(quant_config)(
@@ -344,9 +565,12 @@ class DeepseekV2MoE(nn.Module):
344
565
  num_fused_shared_experts=self.num_fused_shared_experts,
345
566
  topk_group=config.topk_group,
346
567
  correction_bias=self.gate.e_score_correction_bias,
568
+ quant_config=quant_config,
347
569
  routed_scaling_factor=self.routed_scaling_factor,
348
- apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
349
- force_topk=quant_config is None,
570
+ apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
571
+ # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
572
+ # and requires the output format to be standard. We use quant_config to determine the output format.
573
+ output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
350
574
  )
351
575
 
352
576
  self.shared_experts_is_int8 = False
@@ -439,6 +663,7 @@ class DeepseekV2MoE(nn.Module):
439
663
  forward_batch: Optional[ForwardBatch] = None,
440
664
  should_allreduce_fusion: bool = False,
441
665
  use_reduce_scatter: bool = False,
666
+ gemm_output_zero_allocator: BumpAllocator = None,
442
667
  ) -> torch.Tensor:
443
668
  if not self._enable_deepep_moe:
444
669
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
@@ -452,12 +677,14 @@ class DeepseekV2MoE(nn.Module):
452
677
  hidden_states,
453
678
  should_allreduce_fusion,
454
679
  use_reduce_scatter,
680
+ gemm_output_zero_allocator,
455
681
  )
456
682
  else:
457
683
  return self.forward_normal(
458
684
  hidden_states,
459
685
  should_allreduce_fusion,
460
686
  use_reduce_scatter,
687
+ gemm_output_zero_allocator,
461
688
  )
462
689
  else:
463
690
  return self.forward_deepep(hidden_states, forward_batch)
@@ -467,15 +694,18 @@ class DeepseekV2MoE(nn.Module):
467
694
  hidden_states: torch.Tensor,
468
695
  should_allreduce_fusion: bool = False,
469
696
  use_reduce_scatter: bool = False,
697
+ gemm_output_zero_allocator: BumpAllocator = None,
470
698
  ) -> torch.Tensor:
471
699
 
472
700
  current_stream = torch.cuda.current_stream()
473
701
  self.alt_stream.wait_stream(current_stream)
474
- shared_output = self._forward_shared_experts(hidden_states)
702
+ shared_output = self._forward_shared_experts(
703
+ hidden_states, gemm_output_zero_allocator
704
+ )
475
705
 
476
706
  with torch.cuda.stream(self.alt_stream):
477
707
  # router_logits: (num_tokens, n_experts)
478
- router_logits = self.gate(hidden_states)
708
+ router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
479
709
  topk_output = self.topk(hidden_states, router_logits)
480
710
  final_hidden_states = self.experts(hidden_states, topk_output)
481
711
  if not _is_cuda:
@@ -502,6 +732,7 @@ class DeepseekV2MoE(nn.Module):
502
732
  hidden_states: torch.Tensor,
503
733
  should_allreduce_fusion: bool = False,
504
734
  use_reduce_scatter: bool = False,
735
+ gemm_output_zero_allocator: BumpAllocator = None,
505
736
  ) -> torch.Tensor:
506
737
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
507
738
  self.shared_experts.gate_up_proj
@@ -509,9 +740,11 @@ class DeepseekV2MoE(nn.Module):
509
740
  return self.forward_cpu(hidden_states, should_allreduce_fusion)
510
741
 
511
742
  if hidden_states.shape[0] > 0:
512
- shared_output = self._forward_shared_experts(hidden_states)
743
+ shared_output = self._forward_shared_experts(
744
+ hidden_states, gemm_output_zero_allocator
745
+ )
513
746
  # router_logits: (num_tokens, n_experts)
514
- router_logits = self.gate(hidden_states)
747
+ router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
515
748
  topk_output = self.topk(hidden_states, router_logits)
516
749
  else:
517
750
  shared_output = None
@@ -601,7 +834,8 @@ class DeepseekV2MoE(nn.Module):
601
834
  if hidden_states.shape[0] > 0:
602
835
  # router_logits: (num_tokens, n_experts)
603
836
  router_logits = self.gate(hidden_states)
604
- shared_output = self._forward_shared_experts(hidden_states)
837
+ if not SboFlags.fuse_shared_experts_inside_sbo():
838
+ shared_output = self._forward_shared_experts(hidden_states)
605
839
  topk_weights, topk_idx, _ = self.topk(
606
840
  hidden_states,
607
841
  router_logits,
@@ -615,25 +849,39 @@ class DeepseekV2MoE(nn.Module):
615
849
  hidden_states.device
616
850
  )
617
851
 
618
- final_hidden_states = self.experts(
852
+ final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
619
853
  hidden_states=hidden_states,
620
854
  topk_idx=topk_idx,
621
855
  topk_weights=topk_weights,
622
856
  forward_batch=forward_batch,
857
+ # SBO args
858
+ forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
859
+ experts=self.experts,
860
+ alt_stream=self.alt_stream,
623
861
  )
862
+ if sbo_shared_output is not None:
863
+ shared_output = sbo_shared_output
624
864
 
625
865
  if shared_output is not None:
626
866
  x = shared_output
627
- x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
867
+ if self.experts.should_fuse_routed_scaling_factor_in_topk:
868
+ x.add_(final_hidden_states)
869
+ else:
870
+ x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
628
871
  final_hidden_states = x
629
872
  else:
630
- final_hidden_states *= self.routed_scaling_factor
873
+ if not self.experts.should_fuse_routed_scaling_factor_in_topk:
874
+ final_hidden_states *= self.routed_scaling_factor
631
875
 
632
876
  return final_hidden_states
633
877
 
634
- def _forward_shared_experts(self, hidden_states):
635
- if self.num_fused_shared_experts == 0:
636
- return self.shared_experts(hidden_states)
878
+ def _forward_shared_experts(
879
+ self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
880
+ ):
881
+ if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0):
882
+ return self.shared_experts(
883
+ hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
884
+ )
637
885
  else:
638
886
  return None
639
887
 
@@ -683,6 +931,7 @@ class DeepseekV2MoE(nn.Module):
683
931
  if self.ep_size > 1:
684
932
  self.experts.deepep_dispatcher.dispatch_a(
685
933
  hidden_states=state.hidden_states_mlp_input,
934
+ input_global_scale=None,
686
935
  topk_idx=state.pop("topk_idx_local"),
687
936
  topk_weights=state.pop("topk_weights_local"),
688
937
  forward_batch=state.forward_batch,
@@ -783,6 +1032,10 @@ class DeepseekV2AttentionMLA(nn.Module):
783
1032
  self.rope_theta = rope_theta
784
1033
  self.max_position_embeddings = max_position_embeddings
785
1034
 
1035
+ # NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
1036
+ if rope_scaling:
1037
+ rope_scaling["rope_type"] = "deepseek_yarn"
1038
+
786
1039
  # For tensor parallel attention
787
1040
  if self.q_lora_rank is not None:
788
1041
  self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
@@ -820,6 +1073,26 @@ class DeepseekV2AttentionMLA(nn.Module):
820
1073
  prefix=add_prefix("kv_a_proj_with_mqa", prefix),
821
1074
  )
822
1075
 
1076
+ self.use_nsa = is_deepseek_nsa(config)
1077
+ if self.use_nsa:
1078
+ self.indexer = Indexer(
1079
+ hidden_size=hidden_size,
1080
+ index_n_heads=get_nsa_index_n_heads(config),
1081
+ index_head_dim=get_nsa_index_head_dim(config),
1082
+ rope_head_dim=qk_rope_head_dim,
1083
+ index_topk=get_nsa_index_topk(config),
1084
+ q_lora_rank=q_lora_rank,
1085
+ max_position_embeddings=max_position_embeddings,
1086
+ rope_theta=rope_theta,
1087
+ scale_fmt="ue8m0",
1088
+ block_size=128,
1089
+ rope_scaling=rope_scaling,
1090
+ prefix=add_prefix("indexer", prefix),
1091
+ quant_config=quant_config,
1092
+ layer_id=layer_id,
1093
+ alt_stream=alt_stream,
1094
+ )
1095
+
823
1096
  self.kv_b_proj = ColumnParallelLinear(
824
1097
  self.kv_lora_rank,
825
1098
  self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -842,9 +1115,6 @@ class DeepseekV2AttentionMLA(nn.Module):
842
1115
  )
843
1116
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
844
1117
 
845
- if rope_scaling:
846
- rope_scaling["rope_type"] = "deepseek_yarn"
847
-
848
1118
  self.rotary_emb = get_rope_wrapper(
849
1119
  qk_rope_head_dim,
850
1120
  rotary_dim=qk_rope_head_dim,
@@ -968,96 +1238,34 @@ class DeepseekV2AttentionMLA(nn.Module):
968
1238
  self.weight_block_size = (
969
1239
  self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
970
1240
  )
1241
+ self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
1242
+ if self.is_mla_preprocess_enabled:
1243
+ assert (
1244
+ quant_config is None or quant_config.get_name() == "w8a8_int8"
1245
+ ), "MLA Preprocess only works with Unquant or W8A8Int8"
1246
+ self.mla_preprocess = None
971
1247
 
972
1248
  def dispatch_attn_forward_method(
973
1249
  self, forward_batch: ForwardBatch
974
1250
  ) -> AttnForwardMethod:
975
- def _dispatch_mla_subtype():
976
- if _is_hip:
977
- if (
978
- self.rocm_fused_decode_mla
979
- and forward_batch.forward_mode.is_decode()
980
- ):
981
- return AttnForwardMethod.MLA_FUSED_ROPE
982
- else:
983
- return AttnForwardMethod.MLA
984
- else:
985
- if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
986
- self
987
- ):
988
- return AttnForwardMethod.MLA_FUSED_ROPE_CPU
989
- else:
990
- return AttnForwardMethod.MLA
991
-
992
1251
  # Determine attention backend used by current forward batch
993
1252
  if forward_batch.forward_mode.is_decode_or_idle():
994
1253
  attention_backend = global_server_args_dict["decode_attention_backend"]
1254
+ elif (
1255
+ forward_batch.forward_mode.is_target_verify()
1256
+ or forward_batch.forward_mode.is_draft_extend()
1257
+ ):
1258
+ # Use the specified backend for speculative operations (both verify and draft extend)
1259
+ if global_server_args_dict["speculative_attention_mode"] == "decode":
1260
+ attention_backend = global_server_args_dict["decode_attention_backend"]
1261
+ else: # default to prefill
1262
+ attention_backend = global_server_args_dict["prefill_attention_backend"]
995
1263
  else:
996
1264
  attention_backend = global_server_args_dict["prefill_attention_backend"]
997
1265
  self.current_attention_backend = attention_backend
998
1266
 
999
- if attention_backend == "ascend":
1000
- if (
1001
- forward_batch.forward_mode.is_extend()
1002
- and not forward_batch.forward_mode.is_target_verify()
1003
- and not forward_batch.forward_mode.is_draft_extend()
1004
- ):
1005
- return AttnForwardMethod.MHA
1006
- else:
1007
- return AttnForwardMethod.MLA
1008
- elif (
1009
- attention_backend == "flashinfer"
1010
- or attention_backend == "fa3"
1011
- or attention_backend == "flashmla"
1012
- or attention_backend == "trtllm_mla"
1013
- or attention_backend == "cutlass_mla"
1014
- ):
1015
- # Use MHA with chunked KV cache when prefilling on long sequences.
1016
- sum_extend_prefix_lens = (
1017
- sum(forward_batch.extend_prefix_lens_cpu)
1018
- if forward_batch.extend_prefix_lens_cpu is not None
1019
- else 0
1020
- )
1021
- # Flashinfer MLA: Do not absorb when enabling ragged prefill
1022
- disable_ragged = (
1023
- attention_backend == "flashinfer" or attention_backend == "flashmla"
1024
- ) and self.flashinfer_mla_disable_ragged
1025
- if (
1026
- not disable_ragged
1027
- and forward_batch.forward_mode.is_extend()
1028
- and not forward_batch.forward_mode.is_target_verify()
1029
- and not forward_batch.forward_mode.is_draft_extend()
1030
- and (
1031
- (
1032
- sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
1033
- and not self.disable_chunked_prefix_cache
1034
- )
1035
- or sum_extend_prefix_lens == 0
1036
- )
1037
- ):
1038
- return AttnForwardMethod.MHA_CHUNKED_KV
1039
- else:
1040
- return _dispatch_mla_subtype()
1041
- elif attention_backend == "aiter":
1042
- if (
1043
- forward_batch.forward_mode.is_extend()
1044
- and not forward_batch.forward_mode.is_target_verify()
1045
- and not forward_batch.forward_mode.is_draft_extend()
1046
- ):
1047
- return AttnForwardMethod.MHA
1048
- else:
1049
- return AttnForwardMethod.MLA
1050
- else:
1051
- # Triton: Use normal computation for prefill and use weight absorption for extend/decode
1052
- if (
1053
- forward_batch.forward_mode.is_extend()
1054
- and not forward_batch.forward_mode.is_target_verify()
1055
- and not forward_batch.forward_mode.is_draft_extend()
1056
- and sum(forward_batch.extend_prefix_lens_cpu) == 0
1057
- ):
1058
- return AttnForwardMethod.MHA
1059
- else:
1060
- return _dispatch_mla_subtype()
1267
+ handler = AttentionBackendRegistry.get_handler(attention_backend)
1268
+ return handler(self, forward_batch)
1061
1269
 
1062
1270
  def op_prepare(self, state):
1063
1271
  state.attn_intermediate_state = self.forward_prepare(
@@ -1097,14 +1305,21 @@ class DeepseekV2AttentionMLA(nn.Module):
1097
1305
  if self.attn_mha.kv_b_proj is None:
1098
1306
  self.attn_mha.kv_b_proj = self.kv_b_proj
1099
1307
 
1100
- if hidden_states.shape[0] == 0:
1101
- assert (
1102
- not self.o_proj.reduce_results
1103
- ), "short-circuiting allreduce will lead to hangs"
1104
- return hidden_states, None, forward_batch, None
1308
+ # when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
1309
+ if isinstance(hidden_states, tuple):
1310
+ if hidden_states[0].shape[0] == 0:
1311
+ assert (
1312
+ not self.o_proj.reduce_results
1313
+ ), "short-circuiting allreduce will lead to hangs"
1314
+ return hidden_states[0]
1315
+ else:
1316
+ if hidden_states.shape[0] == 0:
1317
+ assert (
1318
+ not self.o_proj.reduce_results
1319
+ ), "short-circuiting allreduce will lead to hangs"
1320
+ return hidden_states, None, forward_batch, None
1105
1321
 
1106
1322
  attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
1107
-
1108
1323
  if attn_forward_method == AttnForwardMethod.MHA:
1109
1324
  inner_state = self.forward_normal_prepare(
1110
1325
  positions, hidden_states, forward_batch, zero_allocator
@@ -1114,7 +1329,30 @@ class DeepseekV2AttentionMLA(nn.Module):
1114
1329
  positions, hidden_states, forward_batch, zero_allocator
1115
1330
  )
1116
1331
  elif attn_forward_method == AttnForwardMethod.MLA:
1117
- inner_state = self.forward_absorb_prepare(
1332
+ if not self.is_mla_preprocess_enabled:
1333
+ inner_state = self.forward_absorb_prepare(
1334
+ positions, hidden_states, forward_batch, zero_allocator
1335
+ )
1336
+ else:
1337
+ # TODO(iforgetmyname): to be separated as a standalone func
1338
+ if self.mla_preprocess is None:
1339
+ self.mla_preprocess = NPUFusedMLAPreprocess(
1340
+ self.fused_qkv_a_proj_with_mqa,
1341
+ self.q_a_layernorm,
1342
+ self.kv_a_layernorm,
1343
+ self.q_b_proj,
1344
+ self.w_kc,
1345
+ self.rotary_emb,
1346
+ self.layer_id,
1347
+ self.num_local_heads,
1348
+ self.qk_nope_head_dim,
1349
+ self.qk_rope_head_dim,
1350
+ )
1351
+ inner_state = self.mla_preprocess.forward(
1352
+ positions, hidden_states, forward_batch, zero_allocator
1353
+ )
1354
+ elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
1355
+ inner_state = self.forward_npu_sparse_prepare(
1118
1356
  positions, hidden_states, forward_batch, zero_allocator
1119
1357
  )
1120
1358
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
@@ -1142,6 +1380,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1142
1380
  return self.forward_normal_chunked_kv_core(*inner_state)
1143
1381
  elif attn_forward_method == AttnForwardMethod.MLA:
1144
1382
  return self.forward_absorb_core(*inner_state)
1383
+ elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
1384
+ return self.forward_npu_sparse_core(*inner_state)
1145
1385
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
1146
1386
  return self.forward_absorb_fused_mla_rope_core(*inner_state)
1147
1387
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
@@ -1180,8 +1420,18 @@ class DeepseekV2AttentionMLA(nn.Module):
1180
1420
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1181
1421
  q[..., self.qk_nope_head_dim :] = q_pe
1182
1422
  k = torch.empty_like(q)
1183
- k[..., : self.qk_nope_head_dim] = k_nope
1184
- k[..., self.qk_nope_head_dim :] = k_pe
1423
+
1424
+ # Temporary for DeepSeek V3/R1 only, but can generalize if needed
1425
+ if (
1426
+ _is_cuda
1427
+ and (self.num_local_heads == 128)
1428
+ and (self.qk_nope_head_dim == 128)
1429
+ and (self.qk_rope_head_dim == 64)
1430
+ ):
1431
+ concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
1432
+ else:
1433
+ k[..., : self.qk_nope_head_dim] = k_nope
1434
+ k[..., self.qk_nope_head_dim :] = k_pe
1185
1435
 
1186
1436
  if not _is_npu:
1187
1437
  latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
@@ -1211,7 +1461,10 @@ class DeepseekV2AttentionMLA(nn.Module):
1211
1461
  """
1212
1462
  return (
1213
1463
  self.current_attention_backend == "trtllm_mla"
1214
- and forward_batch.forward_mode.is_decode_or_idle()
1464
+ and (
1465
+ forward_batch.forward_mode.is_decode_or_idle()
1466
+ or forward_batch.forward_mode.is_target_verify()
1467
+ )
1215
1468
  and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
1216
1469
  )
1217
1470
 
@@ -1224,8 +1477,13 @@ class DeepseekV2AttentionMLA(nn.Module):
1224
1477
  ):
1225
1478
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1226
1479
 
1480
+ q_lora = None
1227
1481
  if self.q_lora_rank is not None:
1228
- if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
1482
+ if (
1483
+ (not isinstance(hidden_states, tuple))
1484
+ and hidden_states.shape[0] <= 16
1485
+ and self.use_min_latency_fused_a_gemm
1486
+ ):
1229
1487
  fused_qkv_a_proj_out = dsv3_fused_a_gemm(
1230
1488
  hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
1231
1489
  )
@@ -1245,8 +1503,22 @@ class DeepseekV2AttentionMLA(nn.Module):
1245
1503
  k_nope = self.kv_a_layernorm(k_nope)
1246
1504
  current_stream.wait_stream(self.alt_stream)
1247
1505
  else:
1248
- q = self.q_a_layernorm(q)
1249
- k_nope = self.kv_a_layernorm(k_nope)
1506
+ if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
1507
+ q, k_nope = fused_rms_mxfp4_quant(
1508
+ q,
1509
+ self.q_a_layernorm.weight,
1510
+ self.q_a_layernorm.variance_epsilon,
1511
+ k_nope,
1512
+ self.kv_a_layernorm.weight,
1513
+ self.kv_a_layernorm.variance_epsilon,
1514
+ )
1515
+ else:
1516
+ q = self.q_a_layernorm(q)
1517
+ k_nope = self.kv_a_layernorm(k_nope)
1518
+
1519
+ # q_lora needed by indexer
1520
+ if self.use_nsa:
1521
+ q_lora = q
1250
1522
 
1251
1523
  k_nope = k_nope.unsqueeze(1)
1252
1524
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
@@ -1278,10 +1550,27 @@ class DeepseekV2AttentionMLA(nn.Module):
1278
1550
  q_nope_out = q_nope_out[:, :expected_m, :]
1279
1551
  elif _is_hip:
1280
1552
  # TODO(haishaw): add bmm_fp8 to ROCm
1281
- q_nope_out = torch.bmm(
1282
- q_nope.to(torch.bfloat16).transpose(0, 1),
1283
- self.w_kc.to(torch.bfloat16) * self.w_scale,
1284
- )
1553
+ if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
1554
+ x = q_nope.transpose(0, 1)
1555
+ q_nope_out = torch.empty(
1556
+ x.shape[0],
1557
+ x.shape[1],
1558
+ self.w_kc.shape[2],
1559
+ device=x.device,
1560
+ dtype=torch.bfloat16,
1561
+ )
1562
+ batched_gemm_afp4wfp4_pre_quant(
1563
+ x,
1564
+ self.w_kc.transpose(-2, -1),
1565
+ self.w_scale_k.transpose(-2, -1),
1566
+ torch.bfloat16,
1567
+ q_nope_out,
1568
+ )
1569
+ else:
1570
+ q_nope_out = torch.bmm(
1571
+ q_nope.to(torch.bfloat16).transpose(0, 1),
1572
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
1573
+ )
1285
1574
  elif self.w_kc.dtype == torch.float8_e4m3fn:
1286
1575
  q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1287
1576
  q_nope.transpose(0, 1),
@@ -1295,27 +1584,51 @@ class DeepseekV2AttentionMLA(nn.Module):
1295
1584
 
1296
1585
  q_nope_out = q_nope_out.transpose(0, 1)
1297
1586
 
1298
- if not self._fuse_rope_for_trtllm_mla(forward_batch):
1587
+ if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1588
+ not _use_aiter or not _is_gfx95_supported or self.use_nsa
1589
+ ):
1299
1590
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1300
1591
 
1301
- return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1592
+ topk_indices = None
1593
+ if q_lora is not None:
1594
+ topk_indices = self.indexer(
1595
+ x=hidden_states,
1596
+ q_lora=q_lora,
1597
+ positions=positions,
1598
+ forward_batch=forward_batch,
1599
+ layer_id=self.layer_id,
1600
+ )
1601
+
1602
+ return (
1603
+ q_pe,
1604
+ k_pe,
1605
+ q_nope_out,
1606
+ k_nope,
1607
+ forward_batch,
1608
+ zero_allocator,
1609
+ positions,
1610
+ topk_indices,
1611
+ )
1302
1612
 
1303
1613
  def forward_absorb_core(
1304
- self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1614
+ self,
1615
+ q_pe,
1616
+ k_pe,
1617
+ q_nope_out,
1618
+ k_nope,
1619
+ forward_batch,
1620
+ zero_allocator,
1621
+ positions,
1622
+ topk_indices,
1305
1623
  ):
1306
- if (
1307
- self.current_attention_backend == "fa3"
1308
- or self.current_attention_backend == "flashinfer"
1309
- or self.current_attention_backend == "cutlass_mla"
1310
- or self.current_attention_backend == "trtllm_mla"
1311
- or self.current_attention_backend == "ascend"
1312
- ):
1624
+ if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
1313
1625
  extra_args = {}
1314
1626
  if self._fuse_rope_for_trtllm_mla(forward_batch):
1315
1627
  extra_args = {
1316
1628
  "cos_sin_cache": self.rotary_emb.cos_sin_cache,
1317
1629
  "is_neox": self.rotary_emb.is_neox_style,
1318
1630
  }
1631
+
1319
1632
  attn_output = self.attn_mqa(
1320
1633
  q_nope_out,
1321
1634
  k_nope,
@@ -1324,11 +1637,33 @@ class DeepseekV2AttentionMLA(nn.Module):
1324
1637
  q_rope=q_pe,
1325
1638
  k_rope=k_pe,
1326
1639
  **extra_args,
1640
+ **(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
1327
1641
  )
1328
1642
  else:
1329
- q = torch.cat([q_nope_out, q_pe], dim=-1)
1330
- k = torch.cat([k_nope, k_pe], dim=-1)
1331
- attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1643
+ if _use_aiter_gfx95:
1644
+ cos = self.rotary_emb.cos_cache
1645
+ sin = self.rotary_emb.sin_cache
1646
+ q, k = fused_qk_rope_cat(
1647
+ q_nope_out,
1648
+ q_pe,
1649
+ k_nope,
1650
+ k_pe,
1651
+ positions,
1652
+ cos,
1653
+ sin,
1654
+ self.rotary_emb.is_neox_style,
1655
+ )
1656
+ else:
1657
+ q = torch.cat([q_nope_out, q_pe], dim=-1)
1658
+ k = torch.cat([k_nope, k_pe], dim=-1)
1659
+
1660
+ attn_output = self.attn_mqa(
1661
+ q,
1662
+ k,
1663
+ k_nope,
1664
+ forward_batch,
1665
+ **(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
1666
+ )
1332
1667
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1333
1668
 
1334
1669
  if self.use_deep_gemm_bmm:
@@ -1352,11 +1687,34 @@ class DeepseekV2AttentionMLA(nn.Module):
1352
1687
  )
1353
1688
  elif _is_hip:
1354
1689
  # TODO(haishaw): add bmm_fp8 to ROCm
1355
- attn_bmm_output = torch.bmm(
1356
- attn_output.to(torch.bfloat16).transpose(0, 1),
1357
- self.w_vc.to(torch.bfloat16) * self.w_scale,
1358
- )
1359
- attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1690
+ if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
1691
+ x = attn_output.transpose(0, 1)
1692
+ attn_bmm_output = torch.empty(
1693
+ x.shape[0],
1694
+ x.shape[1],
1695
+ self.w_vc.shape[2],
1696
+ device=x.device,
1697
+ dtype=torch.bfloat16,
1698
+ )
1699
+ batched_gemm_afp4wfp4_pre_quant(
1700
+ x,
1701
+ self.w_vc.transpose(-2, -1),
1702
+ self.w_scale_v.transpose(-2, -1),
1703
+ torch.bfloat16,
1704
+ attn_bmm_output,
1705
+ )
1706
+ else:
1707
+ attn_bmm_output = torch.bmm(
1708
+ attn_output.to(torch.bfloat16).transpose(0, 1),
1709
+ self.w_vc.to(torch.bfloat16) * self.w_scale,
1710
+ )
1711
+
1712
+ if self.o_proj.weight.dtype == torch.uint8:
1713
+ attn_bmm_output = attn_bmm_output.transpose(0, 1)
1714
+ attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
1715
+ else:
1716
+ attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1717
+
1360
1718
  elif self.w_vc.dtype == torch.float8_e4m3fn:
1361
1719
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1362
1720
  attn_output.transpose(0, 1),
@@ -1387,6 +1745,221 @@ class DeepseekV2AttentionMLA(nn.Module):
1387
1745
 
1388
1746
  return output
1389
1747
 
1748
+ def forward_npu_sparse_prepare(
1749
+ self,
1750
+ positions: torch.Tensor,
1751
+ hidden_states: torch.Tensor,
1752
+ forward_batch: ForwardBatch,
1753
+ zero_allocator: BumpAllocator,
1754
+ ):
1755
+ """
1756
+ Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare
1757
+ """
1758
+ if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode():
1759
+ if self.mla_preprocess is None:
1760
+ self.mla_preprocess = NPUFusedMLAPreprocess(
1761
+ self.fused_qkv_a_proj_with_mqa,
1762
+ self.q_a_layernorm,
1763
+ self.kv_a_layernorm,
1764
+ self.q_b_proj,
1765
+ self.w_kc,
1766
+ self.rotary_emb,
1767
+ self.layer_id,
1768
+ self.num_local_heads,
1769
+ self.qk_nope_head_dim,
1770
+ self.qk_rope_head_dim,
1771
+ )
1772
+ (
1773
+ q_pe,
1774
+ k_pe,
1775
+ q_nope_out,
1776
+ k_nope,
1777
+ forward_batch,
1778
+ zero_allocator,
1779
+ positions,
1780
+ ) = self.mla_preprocess.forward(
1781
+ positions, hidden_states, forward_batch, zero_allocator
1782
+ )
1783
+
1784
+ fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
1785
+ q, _ = fused_qkv_a_proj_out.split(
1786
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1787
+ )
1788
+ q_lora = self.q_a_layernorm(q)
1789
+ else:
1790
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1791
+
1792
+ if (
1793
+ (not isinstance(hidden_states, tuple))
1794
+ and hidden_states.shape[0] <= 16
1795
+ and self.use_min_latency_fused_a_gemm
1796
+ ):
1797
+ fused_qkv_a_proj_out = dsv3_fused_a_gemm(
1798
+ hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
1799
+ )
1800
+ else:
1801
+ fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
1802
+ q, latent_cache = fused_qkv_a_proj_out.split(
1803
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1804
+ )
1805
+ k_nope = latent_cache[..., : self.kv_lora_rank]
1806
+
1807
+ # overlap qk norm
1808
+ if self.alt_stream is not None and get_is_capture_mode():
1809
+ current_stream = torch.cuda.current_stream()
1810
+ self.alt_stream.wait_stream(current_stream)
1811
+ q = self.q_a_layernorm(q)
1812
+ with torch.cuda.stream(self.alt_stream):
1813
+ k_nope = self.kv_a_layernorm(k_nope)
1814
+ current_stream.wait_stream(self.alt_stream)
1815
+ else:
1816
+ if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
1817
+ q, k_nope = fused_rms_mxfp4_quant(
1818
+ q,
1819
+ self.q_a_layernorm.weight,
1820
+ self.q_a_layernorm.variance_epsilon,
1821
+ k_nope,
1822
+ self.kv_a_layernorm.weight,
1823
+ self.kv_a_layernorm.variance_epsilon,
1824
+ )
1825
+ else:
1826
+ q = self.q_a_layernorm(q)
1827
+ k_nope = self.kv_a_layernorm(k_nope)
1828
+
1829
+ q_lora = q.clone() # required for topk_indices
1830
+ k_nope = k_nope.unsqueeze(1)
1831
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
1832
+
1833
+ q_nope, q_pe = q.split(
1834
+ [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
1835
+ )
1836
+ k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
1837
+
1838
+ if self.use_deep_gemm_bmm:
1839
+ q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
1840
+ per_token_group_quant_mla_deep_gemm_masked_fp8(
1841
+ q_nope.transpose(0, 1)
1842
+ )
1843
+ )
1844
+ q_nope_out = q_nope.new_empty(
1845
+ (self.num_local_heads, aligned_m, self.kv_lora_rank)
1846
+ )
1847
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1848
+ (q_nope_val, q_nope_scale),
1849
+ (self.w_kc, self.w_scale_k),
1850
+ q_nope_out,
1851
+ masked_m,
1852
+ expected_m,
1853
+ )
1854
+ q_nope_out = q_nope_out[:, :expected_m, :]
1855
+ elif _is_hip:
1856
+ # TODO(haishaw): add bmm_fp8 to ROCm
1857
+ if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
1858
+ x = q_nope.transpose(0, 1)
1859
+ q_nope_out = torch.empty(
1860
+ x.shape[0],
1861
+ x.shape[1],
1862
+ self.w_kc.shape[2],
1863
+ device=x.device,
1864
+ dtype=torch.bfloat16,
1865
+ )
1866
+ batched_gemm_afp4wfp4_pre_quant(
1867
+ x,
1868
+ self.w_kc.transpose(-2, -1),
1869
+ self.w_scale_k.transpose(-2, -1),
1870
+ torch.bfloat16,
1871
+ q_nope_out,
1872
+ )
1873
+ else:
1874
+ q_nope_out = torch.bmm(
1875
+ q_nope.to(torch.bfloat16).transpose(0, 1),
1876
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
1877
+ )
1878
+ elif self.w_kc.dtype == torch.float8_e4m3fn:
1879
+ q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1880
+ q_nope.transpose(0, 1),
1881
+ zero_allocator.allocate(1),
1882
+ )
1883
+ q_nope_out = bmm_fp8(
1884
+ q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
1885
+ )
1886
+ else:
1887
+ q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
1888
+
1889
+ q_nope_out = q_nope_out.transpose(0, 1)
1890
+
1891
+ if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1892
+ not _use_aiter or not _is_gfx95_supported
1893
+ ):
1894
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1895
+
1896
+ # TODO: multi-stream indexer
1897
+ topk_indices = self.indexer(
1898
+ hidden_states, q_lora, positions, forward_batch, self.layer_id
1899
+ )
1900
+
1901
+ return (
1902
+ q_pe,
1903
+ k_pe,
1904
+ q_nope_out,
1905
+ k_nope,
1906
+ topk_indices,
1907
+ forward_batch,
1908
+ zero_allocator,
1909
+ positions,
1910
+ )
1911
+
1912
+ def forward_npu_sparse_core(
1913
+ self,
1914
+ q_pe,
1915
+ k_pe,
1916
+ q_nope_out,
1917
+ k_nope,
1918
+ topk_indices,
1919
+ forward_batch,
1920
+ zero_allocator,
1921
+ positions,
1922
+ ):
1923
+ attn_output = self.attn_mqa(
1924
+ q_nope_out.contiguous(),
1925
+ k_nope.contiguous(),
1926
+ k_nope.contiguous(),
1927
+ forward_batch,
1928
+ save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True,
1929
+ q_rope=q_pe.contiguous(),
1930
+ k_rope=k_pe.contiguous(),
1931
+ topk_indices=topk_indices,
1932
+ )
1933
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1934
+
1935
+ attn_bmm_output = torch.empty(
1936
+ (attn_output.shape[0], self.num_local_heads, self.v_head_dim),
1937
+ dtype=attn_output.dtype,
1938
+ device=attn_output.device,
1939
+ )
1940
+
1941
+ if not forward_batch.forward_mode.is_decode():
1942
+ attn_output = attn_output.transpose(0, 1)
1943
+ torch.bmm(
1944
+ attn_output,
1945
+ self.w_vc,
1946
+ out=attn_bmm_output.view(
1947
+ -1, self.num_local_heads, self.v_head_dim
1948
+ ).transpose(0, 1),
1949
+ )
1950
+ else:
1951
+ attn_output = attn_output.contiguous()
1952
+ torch.ops.npu.batch_matmul_transpose(
1953
+ attn_output, self.w_vc, attn_bmm_output
1954
+ )
1955
+
1956
+ attn_bmm_output = attn_bmm_output.reshape(
1957
+ -1, self.num_local_heads * self.v_head_dim
1958
+ )
1959
+
1960
+ output, _ = self.o_proj(attn_bmm_output)
1961
+ return output
1962
+
1390
1963
  def forward_absorb_fused_mla_rope_prepare(
1391
1964
  self,
1392
1965
  positions: torch.Tensor,
@@ -1678,9 +2251,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1678
2251
  latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
1679
2252
  self.attn_mha.layer_id
1680
2253
  )
1681
- latent_cache = latent_cache_buf[
1682
- forward_batch.prefix_chunk_kv_indices[i]
1683
- ].contiguous()
2254
+ latent_cache = (
2255
+ latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
2256
+ .contiguous()
2257
+ .to(q.dtype)
2258
+ )
1684
2259
 
1685
2260
  kv_a_normed, k_pe = latent_cache.split(
1686
2261
  [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
@@ -1710,6 +2285,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1710
2285
  tmp_lse = torch.empty_like(accum_lse)
1711
2286
  merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
1712
2287
  accum_output, accum_lse = tmp_output, tmp_lse
2288
+ del kv, k, v, output, lse, tmp_output, tmp_lse
1713
2289
 
1714
2290
  return accum_output
1715
2291
 
@@ -1864,10 +2440,23 @@ class DeepseekV2DecoderLayer(nn.Module):
1864
2440
  forward_batch: ForwardBatch,
1865
2441
  residual: Optional[torch.Tensor],
1866
2442
  zero_allocator: BumpAllocator,
2443
+ gemm_output_zero_allocator: BumpAllocator = None,
1867
2444
  ) -> torch.Tensor:
2445
+ quant_format = (
2446
+ "mxfp4"
2447
+ if _is_gfx95_supported
2448
+ and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
2449
+ and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
2450
+ is not None
2451
+ and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
2452
+ else ""
2453
+ )
1868
2454
 
1869
2455
  hidden_states, residual = self.layer_communicator.prepare_attn(
1870
- hidden_states, residual, forward_batch
2456
+ hidden_states,
2457
+ residual,
2458
+ forward_batch,
2459
+ quant_format,
1871
2460
  )
1872
2461
 
1873
2462
  hidden_states = self.self_attn(
@@ -1891,8 +2480,16 @@ class DeepseekV2DecoderLayer(nn.Module):
1891
2480
  use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
1892
2481
  forward_batch
1893
2482
  )
2483
+
2484
+ if isinstance(self.mlp, DeepseekV2MLP):
2485
+ gemm_output_zero_allocator = None
2486
+
1894
2487
  hidden_states = self.mlp(
1895
- hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
2488
+ hidden_states,
2489
+ forward_batch,
2490
+ should_allreduce_fusion,
2491
+ use_reduce_scatter,
2492
+ gemm_output_zero_allocator,
1896
2493
  )
1897
2494
 
1898
2495
  if should_allreduce_fusion:
@@ -2023,8 +2620,15 @@ class DeepseekV2Model(nn.Module):
2023
2620
  [
2024
2621
  "w13_weight",
2025
2622
  "w2_weight",
2026
- "w13_blockscale_swizzled",
2027
- "w2_blockscale_swizzled",
2623
+ # only for nvfp4
2624
+ *(
2625
+ [
2626
+ "w13_blockscale_swizzled",
2627
+ "w2_blockscale_swizzled",
2628
+ ]
2629
+ if hasattr(module, "w13_blockscale_swizzled")
2630
+ else []
2631
+ ),
2028
2632
  ]
2029
2633
  if isinstance(module, FusedMoE)
2030
2634
  else []
@@ -2036,6 +2640,37 @@ class DeepseekV2Model(nn.Module):
2036
2640
  else:
2037
2641
  self.norm = PPMissingLayer(return_tuple=True)
2038
2642
 
2643
+ self.gemm_output_zero_allocator_size = 0
2644
+ if (
2645
+ _use_aiter_gfx95
2646
+ and config.n_routed_experts == 256
2647
+ and self.embed_tokens.embedding_dim == 7168
2648
+ ):
2649
+ num_moe_layers = sum(
2650
+ [
2651
+ 1
2652
+ for i in range(len(self.layers))
2653
+ if isinstance(self.layers[i].mlp, DeepseekV2MoE)
2654
+ ]
2655
+ )
2656
+
2657
+ allocate_size = 0
2658
+ for i in range(len(self.layers)):
2659
+ if isinstance(self.layers[i].mlp, DeepseekV2MoE):
2660
+ allocate_size = self.layers[
2661
+ i
2662
+ ].mlp.shared_experts.gate_up_proj.output_size_per_partition
2663
+ break
2664
+
2665
+ self.gemm_output_zero_allocator_size = (
2666
+ get_dsv3_gemm_output_zero_allocator_size(
2667
+ config.n_routed_experts,
2668
+ num_moe_layers,
2669
+ allocate_size,
2670
+ self.embed_tokens.embedding_dim,
2671
+ )
2672
+ )
2673
+
2039
2674
  def get_input_embeddings(self) -> torch.Tensor:
2040
2675
  return self.embed_tokens
2041
2676
 
@@ -2055,6 +2690,21 @@ class DeepseekV2Model(nn.Module):
2055
2690
  device=device,
2056
2691
  )
2057
2692
 
2693
+ has_gemm_output_zero_allocator = hasattr(
2694
+ self, "gemm_output_zero_allocator_size"
2695
+ )
2696
+
2697
+ gemm_output_zero_allocator = (
2698
+ BumpAllocator(
2699
+ buffer_size=self.gemm_output_zero_allocator_size,
2700
+ dtype=torch.float32,
2701
+ device=device,
2702
+ )
2703
+ if has_gemm_output_zero_allocator
2704
+ and self.gemm_output_zero_allocator_size > 0
2705
+ else None
2706
+ )
2707
+
2058
2708
  if self.pp_group.is_first_rank:
2059
2709
  if input_embeds is None:
2060
2710
  hidden_states = self.embed_tokens(input_ids)
@@ -2081,7 +2731,12 @@ class DeepseekV2Model(nn.Module):
2081
2731
  with get_global_expert_distribution_recorder().with_current_layer(i):
2082
2732
  layer = self.layers[i]
2083
2733
  hidden_states, residual = layer(
2084
- positions, hidden_states, forward_batch, residual, zero_allocator
2734
+ positions,
2735
+ hidden_states,
2736
+ forward_batch,
2737
+ residual,
2738
+ zero_allocator,
2739
+ gemm_output_zero_allocator,
2085
2740
  )
2086
2741
 
2087
2742
  if normal_end_layer != self.end_layer:
@@ -2354,6 +3009,16 @@ class DeepseekV2ForCausalLM(nn.Module):
2354
3009
  w_kc, w_vc = w.unflatten(
2355
3010
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
2356
3011
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
3012
+
3013
+ if (
3014
+ _use_aiter_gfx95
3015
+ and self.quant_config is not None
3016
+ and self.quant_config.get_name() == "quark"
3017
+ ):
3018
+ w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
3019
+ quark_post_load_weights(self_attn, w, "mxfp4")
3020
+ )
3021
+
2357
3022
  if not use_deep_gemm_bmm:
2358
3023
  self_attn.w_kc = bind_or_assign(
2359
3024
  self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
@@ -2733,8 +3398,24 @@ class DeepseekV2ForCausalLM(nn.Module):
2733
3398
  )
2734
3399
 
2735
3400
 
3401
+ AttentionBackendRegistry.register("ascend", handle_attention_ascend)
3402
+ AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
3403
+ AttentionBackendRegistry.register("fa3", handle_attention_fa3)
3404
+ AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
3405
+ AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
3406
+ AttentionBackendRegistry.register("fa4", handle_attention_fa4)
3407
+ AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
3408
+ AttentionBackendRegistry.register("aiter", handle_attention_aiter)
3409
+ AttentionBackendRegistry.register("nsa", handle_attention_nsa)
3410
+ AttentionBackendRegistry.register("triton", handle_attention_triton)
3411
+
3412
+
2736
3413
  class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
2737
3414
  pass
2738
3415
 
2739
3416
 
2740
- EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
3417
+ class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
3418
+ pass
3419
+
3420
+
3421
+ EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]