sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
  # Adapted from:
16
16
  # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
17
17
  """Inference-only DeepseekV2 model."""
18
+ from __future__ import annotations
18
19
 
19
20
  import concurrent.futures
20
21
  import logging
@@ -25,9 +26,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
25
26
  import torch
26
27
  import torch.nn.functional as F
27
28
  from torch import nn
28
- from tqdm import tqdm
29
29
  from transformers import PretrainedConfig
30
30
 
31
+ from sglang.srt import single_batch_overlap
32
+ from sglang.srt.configs.model_config import (
33
+ get_nsa_index_head_dim,
34
+ get_nsa_index_n_heads,
35
+ get_nsa_index_topk,
36
+ is_deepseek_nsa,
37
+ )
38
+ from sglang.srt.debug_utils.dumper import dumper
31
39
  from sglang.srt.distributed import (
32
40
  get_moe_expert_parallel_world_size,
33
41
  get_pp_group,
@@ -43,6 +51,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
43
51
  from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
44
52
  from sglang.srt.layers.activation import SiluAndMul
45
53
  from sglang.srt.layers.amx_utils import PackWeightMethod
54
+ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
55
+ NPUFusedMLAPreprocess,
56
+ is_mla_preprocess_enabled,
57
+ )
58
+ from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
46
59
  from sglang.srt.layers.communicator import (
47
60
  LayerCommunicator,
48
61
  LayerScatterModes,
@@ -97,6 +110,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
97
110
  from sglang.srt.managers.schedule_batch import global_server_args_dict
98
111
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
99
112
  from sglang.srt.model_loader.weight_utils import default_weight_loader
113
+ from sglang.srt.single_batch_overlap import SboFlags
100
114
  from sglang.srt.two_batch_overlap import (
101
115
  MaybeTboDeepEPDispatcher,
102
116
  model_forward_maybe_tbo,
@@ -160,16 +174,18 @@ if _is_cuda:
160
174
  elif _is_cpu and _is_cpu_amx_available:
161
175
  pass
162
176
  elif _is_hip:
177
+ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
178
+ decode_attention_fwd_grouped_rope,
179
+ )
163
180
  from sglang.srt.layers.quantization.awq_triton import (
164
181
  awq_dequantize_triton as awq_dequantize,
165
182
  )
183
+ elif _is_npu:
184
+ import custom_ops
185
+ import sgl_kernel_npu
186
+ import torch_npu
166
187
  else:
167
- from vllm._custom_ops import awq_dequantize
168
-
169
- if _is_hip:
170
- from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
171
- decode_attention_fwd_grouped_rope,
172
- )
188
+ pass
173
189
 
174
190
  _is_flashinfer_available = is_flashinfer_available()
175
191
  _is_sm100_supported = is_cuda() and is_sm100_supported()
@@ -177,6 +193,21 @@ _is_sm100_supported = is_cuda() and is_sm100_supported()
177
193
 
178
194
  logger = logging.getLogger(__name__)
179
195
 
196
+ FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
197
+ "fa3",
198
+ "nsa",
199
+ "flashinfer",
200
+ "cutlass_mla",
201
+ "trtllm_mla",
202
+ "ascend",
203
+ ]
204
+
205
+
206
+ def add_forward_absorb_core_attention_backend(backend_name):
207
+ if backend_name not in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
208
+ FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.append(backend_name)
209
+ logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
210
+
180
211
 
181
212
  class AttnForwardMethod(IntEnum):
182
213
  # Use multi-head attention
@@ -185,6 +216,9 @@ class AttnForwardMethod(IntEnum):
185
216
  # Use absorbed multi-latent attention
186
217
  MLA = auto()
187
218
 
219
+ # Use Deepseek V3.2 sparse multi-latent attention
220
+ NPU_MLA_SPARSE = auto()
221
+
188
222
  # Use multi-head attention, but with KV cache chunked.
189
223
  # This method can avoid OOM when prefix lengths are long.
190
224
  MHA_CHUNKED_KV = auto()
@@ -196,6 +230,146 @@ class AttnForwardMethod(IntEnum):
196
230
  MLA_FUSED_ROPE_CPU = auto()
197
231
 
198
232
 
233
+ def _dispatch_mla_subtype(attn, forward_batch):
234
+ if _is_hip:
235
+ if attn.rocm_fused_decode_mla and forward_batch.forward_mode.is_decode():
236
+ return AttnForwardMethod.MLA_FUSED_ROPE
237
+ else:
238
+ return AttnForwardMethod.MLA
239
+ else:
240
+ if hasattr(attn, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(attn):
241
+ return AttnForwardMethod.MLA_FUSED_ROPE_CPU
242
+ else:
243
+ return AttnForwardMethod.MLA
244
+
245
+
246
+ class AttentionBackendRegistry:
247
+ _handlers = {}
248
+
249
+ @classmethod
250
+ def register(cls, backend_name, handler_func):
251
+ cls._handlers[backend_name] = handler_func
252
+
253
+ @classmethod
254
+ def get_handler(cls, backend_name):
255
+ return cls._handlers.get(backend_name, cls._handlers.get("triton"))
256
+
257
+
258
+ def handle_attention_ascend(attn, forward_batch):
259
+ if (
260
+ forward_batch.forward_mode.is_extend()
261
+ and not forward_batch.forward_mode.is_target_verify()
262
+ and not forward_batch.forward_mode.is_draft_extend()
263
+ ):
264
+ if hasattr(attn, "indexer"):
265
+ return AttnForwardMethod.NPU_MLA_SPARSE
266
+ else:
267
+ return AttnForwardMethod.MHA
268
+ else:
269
+ if hasattr(attn, "indexer"):
270
+ return AttnForwardMethod.NPU_MLA_SPARSE
271
+ else:
272
+ return AttnForwardMethod.MLA
273
+
274
+
275
+ def _get_sum_extend_prefix_lens(forward_batch):
276
+ return (
277
+ sum(forward_batch.extend_prefix_lens_cpu)
278
+ if forward_batch.extend_prefix_lens_cpu is not None
279
+ else 0
280
+ )
281
+
282
+
283
+ def _is_extend_without_speculative(forward_batch):
284
+ return (
285
+ forward_batch.forward_mode.is_extend()
286
+ and not forward_batch.forward_mode.is_target_verify()
287
+ and not forward_batch.forward_mode.is_draft_extend()
288
+ )
289
+
290
+
291
+ def _handle_attention_backend(
292
+ attn: DeepseekV2AttentionMLA, forward_batch, backend_name
293
+ ):
294
+ sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
295
+ disable_ragged = (
296
+ backend_name in ["flashinfer", "flashmla"]
297
+ ) and attn.flashinfer_mla_disable_ragged
298
+
299
+ if (
300
+ not disable_ragged
301
+ and _is_extend_without_speculative(forward_batch)
302
+ and (
303
+ (
304
+ sum_extend_prefix_lens >= attn.chunked_prefix_cache_threshold
305
+ and not attn.disable_chunked_prefix_cache
306
+ )
307
+ or sum_extend_prefix_lens == 0
308
+ )
309
+ ):
310
+ return AttnForwardMethod.MHA_CHUNKED_KV
311
+ else:
312
+ return _dispatch_mla_subtype(attn, forward_batch)
313
+
314
+
315
+ def handle_attention_flashinfer(attn, forward_batch):
316
+ return _handle_attention_backend(attn, forward_batch, "flashinfer")
317
+
318
+
319
+ def handle_attention_fa3(attn, forward_batch):
320
+ return _handle_attention_backend(attn, forward_batch, "fa3")
321
+
322
+
323
+ def handle_attention_flashmla(attn, forward_batch):
324
+ return _handle_attention_backend(attn, forward_batch, "flashmla")
325
+
326
+
327
+ def handle_attention_cutlass_mla(attn, forward_batch):
328
+ return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
329
+
330
+
331
+ def handle_attention_fa4(attn, forward_batch):
332
+ # TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
333
+ return AttnForwardMethod.MHA_CHUNKED_KV
334
+
335
+
336
+ def handle_attention_trtllm_mla(attn, forward_batch):
337
+ sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
338
+ if _is_extend_without_speculative(forward_batch) and (
339
+ not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
340
+ ):
341
+ return AttnForwardMethod.MHA_CHUNKED_KV
342
+ else:
343
+ return _dispatch_mla_subtype(attn, forward_batch)
344
+
345
+
346
+ def handle_attention_aiter(attn, forward_batch):
347
+ if _is_extend_without_speculative(forward_batch):
348
+ if is_dp_attention_enabled():
349
+ if sum(forward_batch.extend_prefix_lens_cpu) == 0:
350
+ return AttnForwardMethod.MHA
351
+ else:
352
+ return AttnForwardMethod.MLA
353
+ else:
354
+ return AttnForwardMethod.MHA
355
+ else:
356
+ return AttnForwardMethod.MLA
357
+
358
+
359
+ def handle_attention_nsa(attn, forward_batch):
360
+ return AttnForwardMethod.MLA
361
+
362
+
363
+ def handle_attention_triton(attn, forward_batch):
364
+ if (
365
+ _is_extend_without_speculative(forward_batch)
366
+ and sum(forward_batch.extend_prefix_lens_cpu) == 0
367
+ ):
368
+ return AttnForwardMethod.MHA
369
+ else:
370
+ return _dispatch_mla_subtype(attn, forward_batch)
371
+
372
+
199
373
  class DeepseekV2MLP(nn.Module):
200
374
  def __init__(
201
375
  self,
@@ -309,7 +483,7 @@ class MoEGate(nn.Module):
309
483
  _is_cuda
310
484
  and hidden_states.shape[0] <= 16
311
485
  and hidden_states.shape[1] == 7168
312
- and self.weight.shape[0] == 256
486
+ and (self.weight.shape[0] == 256 or self.weight.shape[0] == 384)
313
487
  and _device_sm >= 90
314
488
  ):
315
489
  # router gemm output float32
@@ -393,7 +567,7 @@ class DeepseekV2MoE(nn.Module):
393
567
  correction_bias=self.gate.e_score_correction_bias,
394
568
  quant_config=quant_config,
395
569
  routed_scaling_factor=self.routed_scaling_factor,
396
- apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
570
+ apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
397
571
  # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
398
572
  # and requires the output format to be standard. We use quant_config to determine the output format.
399
573
  output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
@@ -660,7 +834,8 @@ class DeepseekV2MoE(nn.Module):
660
834
  if hidden_states.shape[0] > 0:
661
835
  # router_logits: (num_tokens, n_experts)
662
836
  router_logits = self.gate(hidden_states)
663
- shared_output = self._forward_shared_experts(hidden_states)
837
+ if not SboFlags.fuse_shared_experts_inside_sbo():
838
+ shared_output = self._forward_shared_experts(hidden_states)
664
839
  topk_weights, topk_idx, _ = self.topk(
665
840
  hidden_states,
666
841
  router_logits,
@@ -674,22 +849,28 @@ class DeepseekV2MoE(nn.Module):
674
849
  hidden_states.device
675
850
  )
676
851
 
677
- final_hidden_states = self.experts(
852
+ final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
678
853
  hidden_states=hidden_states,
679
854
  topk_idx=topk_idx,
680
855
  topk_weights=topk_weights,
681
856
  forward_batch=forward_batch,
857
+ # SBO args
858
+ forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
859
+ experts=self.experts,
860
+ alt_stream=self.alt_stream,
682
861
  )
862
+ if sbo_shared_output is not None:
863
+ shared_output = sbo_shared_output
683
864
 
684
865
  if shared_output is not None:
685
866
  x = shared_output
686
- if self.experts.should_fuse_routed_scaling_factor_in_topk():
867
+ if self.experts.should_fuse_routed_scaling_factor_in_topk:
687
868
  x.add_(final_hidden_states)
688
869
  else:
689
870
  x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
690
871
  final_hidden_states = x
691
872
  else:
692
- if not self.experts.should_fuse_routed_scaling_factor_in_topk():
873
+ if not self.experts.should_fuse_routed_scaling_factor_in_topk:
693
874
  final_hidden_states *= self.routed_scaling_factor
694
875
 
695
876
  return final_hidden_states
@@ -697,7 +878,7 @@ class DeepseekV2MoE(nn.Module):
697
878
  def _forward_shared_experts(
698
879
  self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
699
880
  ):
700
- if self.num_fused_shared_experts == 0:
881
+ if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0):
701
882
  return self.shared_experts(
702
883
  hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
703
884
  )
@@ -750,6 +931,7 @@ class DeepseekV2MoE(nn.Module):
750
931
  if self.ep_size > 1:
751
932
  self.experts.deepep_dispatcher.dispatch_a(
752
933
  hidden_states=state.hidden_states_mlp_input,
934
+ input_global_scale=None,
753
935
  topk_idx=state.pop("topk_idx_local"),
754
936
  topk_weights=state.pop("topk_weights_local"),
755
937
  forward_batch=state.forward_batch,
@@ -850,6 +1032,10 @@ class DeepseekV2AttentionMLA(nn.Module):
850
1032
  self.rope_theta = rope_theta
851
1033
  self.max_position_embeddings = max_position_embeddings
852
1034
 
1035
+ # NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
1036
+ if rope_scaling:
1037
+ rope_scaling["rope_type"] = "deepseek_yarn"
1038
+
853
1039
  # For tensor parallel attention
854
1040
  if self.q_lora_rank is not None:
855
1041
  self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
@@ -887,6 +1073,26 @@ class DeepseekV2AttentionMLA(nn.Module):
887
1073
  prefix=add_prefix("kv_a_proj_with_mqa", prefix),
888
1074
  )
889
1075
 
1076
+ self.use_nsa = is_deepseek_nsa(config)
1077
+ if self.use_nsa:
1078
+ self.indexer = Indexer(
1079
+ hidden_size=hidden_size,
1080
+ index_n_heads=get_nsa_index_n_heads(config),
1081
+ index_head_dim=get_nsa_index_head_dim(config),
1082
+ rope_head_dim=qk_rope_head_dim,
1083
+ index_topk=get_nsa_index_topk(config),
1084
+ q_lora_rank=q_lora_rank,
1085
+ max_position_embeddings=max_position_embeddings,
1086
+ rope_theta=rope_theta,
1087
+ scale_fmt="ue8m0",
1088
+ block_size=128,
1089
+ rope_scaling=rope_scaling,
1090
+ prefix=add_prefix("indexer", prefix),
1091
+ quant_config=quant_config,
1092
+ layer_id=layer_id,
1093
+ alt_stream=alt_stream,
1094
+ )
1095
+
890
1096
  self.kv_b_proj = ColumnParallelLinear(
891
1097
  self.kv_lora_rank,
892
1098
  self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -909,9 +1115,6 @@ class DeepseekV2AttentionMLA(nn.Module):
909
1115
  )
910
1116
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
911
1117
 
912
- if rope_scaling:
913
- rope_scaling["rope_type"] = "deepseek_yarn"
914
-
915
1118
  self.rotary_emb = get_rope_wrapper(
916
1119
  qk_rope_head_dim,
917
1120
  rotary_dim=qk_rope_head_dim,
@@ -1035,27 +1238,16 @@ class DeepseekV2AttentionMLA(nn.Module):
1035
1238
  self.weight_block_size = (
1036
1239
  self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
1037
1240
  )
1241
+ self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
1242
+ if self.is_mla_preprocess_enabled:
1243
+ assert (
1244
+ quant_config is None or quant_config.get_name() == "w8a8_int8"
1245
+ ), "MLA Preprocess only works with Unquant or W8A8Int8"
1246
+ self.mla_preprocess = None
1038
1247
 
1039
1248
  def dispatch_attn_forward_method(
1040
1249
  self, forward_batch: ForwardBatch
1041
1250
  ) -> AttnForwardMethod:
1042
- def _dispatch_mla_subtype():
1043
- if _is_hip:
1044
- if (
1045
- self.rocm_fused_decode_mla
1046
- and forward_batch.forward_mode.is_decode()
1047
- ):
1048
- return AttnForwardMethod.MLA_FUSED_ROPE
1049
- else:
1050
- return AttnForwardMethod.MLA
1051
- else:
1052
- if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
1053
- self
1054
- ):
1055
- return AttnForwardMethod.MLA_FUSED_ROPE_CPU
1056
- else:
1057
- return AttnForwardMethod.MLA
1058
-
1059
1251
  # Determine attention backend used by current forward batch
1060
1252
  if forward_batch.forward_mode.is_decode_or_idle():
1061
1253
  attention_backend = global_server_args_dict["decode_attention_backend"]
@@ -1072,109 +1264,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1072
1264
  attention_backend = global_server_args_dict["prefill_attention_backend"]
1073
1265
  self.current_attention_backend = attention_backend
1074
1266
 
1075
- if attention_backend == "ascend":
1076
- if (
1077
- forward_batch.forward_mode.is_extend()
1078
- and not forward_batch.forward_mode.is_target_verify()
1079
- and not forward_batch.forward_mode.is_draft_extend()
1080
- ):
1081
- return AttnForwardMethod.MHA
1082
- else:
1083
- return AttnForwardMethod.MLA
1084
- elif (
1085
- attention_backend == "flashinfer"
1086
- or attention_backend == "fa3"
1087
- or attention_backend == "flashmla"
1088
- or attention_backend == "cutlass_mla"
1089
- ):
1090
- # Use MHA with chunked KV cache when prefilling on long sequences.
1091
- sum_extend_prefix_lens = (
1092
- sum(forward_batch.extend_prefix_lens_cpu)
1093
- if forward_batch.extend_prefix_lens_cpu is not None
1094
- else 0
1095
- )
1096
- # Flashinfer MLA: Do not absorb when enabling ragged prefill
1097
- disable_ragged = (
1098
- attention_backend == "flashinfer" or attention_backend == "flashmla"
1099
- ) and self.flashinfer_mla_disable_ragged
1100
-
1101
- original_mode = getattr(forward_batch, "_original_forward_mode", None)
1102
- if (
1103
- not disable_ragged
1104
- and forward_batch.forward_mode.is_extend()
1105
- and not forward_batch.forward_mode.is_target_verify()
1106
- and not forward_batch.forward_mode.is_draft_extend()
1107
- and (
1108
- (
1109
- sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
1110
- and not self.disable_chunked_prefix_cache
1111
- )
1112
- or sum_extend_prefix_lens == 0
1113
- )
1114
- # TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for
1115
- # dp case. Redirect to mla kernel as a workaround.
1116
- # Tracked by https://github.com/sgl-project/sglang/issues/9806.
1117
- and not (
1118
- original_mode is not None
1119
- and original_mode.is_decode()
1120
- and is_sm100_supported()
1121
- and self.current_attention_backend in ("cutlass_mla", "flashinfer")
1122
- )
1123
- ):
1124
- return AttnForwardMethod.MHA_CHUNKED_KV
1125
- else:
1126
- return _dispatch_mla_subtype()
1127
- elif attention_backend == "trtllm_mla":
1128
- original_mode = getattr(forward_batch, "_original_forward_mode", None)
1129
- if (
1130
- original_mode is not None
1131
- and original_mode.is_decode()
1132
- and is_sm100_supported()
1133
- ):
1134
- return _dispatch_mla_subtype()
1135
-
1136
- sum_extend_prefix_lens = (
1137
- sum(forward_batch.extend_prefix_lens_cpu)
1138
- if forward_batch.extend_prefix_lens_cpu is not None
1139
- else 0
1140
- )
1141
- if (
1142
- forward_batch.forward_mode.is_extend()
1143
- and not forward_batch.forward_mode.is_target_verify()
1144
- and not forward_batch.forward_mode.is_draft_extend()
1145
- and (
1146
- not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
1147
- )
1148
- ):
1149
- return AttnForwardMethod.MHA_CHUNKED_KV
1150
- else:
1151
- return _dispatch_mla_subtype()
1152
- elif attention_backend == "aiter":
1153
- if (
1154
- forward_batch.forward_mode.is_extend()
1155
- and not forward_batch.forward_mode.is_target_verify()
1156
- and not forward_batch.forward_mode.is_draft_extend()
1157
- ):
1158
- if is_dp_attention_enabled():
1159
- if sum(forward_batch.extend_prefix_lens_cpu) == 0:
1160
- return AttnForwardMethod.MHA
1161
- else:
1162
- return AttnForwardMethod.MLA
1163
- else:
1164
- return AttnForwardMethod.MHA
1165
- else:
1166
- return AttnForwardMethod.MLA
1167
- else:
1168
- # Triton: Use normal computation for prefill and use weight absorption for extend/decode
1169
- if (
1170
- forward_batch.forward_mode.is_extend()
1171
- and not forward_batch.forward_mode.is_target_verify()
1172
- and not forward_batch.forward_mode.is_draft_extend()
1173
- and sum(forward_batch.extend_prefix_lens_cpu) == 0
1174
- ):
1175
- return AttnForwardMethod.MHA
1176
- else:
1177
- return _dispatch_mla_subtype()
1267
+ handler = AttentionBackendRegistry.get_handler(attention_backend)
1268
+ return handler(self, forward_batch)
1178
1269
 
1179
1270
  def op_prepare(self, state):
1180
1271
  state.attn_intermediate_state = self.forward_prepare(
@@ -1229,7 +1320,6 @@ class DeepseekV2AttentionMLA(nn.Module):
1229
1320
  return hidden_states, None, forward_batch, None
1230
1321
 
1231
1322
  attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
1232
-
1233
1323
  if attn_forward_method == AttnForwardMethod.MHA:
1234
1324
  inner_state = self.forward_normal_prepare(
1235
1325
  positions, hidden_states, forward_batch, zero_allocator
@@ -1239,7 +1329,30 @@ class DeepseekV2AttentionMLA(nn.Module):
1239
1329
  positions, hidden_states, forward_batch, zero_allocator
1240
1330
  )
1241
1331
  elif attn_forward_method == AttnForwardMethod.MLA:
1242
- inner_state = self.forward_absorb_prepare(
1332
+ if not self.is_mla_preprocess_enabled:
1333
+ inner_state = self.forward_absorb_prepare(
1334
+ positions, hidden_states, forward_batch, zero_allocator
1335
+ )
1336
+ else:
1337
+ # TODO(iforgetmyname): to be separated as a standalone func
1338
+ if self.mla_preprocess is None:
1339
+ self.mla_preprocess = NPUFusedMLAPreprocess(
1340
+ self.fused_qkv_a_proj_with_mqa,
1341
+ self.q_a_layernorm,
1342
+ self.kv_a_layernorm,
1343
+ self.q_b_proj,
1344
+ self.w_kc,
1345
+ self.rotary_emb,
1346
+ self.layer_id,
1347
+ self.num_local_heads,
1348
+ self.qk_nope_head_dim,
1349
+ self.qk_rope_head_dim,
1350
+ )
1351
+ inner_state = self.mla_preprocess.forward(
1352
+ positions, hidden_states, forward_batch, zero_allocator
1353
+ )
1354
+ elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
1355
+ inner_state = self.forward_npu_sparse_prepare(
1243
1356
  positions, hidden_states, forward_batch, zero_allocator
1244
1357
  )
1245
1358
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
@@ -1267,6 +1380,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1267
1380
  return self.forward_normal_chunked_kv_core(*inner_state)
1268
1381
  elif attn_forward_method == AttnForwardMethod.MLA:
1269
1382
  return self.forward_absorb_core(*inner_state)
1383
+ elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
1384
+ return self.forward_npu_sparse_core(*inner_state)
1270
1385
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
1271
1386
  return self.forward_absorb_fused_mla_rope_core(*inner_state)
1272
1387
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
@@ -1346,7 +1461,10 @@ class DeepseekV2AttentionMLA(nn.Module):
1346
1461
  """
1347
1462
  return (
1348
1463
  self.current_attention_backend == "trtllm_mla"
1349
- and forward_batch.forward_mode.is_decode_or_idle()
1464
+ and (
1465
+ forward_batch.forward_mode.is_decode_or_idle()
1466
+ or forward_batch.forward_mode.is_target_verify()
1467
+ )
1350
1468
  and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
1351
1469
  )
1352
1470
 
@@ -1359,6 +1477,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1359
1477
  ):
1360
1478
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1361
1479
 
1480
+ q_lora = None
1362
1481
  if self.q_lora_rank is not None:
1363
1482
  if (
1364
1483
  (not isinstance(hidden_states, tuple))
@@ -1397,6 +1516,10 @@ class DeepseekV2AttentionMLA(nn.Module):
1397
1516
  q = self.q_a_layernorm(q)
1398
1517
  k_nope = self.kv_a_layernorm(k_nope)
1399
1518
 
1519
+ # q_lora needed by indexer
1520
+ if self.use_nsa:
1521
+ q_lora = q
1522
+
1400
1523
  k_nope = k_nope.unsqueeze(1)
1401
1524
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
1402
1525
  else:
@@ -1462,28 +1585,50 @@ class DeepseekV2AttentionMLA(nn.Module):
1462
1585
  q_nope_out = q_nope_out.transpose(0, 1)
1463
1586
 
1464
1587
  if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1465
- not _use_aiter or not _is_gfx95_supported
1588
+ not _use_aiter or not _is_gfx95_supported or self.use_nsa
1466
1589
  ):
1467
1590
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1468
1591
 
1469
- return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1592
+ topk_indices = None
1593
+ if q_lora is not None:
1594
+ topk_indices = self.indexer(
1595
+ x=hidden_states,
1596
+ q_lora=q_lora,
1597
+ positions=positions,
1598
+ forward_batch=forward_batch,
1599
+ layer_id=self.layer_id,
1600
+ )
1601
+
1602
+ return (
1603
+ q_pe,
1604
+ k_pe,
1605
+ q_nope_out,
1606
+ k_nope,
1607
+ forward_batch,
1608
+ zero_allocator,
1609
+ positions,
1610
+ topk_indices,
1611
+ )
1470
1612
 
1471
1613
  def forward_absorb_core(
1472
- self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1614
+ self,
1615
+ q_pe,
1616
+ k_pe,
1617
+ q_nope_out,
1618
+ k_nope,
1619
+ forward_batch,
1620
+ zero_allocator,
1621
+ positions,
1622
+ topk_indices,
1473
1623
  ):
1474
- if (
1475
- self.current_attention_backend == "fa3"
1476
- or self.current_attention_backend == "flashinfer"
1477
- or self.current_attention_backend == "cutlass_mla"
1478
- or self.current_attention_backend == "trtllm_mla"
1479
- or self.current_attention_backend == "ascend"
1480
- ):
1624
+ if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
1481
1625
  extra_args = {}
1482
1626
  if self._fuse_rope_for_trtllm_mla(forward_batch):
1483
1627
  extra_args = {
1484
1628
  "cos_sin_cache": self.rotary_emb.cos_sin_cache,
1485
1629
  "is_neox": self.rotary_emb.is_neox_style,
1486
1630
  }
1631
+
1487
1632
  attn_output = self.attn_mqa(
1488
1633
  q_nope_out,
1489
1634
  k_nope,
@@ -1492,6 +1637,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1492
1637
  q_rope=q_pe,
1493
1638
  k_rope=k_pe,
1494
1639
  **extra_args,
1640
+ **(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
1495
1641
  )
1496
1642
  else:
1497
1643
  if _use_aiter_gfx95:
@@ -1511,7 +1657,13 @@ class DeepseekV2AttentionMLA(nn.Module):
1511
1657
  q = torch.cat([q_nope_out, q_pe], dim=-1)
1512
1658
  k = torch.cat([k_nope, k_pe], dim=-1)
1513
1659
 
1514
- attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1660
+ attn_output = self.attn_mqa(
1661
+ q,
1662
+ k,
1663
+ k_nope,
1664
+ forward_batch,
1665
+ **(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
1666
+ )
1515
1667
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1516
1668
 
1517
1669
  if self.use_deep_gemm_bmm:
@@ -1593,6 +1745,221 @@ class DeepseekV2AttentionMLA(nn.Module):
1593
1745
 
1594
1746
  return output
1595
1747
 
1748
+ def forward_npu_sparse_prepare(
1749
+ self,
1750
+ positions: torch.Tensor,
1751
+ hidden_states: torch.Tensor,
1752
+ forward_batch: ForwardBatch,
1753
+ zero_allocator: BumpAllocator,
1754
+ ):
1755
+ """
1756
+ Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare
1757
+ """
1758
+ if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode():
1759
+ if self.mla_preprocess is None:
1760
+ self.mla_preprocess = NPUFusedMLAPreprocess(
1761
+ self.fused_qkv_a_proj_with_mqa,
1762
+ self.q_a_layernorm,
1763
+ self.kv_a_layernorm,
1764
+ self.q_b_proj,
1765
+ self.w_kc,
1766
+ self.rotary_emb,
1767
+ self.layer_id,
1768
+ self.num_local_heads,
1769
+ self.qk_nope_head_dim,
1770
+ self.qk_rope_head_dim,
1771
+ )
1772
+ (
1773
+ q_pe,
1774
+ k_pe,
1775
+ q_nope_out,
1776
+ k_nope,
1777
+ forward_batch,
1778
+ zero_allocator,
1779
+ positions,
1780
+ ) = self.mla_preprocess.forward(
1781
+ positions, hidden_states, forward_batch, zero_allocator
1782
+ )
1783
+
1784
+ fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
1785
+ q, _ = fused_qkv_a_proj_out.split(
1786
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1787
+ )
1788
+ q_lora = self.q_a_layernorm(q)
1789
+ else:
1790
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1791
+
1792
+ if (
1793
+ (not isinstance(hidden_states, tuple))
1794
+ and hidden_states.shape[0] <= 16
1795
+ and self.use_min_latency_fused_a_gemm
1796
+ ):
1797
+ fused_qkv_a_proj_out = dsv3_fused_a_gemm(
1798
+ hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
1799
+ )
1800
+ else:
1801
+ fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
1802
+ q, latent_cache = fused_qkv_a_proj_out.split(
1803
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1804
+ )
1805
+ k_nope = latent_cache[..., : self.kv_lora_rank]
1806
+
1807
+ # overlap qk norm
1808
+ if self.alt_stream is not None and get_is_capture_mode():
1809
+ current_stream = torch.cuda.current_stream()
1810
+ self.alt_stream.wait_stream(current_stream)
1811
+ q = self.q_a_layernorm(q)
1812
+ with torch.cuda.stream(self.alt_stream):
1813
+ k_nope = self.kv_a_layernorm(k_nope)
1814
+ current_stream.wait_stream(self.alt_stream)
1815
+ else:
1816
+ if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
1817
+ q, k_nope = fused_rms_mxfp4_quant(
1818
+ q,
1819
+ self.q_a_layernorm.weight,
1820
+ self.q_a_layernorm.variance_epsilon,
1821
+ k_nope,
1822
+ self.kv_a_layernorm.weight,
1823
+ self.kv_a_layernorm.variance_epsilon,
1824
+ )
1825
+ else:
1826
+ q = self.q_a_layernorm(q)
1827
+ k_nope = self.kv_a_layernorm(k_nope)
1828
+
1829
+ q_lora = q.clone() # required for topk_indices
1830
+ k_nope = k_nope.unsqueeze(1)
1831
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
1832
+
1833
+ q_nope, q_pe = q.split(
1834
+ [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
1835
+ )
1836
+ k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
1837
+
1838
+ if self.use_deep_gemm_bmm:
1839
+ q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
1840
+ per_token_group_quant_mla_deep_gemm_masked_fp8(
1841
+ q_nope.transpose(0, 1)
1842
+ )
1843
+ )
1844
+ q_nope_out = q_nope.new_empty(
1845
+ (self.num_local_heads, aligned_m, self.kv_lora_rank)
1846
+ )
1847
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1848
+ (q_nope_val, q_nope_scale),
1849
+ (self.w_kc, self.w_scale_k),
1850
+ q_nope_out,
1851
+ masked_m,
1852
+ expected_m,
1853
+ )
1854
+ q_nope_out = q_nope_out[:, :expected_m, :]
1855
+ elif _is_hip:
1856
+ # TODO(haishaw): add bmm_fp8 to ROCm
1857
+ if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
1858
+ x = q_nope.transpose(0, 1)
1859
+ q_nope_out = torch.empty(
1860
+ x.shape[0],
1861
+ x.shape[1],
1862
+ self.w_kc.shape[2],
1863
+ device=x.device,
1864
+ dtype=torch.bfloat16,
1865
+ )
1866
+ batched_gemm_afp4wfp4_pre_quant(
1867
+ x,
1868
+ self.w_kc.transpose(-2, -1),
1869
+ self.w_scale_k.transpose(-2, -1),
1870
+ torch.bfloat16,
1871
+ q_nope_out,
1872
+ )
1873
+ else:
1874
+ q_nope_out = torch.bmm(
1875
+ q_nope.to(torch.bfloat16).transpose(0, 1),
1876
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
1877
+ )
1878
+ elif self.w_kc.dtype == torch.float8_e4m3fn:
1879
+ q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1880
+ q_nope.transpose(0, 1),
1881
+ zero_allocator.allocate(1),
1882
+ )
1883
+ q_nope_out = bmm_fp8(
1884
+ q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
1885
+ )
1886
+ else:
1887
+ q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
1888
+
1889
+ q_nope_out = q_nope_out.transpose(0, 1)
1890
+
1891
+ if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1892
+ not _use_aiter or not _is_gfx95_supported
1893
+ ):
1894
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1895
+
1896
+ # TODO: multi-stream indexer
1897
+ topk_indices = self.indexer(
1898
+ hidden_states, q_lora, positions, forward_batch, self.layer_id
1899
+ )
1900
+
1901
+ return (
1902
+ q_pe,
1903
+ k_pe,
1904
+ q_nope_out,
1905
+ k_nope,
1906
+ topk_indices,
1907
+ forward_batch,
1908
+ zero_allocator,
1909
+ positions,
1910
+ )
1911
+
1912
+ def forward_npu_sparse_core(
1913
+ self,
1914
+ q_pe,
1915
+ k_pe,
1916
+ q_nope_out,
1917
+ k_nope,
1918
+ topk_indices,
1919
+ forward_batch,
1920
+ zero_allocator,
1921
+ positions,
1922
+ ):
1923
+ attn_output = self.attn_mqa(
1924
+ q_nope_out.contiguous(),
1925
+ k_nope.contiguous(),
1926
+ k_nope.contiguous(),
1927
+ forward_batch,
1928
+ save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True,
1929
+ q_rope=q_pe.contiguous(),
1930
+ k_rope=k_pe.contiguous(),
1931
+ topk_indices=topk_indices,
1932
+ )
1933
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1934
+
1935
+ attn_bmm_output = torch.empty(
1936
+ (attn_output.shape[0], self.num_local_heads, self.v_head_dim),
1937
+ dtype=attn_output.dtype,
1938
+ device=attn_output.device,
1939
+ )
1940
+
1941
+ if not forward_batch.forward_mode.is_decode():
1942
+ attn_output = attn_output.transpose(0, 1)
1943
+ torch.bmm(
1944
+ attn_output,
1945
+ self.w_vc,
1946
+ out=attn_bmm_output.view(
1947
+ -1, self.num_local_heads, self.v_head_dim
1948
+ ).transpose(0, 1),
1949
+ )
1950
+ else:
1951
+ attn_output = attn_output.contiguous()
1952
+ torch.ops.npu.batch_matmul_transpose(
1953
+ attn_output, self.w_vc, attn_bmm_output
1954
+ )
1955
+
1956
+ attn_bmm_output = attn_bmm_output.reshape(
1957
+ -1, self.num_local_heads * self.v_head_dim
1958
+ )
1959
+
1960
+ output, _ = self.o_proj(attn_bmm_output)
1961
+ return output
1962
+
1596
1963
  def forward_absorb_fused_mla_rope_prepare(
1597
1964
  self,
1598
1965
  positions: torch.Tensor,
@@ -1918,6 +2285,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1918
2285
  tmp_lse = torch.empty_like(accum_lse)
1919
2286
  merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
1920
2287
  accum_output, accum_lse = tmp_output, tmp_lse
2288
+ del kv, k, v, output, lse, tmp_output, tmp_lse
1921
2289
 
1922
2290
  return accum_output
1923
2291
 
@@ -2074,7 +2442,6 @@ class DeepseekV2DecoderLayer(nn.Module):
2074
2442
  zero_allocator: BumpAllocator,
2075
2443
  gemm_output_zero_allocator: BumpAllocator = None,
2076
2444
  ) -> torch.Tensor:
2077
-
2078
2445
  quant_format = (
2079
2446
  "mxfp4"
2080
2447
  if _is_gfx95_supported
@@ -3031,8 +3398,24 @@ class DeepseekV2ForCausalLM(nn.Module):
3031
3398
  )
3032
3399
 
3033
3400
 
3401
+ AttentionBackendRegistry.register("ascend", handle_attention_ascend)
3402
+ AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
3403
+ AttentionBackendRegistry.register("fa3", handle_attention_fa3)
3404
+ AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
3405
+ AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
3406
+ AttentionBackendRegistry.register("fa4", handle_attention_fa4)
3407
+ AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
3408
+ AttentionBackendRegistry.register("aiter", handle_attention_aiter)
3409
+ AttentionBackendRegistry.register("nsa", handle_attention_nsa)
3410
+ AttentionBackendRegistry.register("triton", handle_attention_triton)
3411
+
3412
+
3034
3413
  class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
3035
3414
  pass
3036
3415
 
3037
3416
 
3038
- EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
3417
+ class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
3418
+ pass
3419
+
3420
+
3421
+ EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]