sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
  # Adapted from:
16
16
  # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
17
17
  """Inference-only DeepseekV2 model."""
18
+ from __future__ import annotations
18
19
 
19
20
  import concurrent.futures
20
21
  import logging
@@ -25,9 +26,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
25
26
  import torch
26
27
  import torch.nn.functional as F
27
28
  from torch import nn
28
- from tqdm import tqdm
29
29
  from transformers import PretrainedConfig
30
30
 
31
+ from sglang.srt import single_batch_overlap
32
+ from sglang.srt.configs.model_config import (
33
+ get_nsa_index_head_dim,
34
+ get_nsa_index_n_heads,
35
+ get_nsa_index_topk,
36
+ is_deepseek_nsa,
37
+ )
38
+ from sglang.srt.debug_utils.dumper import dumper
31
39
  from sglang.srt.distributed import (
32
40
  get_moe_expert_parallel_world_size,
33
41
  get_pp_group,
@@ -43,6 +51,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
43
51
  from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
44
52
  from sglang.srt.layers.activation import SiluAndMul
45
53
  from sglang.srt.layers.amx_utils import PackWeightMethod
54
+ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
55
+ NPUFusedMLAPreprocess,
56
+ is_mla_preprocess_enabled,
57
+ )
58
+ from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
46
59
  from sglang.srt.layers.communicator import (
47
60
  LayerCommunicator,
48
61
  LayerScatterModes,
@@ -65,10 +78,11 @@ from sglang.srt.layers.moe import (
65
78
  get_deepep_mode,
66
79
  get_moe_a2a_backend,
67
80
  should_use_flashinfer_cutlass_moe_fp4_allgather,
81
+ should_use_flashinfer_trtllm_moe,
68
82
  )
69
83
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
70
84
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
71
- from sglang.srt.layers.moe.topk import TopK
85
+ from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
72
86
  from sglang.srt.layers.quantization import deep_gemm_wrapper
73
87
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
74
88
  from sglang.srt.layers.quantization.fp8_kernel import (
@@ -96,6 +110,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
96
110
  from sglang.srt.managers.schedule_batch import global_server_args_dict
97
111
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
98
112
  from sglang.srt.model_loader.weight_utils import default_weight_loader
113
+ from sglang.srt.single_batch_overlap import SboFlags
99
114
  from sglang.srt.two_batch_overlap import (
100
115
  MaybeTboDeepEPDispatcher,
101
116
  model_forward_maybe_tbo,
@@ -151,6 +166,7 @@ if _is_cuda:
151
166
  from sgl_kernel import (
152
167
  awq_dequantize,
153
168
  bmm_fp8,
169
+ concat_mla_k,
154
170
  dsv3_fused_a_gemm,
155
171
  dsv3_router_gemm,
156
172
  merge_state_v2,
@@ -158,16 +174,18 @@ if _is_cuda:
158
174
  elif _is_cpu and _is_cpu_amx_available:
159
175
  pass
160
176
  elif _is_hip:
177
+ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
178
+ decode_attention_fwd_grouped_rope,
179
+ )
161
180
  from sglang.srt.layers.quantization.awq_triton import (
162
181
  awq_dequantize_triton as awq_dequantize,
163
182
  )
183
+ elif _is_npu:
184
+ import custom_ops
185
+ import sgl_kernel_npu
186
+ import torch_npu
164
187
  else:
165
- from vllm._custom_ops import awq_dequantize
166
-
167
- if _is_hip:
168
- from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
169
- decode_attention_fwd_grouped_rope,
170
- )
188
+ pass
171
189
 
172
190
  _is_flashinfer_available = is_flashinfer_available()
173
191
  _is_sm100_supported = is_cuda() and is_sm100_supported()
@@ -175,6 +193,21 @@ _is_sm100_supported = is_cuda() and is_sm100_supported()
175
193
 
176
194
  logger = logging.getLogger(__name__)
177
195
 
196
+ FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
197
+ "fa3",
198
+ "nsa",
199
+ "flashinfer",
200
+ "cutlass_mla",
201
+ "trtllm_mla",
202
+ "ascend",
203
+ ]
204
+
205
+
206
+ def add_forward_absorb_core_attention_backend(backend_name):
207
+ if backend_name not in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
208
+ FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.append(backend_name)
209
+ logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
210
+
178
211
 
179
212
  class AttnForwardMethod(IntEnum):
180
213
  # Use multi-head attention
@@ -183,6 +216,9 @@ class AttnForwardMethod(IntEnum):
183
216
  # Use absorbed multi-latent attention
184
217
  MLA = auto()
185
218
 
219
+ # Use Deepseek V3.2 sparse multi-latent attention
220
+ NPU_MLA_SPARSE = auto()
221
+
186
222
  # Use multi-head attention, but with KV cache chunked.
187
223
  # This method can avoid OOM when prefix lengths are long.
188
224
  MHA_CHUNKED_KV = auto()
@@ -194,6 +230,146 @@ class AttnForwardMethod(IntEnum):
194
230
  MLA_FUSED_ROPE_CPU = auto()
195
231
 
196
232
 
233
+ def _dispatch_mla_subtype(attn, forward_batch):
234
+ if _is_hip:
235
+ if attn.rocm_fused_decode_mla and forward_batch.forward_mode.is_decode():
236
+ return AttnForwardMethod.MLA_FUSED_ROPE
237
+ else:
238
+ return AttnForwardMethod.MLA
239
+ else:
240
+ if hasattr(attn, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(attn):
241
+ return AttnForwardMethod.MLA_FUSED_ROPE_CPU
242
+ else:
243
+ return AttnForwardMethod.MLA
244
+
245
+
246
+ class AttentionBackendRegistry:
247
+ _handlers = {}
248
+
249
+ @classmethod
250
+ def register(cls, backend_name, handler_func):
251
+ cls._handlers[backend_name] = handler_func
252
+
253
+ @classmethod
254
+ def get_handler(cls, backend_name):
255
+ return cls._handlers.get(backend_name, cls._handlers.get("triton"))
256
+
257
+
258
+ def handle_attention_ascend(attn, forward_batch):
259
+ if (
260
+ forward_batch.forward_mode.is_extend()
261
+ and not forward_batch.forward_mode.is_target_verify()
262
+ and not forward_batch.forward_mode.is_draft_extend()
263
+ ):
264
+ if hasattr(attn, "indexer"):
265
+ return AttnForwardMethod.NPU_MLA_SPARSE
266
+ else:
267
+ return AttnForwardMethod.MHA
268
+ else:
269
+ if hasattr(attn, "indexer"):
270
+ return AttnForwardMethod.NPU_MLA_SPARSE
271
+ else:
272
+ return AttnForwardMethod.MLA
273
+
274
+
275
+ def _get_sum_extend_prefix_lens(forward_batch):
276
+ return (
277
+ sum(forward_batch.extend_prefix_lens_cpu)
278
+ if forward_batch.extend_prefix_lens_cpu is not None
279
+ else 0
280
+ )
281
+
282
+
283
+ def _is_extend_without_speculative(forward_batch):
284
+ return (
285
+ forward_batch.forward_mode.is_extend()
286
+ and not forward_batch.forward_mode.is_target_verify()
287
+ and not forward_batch.forward_mode.is_draft_extend()
288
+ )
289
+
290
+
291
+ def _handle_attention_backend(
292
+ attn: DeepseekV2AttentionMLA, forward_batch, backend_name
293
+ ):
294
+ sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
295
+ disable_ragged = (
296
+ backend_name in ["flashinfer", "flashmla"]
297
+ ) and attn.flashinfer_mla_disable_ragged
298
+
299
+ if (
300
+ not disable_ragged
301
+ and _is_extend_without_speculative(forward_batch)
302
+ and (
303
+ (
304
+ sum_extend_prefix_lens >= attn.chunked_prefix_cache_threshold
305
+ and not attn.disable_chunked_prefix_cache
306
+ )
307
+ or sum_extend_prefix_lens == 0
308
+ )
309
+ ):
310
+ return AttnForwardMethod.MHA_CHUNKED_KV
311
+ else:
312
+ return _dispatch_mla_subtype(attn, forward_batch)
313
+
314
+
315
+ def handle_attention_flashinfer(attn, forward_batch):
316
+ return _handle_attention_backend(attn, forward_batch, "flashinfer")
317
+
318
+
319
+ def handle_attention_fa3(attn, forward_batch):
320
+ return _handle_attention_backend(attn, forward_batch, "fa3")
321
+
322
+
323
+ def handle_attention_flashmla(attn, forward_batch):
324
+ return _handle_attention_backend(attn, forward_batch, "flashmla")
325
+
326
+
327
+ def handle_attention_cutlass_mla(attn, forward_batch):
328
+ return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
329
+
330
+
331
+ def handle_attention_fa4(attn, forward_batch):
332
+ # TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
333
+ return AttnForwardMethod.MHA_CHUNKED_KV
334
+
335
+
336
+ def handle_attention_trtllm_mla(attn, forward_batch):
337
+ sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
338
+ if _is_extend_without_speculative(forward_batch) and (
339
+ not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
340
+ ):
341
+ return AttnForwardMethod.MHA_CHUNKED_KV
342
+ else:
343
+ return _dispatch_mla_subtype(attn, forward_batch)
344
+
345
+
346
+ def handle_attention_aiter(attn, forward_batch):
347
+ if _is_extend_without_speculative(forward_batch):
348
+ if is_dp_attention_enabled():
349
+ if sum(forward_batch.extend_prefix_lens_cpu) == 0:
350
+ return AttnForwardMethod.MHA
351
+ else:
352
+ return AttnForwardMethod.MLA
353
+ else:
354
+ return AttnForwardMethod.MHA
355
+ else:
356
+ return AttnForwardMethod.MLA
357
+
358
+
359
+ def handle_attention_nsa(attn, forward_batch):
360
+ return AttnForwardMethod.MLA
361
+
362
+
363
+ def handle_attention_triton(attn, forward_batch):
364
+ if (
365
+ _is_extend_without_speculative(forward_batch)
366
+ and sum(forward_batch.extend_prefix_lens_cpu) == 0
367
+ ):
368
+ return AttnForwardMethod.MHA
369
+ else:
370
+ return _dispatch_mla_subtype(attn, forward_batch)
371
+
372
+
197
373
  class DeepseekV2MLP(nn.Module):
198
374
  def __init__(
199
375
  self,
@@ -246,7 +422,11 @@ class DeepseekV2MLP(nn.Module):
246
422
  if (self.tp_size == 1) and x.shape[0] == 0:
247
423
  return x
248
424
 
249
- if gemm_output_zero_allocator != None and x.shape[0] <= 256:
425
+ if (
426
+ gemm_output_zero_allocator is not None
427
+ and x.shape[0] <= 256
428
+ and self.gate_up_proj.weight.dtype == torch.uint8
429
+ ):
250
430
  y = gemm_output_zero_allocator.allocate(
251
431
  x.shape[0] * self.gate_up_proj.output_size_per_partition
252
432
  ).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
@@ -264,6 +444,7 @@ class MoEGate(nn.Module):
264
444
  def __init__(
265
445
  self,
266
446
  config,
447
+ quant_config,
267
448
  prefix: str = "",
268
449
  is_nextn: bool = False,
269
450
  ):
@@ -273,8 +454,15 @@ class MoEGate(nn.Module):
273
454
  torch.empty((config.n_routed_experts, config.hidden_size))
274
455
  )
275
456
  if config.topk_method == "noaux_tc":
457
+ correction_bias_dtype = (
458
+ torch.bfloat16
459
+ if quant_config is not None
460
+ and quant_config.get_name() == "modelopt_fp4"
461
+ and should_use_flashinfer_trtllm_moe()
462
+ else torch.float32
463
+ )
276
464
  self.e_score_correction_bias = nn.Parameter(
277
- torch.empty((config.n_routed_experts), dtype=torch.float32)
465
+ torch.empty((config.n_routed_experts), dtype=correction_bias_dtype)
278
466
  )
279
467
  else:
280
468
  self.e_score_correction_bias = None
@@ -295,11 +483,13 @@ class MoEGate(nn.Module):
295
483
  _is_cuda
296
484
  and hidden_states.shape[0] <= 16
297
485
  and hidden_states.shape[1] == 7168
298
- and self.weight.shape[0] == 256
486
+ and (self.weight.shape[0] == 256 or self.weight.shape[0] == 384)
299
487
  and _device_sm >= 90
300
488
  ):
301
489
  # router gemm output float32
302
- logits = dsv3_router_gemm(hidden_states, self.weight)
490
+ logits = dsv3_router_gemm(
491
+ hidden_states, self.weight, out_dtype=torch.float32
492
+ )
303
493
  elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
304
494
  logits = aiter_dsv3_router_gemm(
305
495
  hidden_states, self.weight, gemm_output_zero_allocator
@@ -347,7 +537,10 @@ class DeepseekV2MoE(nn.Module):
347
537
  )
348
538
 
349
539
  self.gate = MoEGate(
350
- config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
540
+ config=config,
541
+ quant_config=quant_config,
542
+ prefix=add_prefix("gate", prefix),
543
+ is_nextn=is_nextn,
351
544
  )
352
545
 
353
546
  self.experts = get_moe_impl_class(quant_config)(
@@ -372,9 +565,12 @@ class DeepseekV2MoE(nn.Module):
372
565
  num_fused_shared_experts=self.num_fused_shared_experts,
373
566
  topk_group=config.topk_group,
374
567
  correction_bias=self.gate.e_score_correction_bias,
568
+ quant_config=quant_config,
375
569
  routed_scaling_factor=self.routed_scaling_factor,
376
- apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
377
- force_topk=quant_config is None,
570
+ apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
571
+ # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
572
+ # and requires the output format to be standard. We use quant_config to determine the output format.
573
+ output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
378
574
  )
379
575
 
380
576
  self.shared_experts_is_int8 = False
@@ -638,7 +834,8 @@ class DeepseekV2MoE(nn.Module):
638
834
  if hidden_states.shape[0] > 0:
639
835
  # router_logits: (num_tokens, n_experts)
640
836
  router_logits = self.gate(hidden_states)
641
- shared_output = self._forward_shared_experts(hidden_states)
837
+ if not SboFlags.fuse_shared_experts_inside_sbo():
838
+ shared_output = self._forward_shared_experts(hidden_states)
642
839
  topk_weights, topk_idx, _ = self.topk(
643
840
  hidden_states,
644
841
  router_logits,
@@ -652,26 +849,36 @@ class DeepseekV2MoE(nn.Module):
652
849
  hidden_states.device
653
850
  )
654
851
 
655
- final_hidden_states = self.experts(
852
+ final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
656
853
  hidden_states=hidden_states,
657
854
  topk_idx=topk_idx,
658
855
  topk_weights=topk_weights,
659
856
  forward_batch=forward_batch,
857
+ # SBO args
858
+ forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
859
+ experts=self.experts,
860
+ alt_stream=self.alt_stream,
660
861
  )
862
+ if sbo_shared_output is not None:
863
+ shared_output = sbo_shared_output
661
864
 
662
865
  if shared_output is not None:
663
866
  x = shared_output
664
- x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
867
+ if self.experts.should_fuse_routed_scaling_factor_in_topk:
868
+ x.add_(final_hidden_states)
869
+ else:
870
+ x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
665
871
  final_hidden_states = x
666
872
  else:
667
- final_hidden_states *= self.routed_scaling_factor
873
+ if not self.experts.should_fuse_routed_scaling_factor_in_topk:
874
+ final_hidden_states *= self.routed_scaling_factor
668
875
 
669
876
  return final_hidden_states
670
877
 
671
878
  def _forward_shared_experts(
672
879
  self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
673
880
  ):
674
- if self.num_fused_shared_experts == 0:
881
+ if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0):
675
882
  return self.shared_experts(
676
883
  hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
677
884
  )
@@ -724,6 +931,7 @@ class DeepseekV2MoE(nn.Module):
724
931
  if self.ep_size > 1:
725
932
  self.experts.deepep_dispatcher.dispatch_a(
726
933
  hidden_states=state.hidden_states_mlp_input,
934
+ input_global_scale=None,
727
935
  topk_idx=state.pop("topk_idx_local"),
728
936
  topk_weights=state.pop("topk_weights_local"),
729
937
  forward_batch=state.forward_batch,
@@ -824,6 +1032,10 @@ class DeepseekV2AttentionMLA(nn.Module):
824
1032
  self.rope_theta = rope_theta
825
1033
  self.max_position_embeddings = max_position_embeddings
826
1034
 
1035
+ # NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
1036
+ if rope_scaling:
1037
+ rope_scaling["rope_type"] = "deepseek_yarn"
1038
+
827
1039
  # For tensor parallel attention
828
1040
  if self.q_lora_rank is not None:
829
1041
  self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
@@ -861,6 +1073,26 @@ class DeepseekV2AttentionMLA(nn.Module):
861
1073
  prefix=add_prefix("kv_a_proj_with_mqa", prefix),
862
1074
  )
863
1075
 
1076
+ self.use_nsa = is_deepseek_nsa(config)
1077
+ if self.use_nsa:
1078
+ self.indexer = Indexer(
1079
+ hidden_size=hidden_size,
1080
+ index_n_heads=get_nsa_index_n_heads(config),
1081
+ index_head_dim=get_nsa_index_head_dim(config),
1082
+ rope_head_dim=qk_rope_head_dim,
1083
+ index_topk=get_nsa_index_topk(config),
1084
+ q_lora_rank=q_lora_rank,
1085
+ max_position_embeddings=max_position_embeddings,
1086
+ rope_theta=rope_theta,
1087
+ scale_fmt="ue8m0",
1088
+ block_size=128,
1089
+ rope_scaling=rope_scaling,
1090
+ prefix=add_prefix("indexer", prefix),
1091
+ quant_config=quant_config,
1092
+ layer_id=layer_id,
1093
+ alt_stream=alt_stream,
1094
+ )
1095
+
864
1096
  self.kv_b_proj = ColumnParallelLinear(
865
1097
  self.kv_lora_rank,
866
1098
  self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -883,9 +1115,6 @@ class DeepseekV2AttentionMLA(nn.Module):
883
1115
  )
884
1116
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
885
1117
 
886
- if rope_scaling:
887
- rope_scaling["rope_type"] = "deepseek_yarn"
888
-
889
1118
  self.rotary_emb = get_rope_wrapper(
890
1119
  qk_rope_head_dim,
891
1120
  rotary_dim=qk_rope_head_dim,
@@ -1009,102 +1238,34 @@ class DeepseekV2AttentionMLA(nn.Module):
1009
1238
  self.weight_block_size = (
1010
1239
  self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
1011
1240
  )
1241
+ self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
1242
+ if self.is_mla_preprocess_enabled:
1243
+ assert (
1244
+ quant_config is None or quant_config.get_name() == "w8a8_int8"
1245
+ ), "MLA Preprocess only works with Unquant or W8A8Int8"
1246
+ self.mla_preprocess = None
1012
1247
 
1013
1248
  def dispatch_attn_forward_method(
1014
1249
  self, forward_batch: ForwardBatch
1015
1250
  ) -> AttnForwardMethod:
1016
- def _dispatch_mla_subtype():
1017
- if _is_hip:
1018
- if (
1019
- self.rocm_fused_decode_mla
1020
- and forward_batch.forward_mode.is_decode()
1021
- ):
1022
- return AttnForwardMethod.MLA_FUSED_ROPE
1023
- else:
1024
- return AttnForwardMethod.MLA
1025
- else:
1026
- if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
1027
- self
1028
- ):
1029
- return AttnForwardMethod.MLA_FUSED_ROPE_CPU
1030
- else:
1031
- return AttnForwardMethod.MLA
1032
-
1033
1251
  # Determine attention backend used by current forward batch
1034
1252
  if forward_batch.forward_mode.is_decode_or_idle():
1035
1253
  attention_backend = global_server_args_dict["decode_attention_backend"]
1254
+ elif (
1255
+ forward_batch.forward_mode.is_target_verify()
1256
+ or forward_batch.forward_mode.is_draft_extend()
1257
+ ):
1258
+ # Use the specified backend for speculative operations (both verify and draft extend)
1259
+ if global_server_args_dict["speculative_attention_mode"] == "decode":
1260
+ attention_backend = global_server_args_dict["decode_attention_backend"]
1261
+ else: # default to prefill
1262
+ attention_backend = global_server_args_dict["prefill_attention_backend"]
1036
1263
  else:
1037
1264
  attention_backend = global_server_args_dict["prefill_attention_backend"]
1038
1265
  self.current_attention_backend = attention_backend
1039
1266
 
1040
- if attention_backend == "ascend":
1041
- if (
1042
- forward_batch.forward_mode.is_extend()
1043
- and not forward_batch.forward_mode.is_target_verify()
1044
- and not forward_batch.forward_mode.is_draft_extend()
1045
- ):
1046
- return AttnForwardMethod.MHA
1047
- else:
1048
- return AttnForwardMethod.MLA
1049
- elif (
1050
- attention_backend == "flashinfer"
1051
- or attention_backend == "fa3"
1052
- or attention_backend == "flashmla"
1053
- or attention_backend == "trtllm_mla"
1054
- or attention_backend == "cutlass_mla"
1055
- ):
1056
- # Use MHA with chunked KV cache when prefilling on long sequences.
1057
- sum_extend_prefix_lens = (
1058
- sum(forward_batch.extend_prefix_lens_cpu)
1059
- if forward_batch.extend_prefix_lens_cpu is not None
1060
- else 0
1061
- )
1062
- # Flashinfer MLA: Do not absorb when enabling ragged prefill
1063
- disable_ragged = (
1064
- attention_backend == "flashinfer" or attention_backend == "flashmla"
1065
- ) and self.flashinfer_mla_disable_ragged
1066
- if (
1067
- not disable_ragged
1068
- and forward_batch.forward_mode.is_extend()
1069
- and not forward_batch.forward_mode.is_target_verify()
1070
- and not forward_batch.forward_mode.is_draft_extend()
1071
- and (
1072
- (
1073
- sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
1074
- and not self.disable_chunked_prefix_cache
1075
- )
1076
- or sum_extend_prefix_lens == 0
1077
- )
1078
- ):
1079
- return AttnForwardMethod.MHA_CHUNKED_KV
1080
- else:
1081
- return _dispatch_mla_subtype()
1082
- elif attention_backend == "aiter":
1083
- if (
1084
- forward_batch.forward_mode.is_extend()
1085
- and not forward_batch.forward_mode.is_target_verify()
1086
- and not forward_batch.forward_mode.is_draft_extend()
1087
- ):
1088
- if is_dp_attention_enabled():
1089
- if sum(forward_batch.extend_prefix_lens_cpu) == 0:
1090
- return AttnForwardMethod.MHA
1091
- else:
1092
- return AttnForwardMethod.MLA
1093
- else:
1094
- return AttnForwardMethod.MHA
1095
- else:
1096
- return AttnForwardMethod.MLA
1097
- else:
1098
- # Triton: Use normal computation for prefill and use weight absorption for extend/decode
1099
- if (
1100
- forward_batch.forward_mode.is_extend()
1101
- and not forward_batch.forward_mode.is_target_verify()
1102
- and not forward_batch.forward_mode.is_draft_extend()
1103
- and sum(forward_batch.extend_prefix_lens_cpu) == 0
1104
- ):
1105
- return AttnForwardMethod.MHA
1106
- else:
1107
- return _dispatch_mla_subtype()
1267
+ handler = AttentionBackendRegistry.get_handler(attention_backend)
1268
+ return handler(self, forward_batch)
1108
1269
 
1109
1270
  def op_prepare(self, state):
1110
1271
  state.attn_intermediate_state = self.forward_prepare(
@@ -1159,7 +1320,6 @@ class DeepseekV2AttentionMLA(nn.Module):
1159
1320
  return hidden_states, None, forward_batch, None
1160
1321
 
1161
1322
  attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
1162
-
1163
1323
  if attn_forward_method == AttnForwardMethod.MHA:
1164
1324
  inner_state = self.forward_normal_prepare(
1165
1325
  positions, hidden_states, forward_batch, zero_allocator
@@ -1169,7 +1329,30 @@ class DeepseekV2AttentionMLA(nn.Module):
1169
1329
  positions, hidden_states, forward_batch, zero_allocator
1170
1330
  )
1171
1331
  elif attn_forward_method == AttnForwardMethod.MLA:
1172
- inner_state = self.forward_absorb_prepare(
1332
+ if not self.is_mla_preprocess_enabled:
1333
+ inner_state = self.forward_absorb_prepare(
1334
+ positions, hidden_states, forward_batch, zero_allocator
1335
+ )
1336
+ else:
1337
+ # TODO(iforgetmyname): to be separated as a standalone func
1338
+ if self.mla_preprocess is None:
1339
+ self.mla_preprocess = NPUFusedMLAPreprocess(
1340
+ self.fused_qkv_a_proj_with_mqa,
1341
+ self.q_a_layernorm,
1342
+ self.kv_a_layernorm,
1343
+ self.q_b_proj,
1344
+ self.w_kc,
1345
+ self.rotary_emb,
1346
+ self.layer_id,
1347
+ self.num_local_heads,
1348
+ self.qk_nope_head_dim,
1349
+ self.qk_rope_head_dim,
1350
+ )
1351
+ inner_state = self.mla_preprocess.forward(
1352
+ positions, hidden_states, forward_batch, zero_allocator
1353
+ )
1354
+ elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
1355
+ inner_state = self.forward_npu_sparse_prepare(
1173
1356
  positions, hidden_states, forward_batch, zero_allocator
1174
1357
  )
1175
1358
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
@@ -1197,6 +1380,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1197
1380
  return self.forward_normal_chunked_kv_core(*inner_state)
1198
1381
  elif attn_forward_method == AttnForwardMethod.MLA:
1199
1382
  return self.forward_absorb_core(*inner_state)
1383
+ elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
1384
+ return self.forward_npu_sparse_core(*inner_state)
1200
1385
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
1201
1386
  return self.forward_absorb_fused_mla_rope_core(*inner_state)
1202
1387
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
@@ -1235,8 +1420,18 @@ class DeepseekV2AttentionMLA(nn.Module):
1235
1420
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1236
1421
  q[..., self.qk_nope_head_dim :] = q_pe
1237
1422
  k = torch.empty_like(q)
1238
- k[..., : self.qk_nope_head_dim] = k_nope
1239
- k[..., self.qk_nope_head_dim :] = k_pe
1423
+
1424
+ # Temporary for DeepSeek V3/R1 only, but can generalize if needed
1425
+ if (
1426
+ _is_cuda
1427
+ and (self.num_local_heads == 128)
1428
+ and (self.qk_nope_head_dim == 128)
1429
+ and (self.qk_rope_head_dim == 64)
1430
+ ):
1431
+ concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
1432
+ else:
1433
+ k[..., : self.qk_nope_head_dim] = k_nope
1434
+ k[..., self.qk_nope_head_dim :] = k_pe
1240
1435
 
1241
1436
  if not _is_npu:
1242
1437
  latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
@@ -1266,7 +1461,10 @@ class DeepseekV2AttentionMLA(nn.Module):
1266
1461
  """
1267
1462
  return (
1268
1463
  self.current_attention_backend == "trtllm_mla"
1269
- and forward_batch.forward_mode.is_decode_or_idle()
1464
+ and (
1465
+ forward_batch.forward_mode.is_decode_or_idle()
1466
+ or forward_batch.forward_mode.is_target_verify()
1467
+ )
1270
1468
  and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
1271
1469
  )
1272
1470
 
@@ -1279,6 +1477,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1279
1477
  ):
1280
1478
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1281
1479
 
1480
+ q_lora = None
1282
1481
  if self.q_lora_rank is not None:
1283
1482
  if (
1284
1483
  (not isinstance(hidden_states, tuple))
@@ -1317,6 +1516,10 @@ class DeepseekV2AttentionMLA(nn.Module):
1317
1516
  q = self.q_a_layernorm(q)
1318
1517
  k_nope = self.kv_a_layernorm(k_nope)
1319
1518
 
1519
+ # q_lora needed by indexer
1520
+ if self.use_nsa:
1521
+ q_lora = q
1522
+
1320
1523
  k_nope = k_nope.unsqueeze(1)
1321
1524
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
1322
1525
  else:
@@ -1382,28 +1585,50 @@ class DeepseekV2AttentionMLA(nn.Module):
1382
1585
  q_nope_out = q_nope_out.transpose(0, 1)
1383
1586
 
1384
1587
  if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1385
- not _use_aiter or not _is_gfx95_supported
1588
+ not _use_aiter or not _is_gfx95_supported or self.use_nsa
1386
1589
  ):
1387
1590
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1388
1591
 
1389
- return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1592
+ topk_indices = None
1593
+ if q_lora is not None:
1594
+ topk_indices = self.indexer(
1595
+ x=hidden_states,
1596
+ q_lora=q_lora,
1597
+ positions=positions,
1598
+ forward_batch=forward_batch,
1599
+ layer_id=self.layer_id,
1600
+ )
1601
+
1602
+ return (
1603
+ q_pe,
1604
+ k_pe,
1605
+ q_nope_out,
1606
+ k_nope,
1607
+ forward_batch,
1608
+ zero_allocator,
1609
+ positions,
1610
+ topk_indices,
1611
+ )
1390
1612
 
1391
1613
  def forward_absorb_core(
1392
- self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1614
+ self,
1615
+ q_pe,
1616
+ k_pe,
1617
+ q_nope_out,
1618
+ k_nope,
1619
+ forward_batch,
1620
+ zero_allocator,
1621
+ positions,
1622
+ topk_indices,
1393
1623
  ):
1394
- if (
1395
- self.current_attention_backend == "fa3"
1396
- or self.current_attention_backend == "flashinfer"
1397
- or self.current_attention_backend == "cutlass_mla"
1398
- or self.current_attention_backend == "trtllm_mla"
1399
- or self.current_attention_backend == "ascend"
1400
- ):
1624
+ if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
1401
1625
  extra_args = {}
1402
1626
  if self._fuse_rope_for_trtllm_mla(forward_batch):
1403
1627
  extra_args = {
1404
1628
  "cos_sin_cache": self.rotary_emb.cos_sin_cache,
1405
1629
  "is_neox": self.rotary_emb.is_neox_style,
1406
1630
  }
1631
+
1407
1632
  attn_output = self.attn_mqa(
1408
1633
  q_nope_out,
1409
1634
  k_nope,
@@ -1412,6 +1637,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1412
1637
  q_rope=q_pe,
1413
1638
  k_rope=k_pe,
1414
1639
  **extra_args,
1640
+ **(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
1415
1641
  )
1416
1642
  else:
1417
1643
  if _use_aiter_gfx95:
@@ -1431,7 +1657,13 @@ class DeepseekV2AttentionMLA(nn.Module):
1431
1657
  q = torch.cat([q_nope_out, q_pe], dim=-1)
1432
1658
  k = torch.cat([k_nope, k_pe], dim=-1)
1433
1659
 
1434
- attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1660
+ attn_output = self.attn_mqa(
1661
+ q,
1662
+ k,
1663
+ k_nope,
1664
+ forward_batch,
1665
+ **(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
1666
+ )
1435
1667
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1436
1668
 
1437
1669
  if self.use_deep_gemm_bmm:
@@ -1513,6 +1745,221 @@ class DeepseekV2AttentionMLA(nn.Module):
1513
1745
 
1514
1746
  return output
1515
1747
 
1748
+ def forward_npu_sparse_prepare(
1749
+ self,
1750
+ positions: torch.Tensor,
1751
+ hidden_states: torch.Tensor,
1752
+ forward_batch: ForwardBatch,
1753
+ zero_allocator: BumpAllocator,
1754
+ ):
1755
+ """
1756
+ Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare
1757
+ """
1758
+ if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode():
1759
+ if self.mla_preprocess is None:
1760
+ self.mla_preprocess = NPUFusedMLAPreprocess(
1761
+ self.fused_qkv_a_proj_with_mqa,
1762
+ self.q_a_layernorm,
1763
+ self.kv_a_layernorm,
1764
+ self.q_b_proj,
1765
+ self.w_kc,
1766
+ self.rotary_emb,
1767
+ self.layer_id,
1768
+ self.num_local_heads,
1769
+ self.qk_nope_head_dim,
1770
+ self.qk_rope_head_dim,
1771
+ )
1772
+ (
1773
+ q_pe,
1774
+ k_pe,
1775
+ q_nope_out,
1776
+ k_nope,
1777
+ forward_batch,
1778
+ zero_allocator,
1779
+ positions,
1780
+ ) = self.mla_preprocess.forward(
1781
+ positions, hidden_states, forward_batch, zero_allocator
1782
+ )
1783
+
1784
+ fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
1785
+ q, _ = fused_qkv_a_proj_out.split(
1786
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1787
+ )
1788
+ q_lora = self.q_a_layernorm(q)
1789
+ else:
1790
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1791
+
1792
+ if (
1793
+ (not isinstance(hidden_states, tuple))
1794
+ and hidden_states.shape[0] <= 16
1795
+ and self.use_min_latency_fused_a_gemm
1796
+ ):
1797
+ fused_qkv_a_proj_out = dsv3_fused_a_gemm(
1798
+ hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
1799
+ )
1800
+ else:
1801
+ fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
1802
+ q, latent_cache = fused_qkv_a_proj_out.split(
1803
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1804
+ )
1805
+ k_nope = latent_cache[..., : self.kv_lora_rank]
1806
+
1807
+ # overlap qk norm
1808
+ if self.alt_stream is not None and get_is_capture_mode():
1809
+ current_stream = torch.cuda.current_stream()
1810
+ self.alt_stream.wait_stream(current_stream)
1811
+ q = self.q_a_layernorm(q)
1812
+ with torch.cuda.stream(self.alt_stream):
1813
+ k_nope = self.kv_a_layernorm(k_nope)
1814
+ current_stream.wait_stream(self.alt_stream)
1815
+ else:
1816
+ if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
1817
+ q, k_nope = fused_rms_mxfp4_quant(
1818
+ q,
1819
+ self.q_a_layernorm.weight,
1820
+ self.q_a_layernorm.variance_epsilon,
1821
+ k_nope,
1822
+ self.kv_a_layernorm.weight,
1823
+ self.kv_a_layernorm.variance_epsilon,
1824
+ )
1825
+ else:
1826
+ q = self.q_a_layernorm(q)
1827
+ k_nope = self.kv_a_layernorm(k_nope)
1828
+
1829
+ q_lora = q.clone() # required for topk_indices
1830
+ k_nope = k_nope.unsqueeze(1)
1831
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
1832
+
1833
+ q_nope, q_pe = q.split(
1834
+ [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
1835
+ )
1836
+ k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
1837
+
1838
+ if self.use_deep_gemm_bmm:
1839
+ q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
1840
+ per_token_group_quant_mla_deep_gemm_masked_fp8(
1841
+ q_nope.transpose(0, 1)
1842
+ )
1843
+ )
1844
+ q_nope_out = q_nope.new_empty(
1845
+ (self.num_local_heads, aligned_m, self.kv_lora_rank)
1846
+ )
1847
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1848
+ (q_nope_val, q_nope_scale),
1849
+ (self.w_kc, self.w_scale_k),
1850
+ q_nope_out,
1851
+ masked_m,
1852
+ expected_m,
1853
+ )
1854
+ q_nope_out = q_nope_out[:, :expected_m, :]
1855
+ elif _is_hip:
1856
+ # TODO(haishaw): add bmm_fp8 to ROCm
1857
+ if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
1858
+ x = q_nope.transpose(0, 1)
1859
+ q_nope_out = torch.empty(
1860
+ x.shape[0],
1861
+ x.shape[1],
1862
+ self.w_kc.shape[2],
1863
+ device=x.device,
1864
+ dtype=torch.bfloat16,
1865
+ )
1866
+ batched_gemm_afp4wfp4_pre_quant(
1867
+ x,
1868
+ self.w_kc.transpose(-2, -1),
1869
+ self.w_scale_k.transpose(-2, -1),
1870
+ torch.bfloat16,
1871
+ q_nope_out,
1872
+ )
1873
+ else:
1874
+ q_nope_out = torch.bmm(
1875
+ q_nope.to(torch.bfloat16).transpose(0, 1),
1876
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
1877
+ )
1878
+ elif self.w_kc.dtype == torch.float8_e4m3fn:
1879
+ q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1880
+ q_nope.transpose(0, 1),
1881
+ zero_allocator.allocate(1),
1882
+ )
1883
+ q_nope_out = bmm_fp8(
1884
+ q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
1885
+ )
1886
+ else:
1887
+ q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
1888
+
1889
+ q_nope_out = q_nope_out.transpose(0, 1)
1890
+
1891
+ if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1892
+ not _use_aiter or not _is_gfx95_supported
1893
+ ):
1894
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1895
+
1896
+ # TODO: multi-stream indexer
1897
+ topk_indices = self.indexer(
1898
+ hidden_states, q_lora, positions, forward_batch, self.layer_id
1899
+ )
1900
+
1901
+ return (
1902
+ q_pe,
1903
+ k_pe,
1904
+ q_nope_out,
1905
+ k_nope,
1906
+ topk_indices,
1907
+ forward_batch,
1908
+ zero_allocator,
1909
+ positions,
1910
+ )
1911
+
1912
+ def forward_npu_sparse_core(
1913
+ self,
1914
+ q_pe,
1915
+ k_pe,
1916
+ q_nope_out,
1917
+ k_nope,
1918
+ topk_indices,
1919
+ forward_batch,
1920
+ zero_allocator,
1921
+ positions,
1922
+ ):
1923
+ attn_output = self.attn_mqa(
1924
+ q_nope_out.contiguous(),
1925
+ k_nope.contiguous(),
1926
+ k_nope.contiguous(),
1927
+ forward_batch,
1928
+ save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True,
1929
+ q_rope=q_pe.contiguous(),
1930
+ k_rope=k_pe.contiguous(),
1931
+ topk_indices=topk_indices,
1932
+ )
1933
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1934
+
1935
+ attn_bmm_output = torch.empty(
1936
+ (attn_output.shape[0], self.num_local_heads, self.v_head_dim),
1937
+ dtype=attn_output.dtype,
1938
+ device=attn_output.device,
1939
+ )
1940
+
1941
+ if not forward_batch.forward_mode.is_decode():
1942
+ attn_output = attn_output.transpose(0, 1)
1943
+ torch.bmm(
1944
+ attn_output,
1945
+ self.w_vc,
1946
+ out=attn_bmm_output.view(
1947
+ -1, self.num_local_heads, self.v_head_dim
1948
+ ).transpose(0, 1),
1949
+ )
1950
+ else:
1951
+ attn_output = attn_output.contiguous()
1952
+ torch.ops.npu.batch_matmul_transpose(
1953
+ attn_output, self.w_vc, attn_bmm_output
1954
+ )
1955
+
1956
+ attn_bmm_output = attn_bmm_output.reshape(
1957
+ -1, self.num_local_heads * self.v_head_dim
1958
+ )
1959
+
1960
+ output, _ = self.o_proj(attn_bmm_output)
1961
+ return output
1962
+
1516
1963
  def forward_absorb_fused_mla_rope_prepare(
1517
1964
  self,
1518
1965
  positions: torch.Tensor,
@@ -1838,6 +2285,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1838
2285
  tmp_lse = torch.empty_like(accum_lse)
1839
2286
  merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
1840
2287
  accum_output, accum_lse = tmp_output, tmp_lse
2288
+ del kv, k, v, output, lse, tmp_output, tmp_lse
1841
2289
 
1842
2290
  return accum_output
1843
2291
 
@@ -1994,11 +2442,13 @@ class DeepseekV2DecoderLayer(nn.Module):
1994
2442
  zero_allocator: BumpAllocator,
1995
2443
  gemm_output_zero_allocator: BumpAllocator = None,
1996
2444
  ) -> torch.Tensor:
1997
-
1998
2445
  quant_format = (
1999
2446
  "mxfp4"
2000
2447
  if _is_gfx95_supported
2001
- and self.self_attn.fused_qkv_a_proj_with_mqa.weight == torch.uint8
2448
+ and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
2449
+ and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
2450
+ is not None
2451
+ and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
2002
2452
  else ""
2003
2453
  )
2004
2454
 
@@ -2170,8 +2620,15 @@ class DeepseekV2Model(nn.Module):
2170
2620
  [
2171
2621
  "w13_weight",
2172
2622
  "w2_weight",
2173
- "w13_blockscale_swizzled",
2174
- "w2_blockscale_swizzled",
2623
+ # only for nvfp4
2624
+ *(
2625
+ [
2626
+ "w13_blockscale_swizzled",
2627
+ "w2_blockscale_swizzled",
2628
+ ]
2629
+ if hasattr(module, "w13_blockscale_swizzled")
2630
+ else []
2631
+ ),
2175
2632
  ]
2176
2633
  if isinstance(module, FusedMoE)
2177
2634
  else []
@@ -2553,7 +3010,11 @@ class DeepseekV2ForCausalLM(nn.Module):
2553
3010
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
2554
3011
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
2555
3012
 
2556
- if _use_aiter_gfx95 and self.quant_config.get_name() == "quark":
3013
+ if (
3014
+ _use_aiter_gfx95
3015
+ and self.quant_config is not None
3016
+ and self.quant_config.get_name() == "quark"
3017
+ ):
2557
3018
  w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
2558
3019
  quark_post_load_weights(self_attn, w, "mxfp4")
2559
3020
  )
@@ -2937,8 +3398,24 @@ class DeepseekV2ForCausalLM(nn.Module):
2937
3398
  )
2938
3399
 
2939
3400
 
3401
+ AttentionBackendRegistry.register("ascend", handle_attention_ascend)
3402
+ AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
3403
+ AttentionBackendRegistry.register("fa3", handle_attention_fa3)
3404
+ AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
3405
+ AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
3406
+ AttentionBackendRegistry.register("fa4", handle_attention_fa4)
3407
+ AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
3408
+ AttentionBackendRegistry.register("aiter", handle_attention_aiter)
3409
+ AttentionBackendRegistry.register("nsa", handle_attention_nsa)
3410
+ AttentionBackendRegistry.register("triton", handle_attention_triton)
3411
+
3412
+
2940
3413
  class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
2941
3414
  pass
2942
3415
 
2943
3416
 
2944
- EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
3417
+ class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
3418
+ pass
3419
+
3420
+
3421
+ EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]