sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
  # Adapted from:
16
16
  # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
17
17
  """Inference-only DeepseekV2 model."""
18
+ from __future__ import annotations
18
19
 
19
20
  import concurrent.futures
20
21
  import logging
@@ -24,10 +25,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
24
25
 
25
26
  import torch
26
27
  import torch.nn.functional as F
28
+ import tqdm
27
29
  from torch import nn
28
- from tqdm import tqdm
29
30
  from transformers import PretrainedConfig
30
31
 
32
+ from sglang.srt.configs.model_config import (
33
+ get_nsa_index_head_dim,
34
+ get_nsa_index_n_heads,
35
+ get_nsa_index_topk,
36
+ is_deepseek_nsa,
37
+ )
31
38
  from sglang.srt.distributed import (
32
39
  get_moe_expert_parallel_world_size,
33
40
  get_pp_group,
@@ -38,11 +45,18 @@ from sglang.srt.distributed import (
38
45
  from sglang.srt.distributed.device_communicators.pynccl_allocator import (
39
46
  use_symmetric_memory,
40
47
  )
48
+ from sglang.srt.environ import envs
41
49
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
42
50
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
43
51
  from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
52
+ from sglang.srt.layers import deep_gemm_wrapper
44
53
  from sglang.srt.layers.activation import SiluAndMul
45
54
  from sglang.srt.layers.amx_utils import PackWeightMethod
55
+ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
56
+ NPUFusedMLAPreprocess,
57
+ is_mla_preprocess_enabled,
58
+ )
59
+ from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
46
60
  from sglang.srt.layers.communicator import (
47
61
  LayerCommunicator,
48
62
  LayerScatterModes,
@@ -62,7 +76,6 @@ from sglang.srt.layers.linear import (
62
76
  )
63
77
  from sglang.srt.layers.logits_processor import LogitsProcessor
64
78
  from sglang.srt.layers.moe import (
65
- get_deepep_mode,
66
79
  get_moe_a2a_backend,
67
80
  should_use_flashinfer_cutlass_moe_fp4_allgather,
68
81
  should_use_flashinfer_trtllm_moe,
@@ -70,8 +83,12 @@ from sglang.srt.layers.moe import (
70
83
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
71
84
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
72
85
  from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
73
- from sglang.srt.layers.quantization import deep_gemm_wrapper
86
+ from sglang.srt.layers.quantization import CompressedTensorsConfig
74
87
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
88
+ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
89
+ CompressedTensorsWNA16AMXEPMoEMethod,
90
+ )
91
+ from sglang.srt.layers.quantization.fp8 import Fp8Config
75
92
  from sglang.srt.layers.quantization.fp8_kernel import (
76
93
  is_fp8_fnuz,
77
94
  per_tensor_quant_mla_fp8,
@@ -82,7 +99,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
82
99
  block_quant_to_tensor_quant,
83
100
  channel_quant_to_tensor_quant,
84
101
  normalize_e4m3fn_to_e4m3fnuz,
102
+ quant_weight_ue8m0,
85
103
  requant_weight_ue8m0_inplace,
104
+ transform_scale_ue8m0_inplace,
86
105
  )
87
106
  from sglang.srt.layers.quantization.int8_utils import (
88
107
  block_dequant as int8_block_dequant,
@@ -94,13 +113,12 @@ from sglang.srt.layers.vocab_parallel_embedding import (
94
113
  ParallelLMHead,
95
114
  VocabParallelEmbedding,
96
115
  )
97
- from sglang.srt.managers.schedule_batch import global_server_args_dict
98
116
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
99
117
  from sglang.srt.model_loader.weight_utils import default_weight_loader
100
- from sglang.srt.two_batch_overlap import (
101
- MaybeTboDeepEPDispatcher,
102
- model_forward_maybe_tbo,
103
- )
118
+ from sglang.srt.server_args import get_global_server_args
119
+ from sglang.srt.single_batch_overlap import SboFlags
120
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
121
+ from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
104
122
  from sglang.srt.utils import (
105
123
  BumpAllocator,
106
124
  LazyValue,
@@ -117,6 +135,7 @@ from sglang.srt.utils import (
117
135
  is_hip,
118
136
  is_non_idle_and_non_empty,
119
137
  is_npu,
138
+ is_nvidia_cublas_cu12_version_ge_12_9,
120
139
  is_sm100_supported,
121
140
  log_info_on_rank0,
122
141
  make_layers,
@@ -160,24 +179,54 @@ if _is_cuda:
160
179
  elif _is_cpu and _is_cpu_amx_available:
161
180
  pass
162
181
  elif _is_hip:
182
+ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
183
+ decode_attention_fwd_grouped_rope,
184
+ )
163
185
  from sglang.srt.layers.quantization.awq_triton import (
164
186
  awq_dequantize_triton as awq_dequantize,
165
187
  )
166
- else:
167
- from vllm._custom_ops import awq_dequantize
188
+ elif _is_npu:
189
+ import custom_ops # noqa: F401
190
+ import sgl_kernel_npu # noqa: F401
191
+ import torch_npu # noqa: F401
168
192
 
169
- if _is_hip:
170
- from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
171
- decode_attention_fwd_grouped_rope,
193
+ from sglang.srt.layers.quantization.awq_triton import (
194
+ awq_dequantize_decomposition as awq_dequantize,
172
195
  )
196
+ else:
197
+ pass
173
198
 
174
199
  _is_flashinfer_available = is_flashinfer_available()
175
200
  _is_sm100_supported = is_cuda() and is_sm100_supported()
176
-
201
+ _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
177
202
 
178
203
  logger = logging.getLogger(__name__)
179
204
 
180
205
 
206
+ def enable_nextn_moe_bf16_cast_to_fp8(quant_config):
207
+ return (
208
+ quant_config is not None
209
+ and quant_config.get_name() == "modelopt_fp4"
210
+ and get_moe_a2a_backend().is_deepep()
211
+ )
212
+
213
+
214
+ FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
215
+ "fa3",
216
+ "nsa",
217
+ "flashinfer",
218
+ "cutlass_mla",
219
+ "trtllm_mla",
220
+ "ascend",
221
+ ]
222
+
223
+
224
+ def add_forward_absorb_core_attention_backend(backend_name):
225
+ if backend_name not in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
226
+ FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.append(backend_name)
227
+ logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
228
+
229
+
181
230
  class AttnForwardMethod(IntEnum):
182
231
  # Use multi-head attention
183
232
  MHA = auto()
@@ -185,6 +234,9 @@ class AttnForwardMethod(IntEnum):
185
234
  # Use absorbed multi-latent attention
186
235
  MLA = auto()
187
236
 
237
+ # Use Deepseek V3.2 sparse multi-latent attention
238
+ NPU_MLA_SPARSE = auto()
239
+
188
240
  # Use multi-head attention, but with KV cache chunked.
189
241
  # This method can avoid OOM when prefix lengths are long.
190
242
  MHA_CHUNKED_KV = auto()
@@ -196,6 +248,146 @@ class AttnForwardMethod(IntEnum):
196
248
  MLA_FUSED_ROPE_CPU = auto()
197
249
 
198
250
 
251
+ def _dispatch_mla_subtype(attn, forward_batch):
252
+ if _is_hip:
253
+ if attn.rocm_fused_decode_mla and forward_batch.forward_mode.is_decode():
254
+ return AttnForwardMethod.MLA_FUSED_ROPE
255
+ else:
256
+ return AttnForwardMethod.MLA
257
+ else:
258
+ if hasattr(attn, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(attn):
259
+ return AttnForwardMethod.MLA_FUSED_ROPE_CPU
260
+ else:
261
+ return AttnForwardMethod.MLA
262
+
263
+
264
+ class AttentionBackendRegistry:
265
+ _handlers = {}
266
+
267
+ @classmethod
268
+ def register(cls, backend_name, handler_func):
269
+ cls._handlers[backend_name] = handler_func
270
+
271
+ @classmethod
272
+ def get_handler(cls, backend_name):
273
+ return cls._handlers.get(backend_name, cls._handlers.get("triton"))
274
+
275
+
276
+ def handle_attention_ascend(attn, forward_batch):
277
+ if (
278
+ forward_batch.forward_mode.is_extend()
279
+ and not forward_batch.forward_mode.is_target_verify()
280
+ and not forward_batch.forward_mode.is_draft_extend()
281
+ ):
282
+ if hasattr(attn, "indexer"):
283
+ return AttnForwardMethod.NPU_MLA_SPARSE
284
+ else:
285
+ return AttnForwardMethod.MHA
286
+ else:
287
+ if hasattr(attn, "indexer"):
288
+ return AttnForwardMethod.NPU_MLA_SPARSE
289
+ else:
290
+ return AttnForwardMethod.MLA
291
+
292
+
293
+ def _get_sum_extend_prefix_lens(forward_batch):
294
+ return (
295
+ sum(forward_batch.extend_prefix_lens_cpu)
296
+ if forward_batch.extend_prefix_lens_cpu is not None
297
+ else 0
298
+ )
299
+
300
+
301
+ def _is_extend_without_speculative(forward_batch):
302
+ return (
303
+ forward_batch.forward_mode.is_extend()
304
+ and not forward_batch.forward_mode.is_target_verify()
305
+ and not forward_batch.forward_mode.is_draft_extend()
306
+ )
307
+
308
+
309
+ def _handle_attention_backend(
310
+ attn: DeepseekV2AttentionMLA, forward_batch, backend_name
311
+ ):
312
+ sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
313
+ disable_ragged = (
314
+ backend_name in ["flashinfer", "flashmla"]
315
+ ) and attn.flashinfer_mla_disable_ragged
316
+
317
+ if (
318
+ not disable_ragged
319
+ and _is_extend_without_speculative(forward_batch)
320
+ and (
321
+ (
322
+ sum_extend_prefix_lens >= attn.chunked_prefix_cache_threshold
323
+ and not attn.disable_chunked_prefix_cache
324
+ )
325
+ or sum_extend_prefix_lens == 0
326
+ )
327
+ ):
328
+ return AttnForwardMethod.MHA_CHUNKED_KV
329
+ else:
330
+ return _dispatch_mla_subtype(attn, forward_batch)
331
+
332
+
333
+ def handle_attention_flashinfer(attn, forward_batch):
334
+ return _handle_attention_backend(attn, forward_batch, "flashinfer")
335
+
336
+
337
+ def handle_attention_fa3(attn, forward_batch):
338
+ return _handle_attention_backend(attn, forward_batch, "fa3")
339
+
340
+
341
+ def handle_attention_flashmla(attn, forward_batch):
342
+ return _handle_attention_backend(attn, forward_batch, "flashmla")
343
+
344
+
345
+ def handle_attention_cutlass_mla(attn, forward_batch):
346
+ return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
347
+
348
+
349
+ def handle_attention_fa4(attn, forward_batch):
350
+ # TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
351
+ return AttnForwardMethod.MHA_CHUNKED_KV
352
+
353
+
354
+ def handle_attention_trtllm_mla(attn, forward_batch):
355
+ sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
356
+ if _is_extend_without_speculative(forward_batch) and (
357
+ not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
358
+ ):
359
+ return AttnForwardMethod.MHA_CHUNKED_KV
360
+ else:
361
+ return _dispatch_mla_subtype(attn, forward_batch)
362
+
363
+
364
+ def handle_attention_aiter(attn, forward_batch):
365
+ if _is_extend_without_speculative(forward_batch):
366
+ if is_dp_attention_enabled():
367
+ if sum(forward_batch.extend_prefix_lens_cpu) == 0:
368
+ return AttnForwardMethod.MHA
369
+ else:
370
+ return AttnForwardMethod.MLA
371
+ else:
372
+ return AttnForwardMethod.MHA
373
+ else:
374
+ return AttnForwardMethod.MLA
375
+
376
+
377
+ def handle_attention_nsa(attn, forward_batch):
378
+ return AttnForwardMethod.MLA
379
+
380
+
381
+ def handle_attention_triton(attn, forward_batch):
382
+ if (
383
+ _is_extend_without_speculative(forward_batch)
384
+ and sum(forward_batch.extend_prefix_lens_cpu) == 0
385
+ ):
386
+ return AttnForwardMethod.MHA
387
+ else:
388
+ return _dispatch_mla_subtype(attn, forward_batch)
389
+
390
+
199
391
  class DeepseekV2MLP(nn.Module):
200
392
  def __init__(
201
393
  self,
@@ -309,7 +501,7 @@ class MoEGate(nn.Module):
309
501
  _is_cuda
310
502
  and hidden_states.shape[0] <= 16
311
503
  and hidden_states.shape[1] == 7168
312
- and self.weight.shape[0] == 256
504
+ and (self.weight.shape[0] == 256 or self.weight.shape[0] == 384)
313
505
  and _device_sm >= 90
314
506
  ):
315
507
  # router gemm output float32
@@ -343,12 +535,13 @@ class DeepseekV2MoE(nn.Module):
343
535
  self.n_shared_experts = config.n_shared_experts
344
536
  self.num_fused_shared_experts = (
345
537
  0
346
- if global_server_args_dict["disable_shared_experts_fusion"]
538
+ if get_global_server_args().disable_shared_experts_fusion
347
539
  else config.n_shared_experts
348
540
  )
349
541
  self.config = config
350
542
  self.layer_id = layer_id
351
543
  self.alt_stream = alt_stream
544
+ self.is_nextn = is_nextn
352
545
 
353
546
  if self.tp_size > config.n_routed_experts:
354
547
  raise ValueError(
@@ -372,7 +565,7 @@ class DeepseekV2MoE(nn.Module):
372
565
  self.experts = get_moe_impl_class(quant_config)(
373
566
  num_experts=config.n_routed_experts
374
567
  + self.num_fused_shared_experts
375
- + global_server_args_dict["ep_num_redundant_experts"],
568
+ + get_global_server_args().ep_num_redundant_experts,
376
569
  num_fused_shared_experts=self.num_fused_shared_experts,
377
570
  top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
378
571
  hidden_size=config.hidden_size,
@@ -393,7 +586,7 @@ class DeepseekV2MoE(nn.Module):
393
586
  correction_bias=self.gate.e_score_correction_bias,
394
587
  quant_config=quant_config,
395
588
  routed_scaling_factor=self.routed_scaling_factor,
396
- apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
589
+ apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
397
590
  # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
398
591
  # and requires the output format to be standard. We use quant_config to determine the output format.
399
592
  output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
@@ -415,6 +608,7 @@ class DeepseekV2MoE(nn.Module):
415
608
  **(
416
609
  dict(tp_rank=0, tp_size=1)
417
610
  if get_moe_a2a_backend().is_deepep()
611
+ or get_moe_a2a_backend().is_mooncake()
418
612
  or should_use_flashinfer_cutlass_moe_fp4_allgather()
419
613
  else {}
420
614
  ),
@@ -445,12 +639,12 @@ class DeepseekV2MoE(nn.Module):
445
639
 
446
640
  self.top_k = config.num_experts_per_tok
447
641
 
448
- if get_moe_a2a_backend().is_deepep():
642
+ if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
449
643
  # TODO: we will support tp < ep in the future
450
644
  self.ep_size = get_moe_expert_parallel_world_size()
451
645
  self.num_experts = (
452
646
  config.n_routed_experts
453
- + global_server_args_dict["ep_num_redundant_experts"]
647
+ + get_global_server_args().ep_num_redundant_experts
454
648
  )
455
649
  self.renormalize = config.norm_topk_prob
456
650
  self.topk_group = config.topk_group
@@ -461,20 +655,10 @@ class DeepseekV2MoE(nn.Module):
461
655
  else None
462
656
  )
463
657
 
464
- self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
465
- group=parallel_state.get_tp_group().device_group,
466
- router_topk=self.top_k,
467
- permute_fusion=True,
468
- num_experts=self.num_experts,
469
- num_local_experts=config.n_routed_experts // self.tp_size,
470
- hidden_size=config.hidden_size,
471
- params_dtype=config.torch_dtype,
472
- deepep_mode=get_deepep_mode(),
473
- async_finish=True,
474
- return_recv_hook=True,
475
- )
476
-
477
- self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
658
+ self._enable_a2a_moe = (
659
+ get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
660
+ )
661
+ self._fuse_shared_experts_inside_sbo = SboFlags.fuse_shared_experts_inside_sbo()
478
662
 
479
663
  def get_moe_weights(self):
480
664
  return [
@@ -491,7 +675,7 @@ class DeepseekV2MoE(nn.Module):
491
675
  use_reduce_scatter: bool = False,
492
676
  gemm_output_zero_allocator: BumpAllocator = None,
493
677
  ) -> torch.Tensor:
494
- if not self._enable_deepep_moe:
678
+ if not self._enable_a2a_moe:
495
679
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
496
680
  if (
497
681
  self.alt_stream is not None
@@ -533,6 +717,10 @@ class DeepseekV2MoE(nn.Module):
533
717
  # router_logits: (num_tokens, n_experts)
534
718
  router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
535
719
  topk_output = self.topk(hidden_states, router_logits)
720
+ if isinstance(
721
+ self.experts.quant_method, CompressedTensorsWNA16AMXEPMoEMethod
722
+ ):
723
+ topk_output.topk_weights.mul_(self.routed_scaling_factor)
536
724
  final_hidden_states = self.experts(hidden_states, topk_output)
537
725
  if not _is_cuda:
538
726
  final_hidden_states *= self.routed_scaling_factor
@@ -566,9 +754,10 @@ class DeepseekV2MoE(nn.Module):
566
754
  return self.forward_cpu(hidden_states, should_allreduce_fusion)
567
755
 
568
756
  if hidden_states.shape[0] > 0:
569
- shared_output = self._forward_shared_experts(
570
- hidden_states, gemm_output_zero_allocator
571
- )
757
+ if not self._fuse_shared_experts_inside_sbo:
758
+ shared_output = self._forward_shared_experts(
759
+ hidden_states, gemm_output_zero_allocator
760
+ )
572
761
  # router_logits: (num_tokens, n_experts)
573
762
  router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
574
763
  topk_output = self.topk(hidden_states, router_logits)
@@ -576,7 +765,27 @@ class DeepseekV2MoE(nn.Module):
576
765
  shared_output = None
577
766
  topk_output = self.topk.empty_topk_output(hidden_states.device)
578
767
 
579
- final_hidden_states = self.experts(hidden_states, topk_output)
768
+ if self._fuse_shared_experts_inside_sbo:
769
+ shared_output = None
770
+
771
+ def _forward_shared_experts_and_put_results():
772
+ nonlocal shared_output
773
+ shared_output = self._forward_shared_experts(
774
+ hidden_states, gemm_output_zero_allocator
775
+ )
776
+
777
+ final_hidden_states = self.experts(
778
+ hidden_states,
779
+ topk_output,
780
+ **(
781
+ dict(
782
+ forward_shared_experts=_forward_shared_experts_and_put_results,
783
+ alt_stream=self.alt_stream,
784
+ )
785
+ if self._fuse_shared_experts_inside_sbo
786
+ else {}
787
+ ),
788
+ )
580
789
  if not _is_cuda and not _use_aiter:
581
790
  # fused in biased_grouped_topk so we can skip here
582
791
  final_hidden_states *= self.routed_scaling_factor
@@ -660,8 +869,9 @@ class DeepseekV2MoE(nn.Module):
660
869
  if hidden_states.shape[0] > 0:
661
870
  # router_logits: (num_tokens, n_experts)
662
871
  router_logits = self.gate(hidden_states)
663
- shared_output = self._forward_shared_experts(hidden_states)
664
- topk_weights, topk_idx, _ = self.topk(
872
+ if not self._fuse_shared_experts_inside_sbo:
873
+ shared_output = self._forward_shared_experts(hidden_states)
874
+ topk_output = self.topk(
665
875
  hidden_states,
666
876
  router_logits,
667
877
  num_token_non_padded=forward_batch.num_token_non_padded,
@@ -670,26 +880,39 @@ class DeepseekV2MoE(nn.Module):
670
880
  ),
671
881
  )
672
882
  else:
673
- topk_weights, topk_idx, _ = self.topk.empty_topk_output(
674
- hidden_states.device
675
- )
883
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
884
+
885
+ if self._fuse_shared_experts_inside_sbo:
886
+ shared_output = None
887
+
888
+ def _forward_shared_experts_and_put_results():
889
+ nonlocal shared_output
890
+ shared_output = self._forward_shared_experts(hidden_states)
676
891
 
677
892
  final_hidden_states = self.experts(
678
893
  hidden_states=hidden_states,
679
- topk_idx=topk_idx,
680
- topk_weights=topk_weights,
681
- forward_batch=forward_batch,
894
+ topk_output=topk_output,
895
+ **(
896
+ dict(
897
+ forward_shared_experts=_forward_shared_experts_and_put_results,
898
+ alt_stream=self.alt_stream,
899
+ # SBO is not yet implemented for NextN
900
+ disable_sbo=self.is_nextn,
901
+ )
902
+ if self._fuse_shared_experts_inside_sbo
903
+ else {}
904
+ ),
682
905
  )
683
906
 
684
907
  if shared_output is not None:
685
908
  x = shared_output
686
- if self.experts.should_fuse_routed_scaling_factor_in_topk():
909
+ if self.experts.should_fuse_routed_scaling_factor_in_topk:
687
910
  x.add_(final_hidden_states)
688
911
  else:
689
912
  x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
690
913
  final_hidden_states = x
691
914
  else:
692
- if not self.experts.should_fuse_routed_scaling_factor_in_topk():
915
+ if not self.experts.should_fuse_routed_scaling_factor_in_topk:
693
916
  final_hidden_states *= self.routed_scaling_factor
694
917
 
695
918
  return final_hidden_states
@@ -697,7 +920,7 @@ class DeepseekV2MoE(nn.Module):
697
920
  def _forward_shared_experts(
698
921
  self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
699
922
  ):
700
- if self.num_fused_shared_experts == 0:
923
+ if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0):
701
924
  return self.shared_experts(
702
925
  hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
703
926
  )
@@ -730,7 +953,7 @@ class DeepseekV2MoE(nn.Module):
730
953
  with get_global_expert_distribution_recorder().with_current_layer(
731
954
  self.layer_id
732
955
  ):
733
- state.topk_weights_local, state.topk_idx_local, _ = self.topk(
956
+ state.topk_output = self.topk(
734
957
  hidden_states=hidden_states,
735
958
  router_logits=router_logits,
736
959
  num_token_non_padded=state.forward_batch.num_token_non_padded,
@@ -739,20 +962,13 @@ class DeepseekV2MoE(nn.Module):
739
962
  ),
740
963
  )
741
964
  else:
742
- state.topk_idx_local = torch.full(
743
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
744
- )
745
- state.topk_weights_local = torch.empty(
746
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
747
- )
965
+ state.topk_output = self.topk.empty_topk_output(hidden_states.device)
748
966
 
749
967
  def op_dispatch_a(self, state):
750
968
  if self.ep_size > 1:
751
- self.experts.deepep_dispatcher.dispatch_a(
969
+ self.experts.dispatcher.dispatch_a(
752
970
  hidden_states=state.hidden_states_mlp_input,
753
- topk_idx=state.pop("topk_idx_local"),
754
- topk_weights=state.pop("topk_weights_local"),
755
- forward_batch=state.forward_batch,
971
+ topk_output=state.pop("topk_output"),
756
972
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
757
973
  )
758
974
 
@@ -761,32 +977,29 @@ class DeepseekV2MoE(nn.Module):
761
977
  with get_global_expert_distribution_recorder().with_current_layer(
762
978
  self.layer_id
763
979
  ):
764
- state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
980
+ state.dispatch_output = self.experts.dispatcher.dispatch_b(
765
981
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
766
982
  )
767
983
 
768
984
  def op_experts(self, state):
769
- state.hidden_states_experts_output = self.experts.moe_impl(
985
+ state.hidden_states_experts_output = self.experts.run_moe_core(
770
986
  dispatch_output=state.dispatch_output,
771
987
  )
772
988
 
773
989
  def op_combine_a(self, state):
774
990
  if self.ep_size > 1:
775
- self.experts.deepep_dispatcher.combine_a(
991
+ self.experts.dispatcher.combine_a(
776
992
  hidden_states=state.pop("hidden_states_experts_output"),
777
- topk_idx=state.dispatch_output.topk_idx,
993
+ topk_ids=state.dispatch_output.topk_ids,
778
994
  topk_weights=state.dispatch_output.topk_weights,
779
- forward_batch=state.forward_batch,
780
995
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
781
996
  )
782
997
  state.pop("dispatch_output")
783
998
 
784
999
  def op_combine_b(self, state):
785
1000
  if self.ep_size > 1:
786
- state.hidden_states_after_combine = (
787
- self.experts.deepep_dispatcher.combine_b(
788
- tbo_subbatch_index=state.get("tbo_subbatch_index"),
789
- )
1001
+ state.hidden_states_after_combine = self.experts.dispatcher.combine_b(
1002
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
790
1003
  )
791
1004
 
792
1005
  def op_output(self, state):
@@ -850,6 +1063,10 @@ class DeepseekV2AttentionMLA(nn.Module):
850
1063
  self.rope_theta = rope_theta
851
1064
  self.max_position_embeddings = max_position_embeddings
852
1065
 
1066
+ # NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
1067
+ if rope_scaling:
1068
+ rope_scaling["rope_type"] = "deepseek_yarn"
1069
+
853
1070
  # For tensor parallel attention
854
1071
  if self.q_lora_rank is not None:
855
1072
  self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
@@ -864,7 +1081,7 @@ class DeepseekV2AttentionMLA(nn.Module):
864
1081
  q_lora_rank,
865
1082
  self.num_heads * self.qk_head_dim,
866
1083
  bias=False,
867
- quant_config=quant_config,
1084
+ quant_config=self._get_q_b_proj_quant_config(quant_config),
868
1085
  prefix=add_prefix("q_b_proj", prefix),
869
1086
  tp_rank=attn_tp_rank,
870
1087
  tp_size=attn_tp_size,
@@ -887,6 +1104,26 @@ class DeepseekV2AttentionMLA(nn.Module):
887
1104
  prefix=add_prefix("kv_a_proj_with_mqa", prefix),
888
1105
  )
889
1106
 
1107
+ self.use_nsa = is_deepseek_nsa(config)
1108
+ if self.use_nsa:
1109
+ self.indexer = Indexer(
1110
+ hidden_size=hidden_size,
1111
+ index_n_heads=get_nsa_index_n_heads(config),
1112
+ index_head_dim=get_nsa_index_head_dim(config),
1113
+ rope_head_dim=qk_rope_head_dim,
1114
+ index_topk=get_nsa_index_topk(config),
1115
+ q_lora_rank=q_lora_rank,
1116
+ max_position_embeddings=max_position_embeddings,
1117
+ rope_theta=rope_theta,
1118
+ scale_fmt="ue8m0",
1119
+ block_size=128,
1120
+ rope_scaling=rope_scaling,
1121
+ prefix=add_prefix("indexer", prefix),
1122
+ quant_config=quant_config,
1123
+ layer_id=layer_id,
1124
+ alt_stream=alt_stream,
1125
+ )
1126
+
890
1127
  self.kv_b_proj = ColumnParallelLinear(
891
1128
  self.kv_lora_rank,
892
1129
  self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -909,9 +1146,6 @@ class DeepseekV2AttentionMLA(nn.Module):
909
1146
  )
910
1147
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
911
1148
 
912
- if rope_scaling:
913
- rope_scaling["rope_type"] = "deepseek_yarn"
914
-
915
1149
  self.rotary_emb = get_rope_wrapper(
916
1150
  qk_rope_head_dim,
917
1151
  rotary_dim=qk_rope_head_dim,
@@ -919,7 +1153,7 @@ class DeepseekV2AttentionMLA(nn.Module):
919
1153
  base=rope_theta,
920
1154
  rope_scaling=rope_scaling,
921
1155
  is_neox_style=False,
922
- device=global_server_args_dict["device"],
1156
+ device=get_global_server_args().device,
923
1157
  )
924
1158
 
925
1159
  if rope_scaling:
@@ -963,12 +1197,12 @@ class DeepseekV2AttentionMLA(nn.Module):
963
1197
  self.w_scale_v = None
964
1198
  self.use_deep_gemm_bmm = False
965
1199
 
966
- self.flashinfer_mla_disable_ragged = global_server_args_dict[
967
- "flashinfer_mla_disable_ragged"
968
- ]
969
- self.disable_chunked_prefix_cache = global_server_args_dict[
970
- "disable_chunked_prefix_cache"
971
- ]
1200
+ self.flashinfer_mla_disable_ragged = (
1201
+ get_global_server_args().flashinfer_mla_disable_ragged
1202
+ )
1203
+ self.disable_chunked_prefix_cache = (
1204
+ get_global_server_args().disable_chunked_prefix_cache
1205
+ )
972
1206
 
973
1207
  self.current_attention_backend = (
974
1208
  None # Attention backend used by current forward batch
@@ -1035,146 +1269,34 @@ class DeepseekV2AttentionMLA(nn.Module):
1035
1269
  self.weight_block_size = (
1036
1270
  self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
1037
1271
  )
1272
+ self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
1273
+ if self.is_mla_preprocess_enabled:
1274
+ assert (
1275
+ quant_config is None or quant_config.get_name() == "w8a8_int8"
1276
+ ), "MLA Preprocess only works with Unquant or W8A8Int8"
1277
+ self.mla_preprocess = None
1038
1278
 
1039
1279
  def dispatch_attn_forward_method(
1040
1280
  self, forward_batch: ForwardBatch
1041
1281
  ) -> AttnForwardMethod:
1042
- def _dispatch_mla_subtype():
1043
- if _is_hip:
1044
- if (
1045
- self.rocm_fused_decode_mla
1046
- and forward_batch.forward_mode.is_decode()
1047
- ):
1048
- return AttnForwardMethod.MLA_FUSED_ROPE
1049
- else:
1050
- return AttnForwardMethod.MLA
1051
- else:
1052
- if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
1053
- self
1054
- ):
1055
- return AttnForwardMethod.MLA_FUSED_ROPE_CPU
1056
- else:
1057
- return AttnForwardMethod.MLA
1058
-
1059
1282
  # Determine attention backend used by current forward batch
1060
1283
  if forward_batch.forward_mode.is_decode_or_idle():
1061
- attention_backend = global_server_args_dict["decode_attention_backend"]
1284
+ attention_backend = get_global_server_args().decode_attention_backend
1062
1285
  elif (
1063
1286
  forward_batch.forward_mode.is_target_verify()
1064
1287
  or forward_batch.forward_mode.is_draft_extend()
1065
1288
  ):
1066
1289
  # Use the specified backend for speculative operations (both verify and draft extend)
1067
- if global_server_args_dict["speculative_attention_mode"] == "decode":
1068
- attention_backend = global_server_args_dict["decode_attention_backend"]
1290
+ if get_global_server_args().speculative_attention_mode == "decode":
1291
+ attention_backend = get_global_server_args().decode_attention_backend
1069
1292
  else: # default to prefill
1070
- attention_backend = global_server_args_dict["prefill_attention_backend"]
1293
+ attention_backend = get_global_server_args().prefill_attention_backend
1071
1294
  else:
1072
- attention_backend = global_server_args_dict["prefill_attention_backend"]
1295
+ attention_backend = get_global_server_args().prefill_attention_backend
1073
1296
  self.current_attention_backend = attention_backend
1074
1297
 
1075
- if attention_backend == "ascend":
1076
- if (
1077
- forward_batch.forward_mode.is_extend()
1078
- and not forward_batch.forward_mode.is_target_verify()
1079
- and not forward_batch.forward_mode.is_draft_extend()
1080
- ):
1081
- return AttnForwardMethod.MHA
1082
- else:
1083
- return AttnForwardMethod.MLA
1084
- elif (
1085
- attention_backend == "flashinfer"
1086
- or attention_backend == "fa3"
1087
- or attention_backend == "flashmla"
1088
- or attention_backend == "cutlass_mla"
1089
- ):
1090
- # Use MHA with chunked KV cache when prefilling on long sequences.
1091
- sum_extend_prefix_lens = (
1092
- sum(forward_batch.extend_prefix_lens_cpu)
1093
- if forward_batch.extend_prefix_lens_cpu is not None
1094
- else 0
1095
- )
1096
- # Flashinfer MLA: Do not absorb when enabling ragged prefill
1097
- disable_ragged = (
1098
- attention_backend == "flashinfer" or attention_backend == "flashmla"
1099
- ) and self.flashinfer_mla_disable_ragged
1100
-
1101
- original_mode = getattr(forward_batch, "_original_forward_mode", None)
1102
- if (
1103
- not disable_ragged
1104
- and forward_batch.forward_mode.is_extend()
1105
- and not forward_batch.forward_mode.is_target_verify()
1106
- and not forward_batch.forward_mode.is_draft_extend()
1107
- and (
1108
- (
1109
- sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
1110
- and not self.disable_chunked_prefix_cache
1111
- )
1112
- or sum_extend_prefix_lens == 0
1113
- )
1114
- # TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for
1115
- # dp case. Redirect to mla kernel as a workaround.
1116
- # Tracked by https://github.com/sgl-project/sglang/issues/9806.
1117
- and not (
1118
- original_mode is not None
1119
- and original_mode.is_decode()
1120
- and is_sm100_supported()
1121
- and self.current_attention_backend in ("cutlass_mla", "flashinfer")
1122
- )
1123
- ):
1124
- return AttnForwardMethod.MHA_CHUNKED_KV
1125
- else:
1126
- return _dispatch_mla_subtype()
1127
- elif attention_backend == "trtllm_mla":
1128
- original_mode = getattr(forward_batch, "_original_forward_mode", None)
1129
- if (
1130
- original_mode is not None
1131
- and original_mode.is_decode()
1132
- and is_sm100_supported()
1133
- ):
1134
- return _dispatch_mla_subtype()
1135
-
1136
- sum_extend_prefix_lens = (
1137
- sum(forward_batch.extend_prefix_lens_cpu)
1138
- if forward_batch.extend_prefix_lens_cpu is not None
1139
- else 0
1140
- )
1141
- if (
1142
- forward_batch.forward_mode.is_extend()
1143
- and not forward_batch.forward_mode.is_target_verify()
1144
- and not forward_batch.forward_mode.is_draft_extend()
1145
- and (
1146
- not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
1147
- )
1148
- ):
1149
- return AttnForwardMethod.MHA_CHUNKED_KV
1150
- else:
1151
- return _dispatch_mla_subtype()
1152
- elif attention_backend == "aiter":
1153
- if (
1154
- forward_batch.forward_mode.is_extend()
1155
- and not forward_batch.forward_mode.is_target_verify()
1156
- and not forward_batch.forward_mode.is_draft_extend()
1157
- ):
1158
- if is_dp_attention_enabled():
1159
- if sum(forward_batch.extend_prefix_lens_cpu) == 0:
1160
- return AttnForwardMethod.MHA
1161
- else:
1162
- return AttnForwardMethod.MLA
1163
- else:
1164
- return AttnForwardMethod.MHA
1165
- else:
1166
- return AttnForwardMethod.MLA
1167
- else:
1168
- # Triton: Use normal computation for prefill and use weight absorption for extend/decode
1169
- if (
1170
- forward_batch.forward_mode.is_extend()
1171
- and not forward_batch.forward_mode.is_target_verify()
1172
- and not forward_batch.forward_mode.is_draft_extend()
1173
- and sum(forward_batch.extend_prefix_lens_cpu) == 0
1174
- ):
1175
- return AttnForwardMethod.MHA
1176
- else:
1177
- return _dispatch_mla_subtype()
1298
+ handler = AttentionBackendRegistry.get_handler(attention_backend)
1299
+ return handler(self, forward_batch)
1178
1300
 
1179
1301
  def op_prepare(self, state):
1180
1302
  state.attn_intermediate_state = self.forward_prepare(
@@ -1229,7 +1351,6 @@ class DeepseekV2AttentionMLA(nn.Module):
1229
1351
  return hidden_states, None, forward_batch, None
1230
1352
 
1231
1353
  attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
1232
-
1233
1354
  if attn_forward_method == AttnForwardMethod.MHA:
1234
1355
  inner_state = self.forward_normal_prepare(
1235
1356
  positions, hidden_states, forward_batch, zero_allocator
@@ -1239,7 +1360,31 @@ class DeepseekV2AttentionMLA(nn.Module):
1239
1360
  positions, hidden_states, forward_batch, zero_allocator
1240
1361
  )
1241
1362
  elif attn_forward_method == AttnForwardMethod.MLA:
1242
- inner_state = self.forward_absorb_prepare(
1363
+ if not self.is_mla_preprocess_enabled:
1364
+ inner_state = self.forward_absorb_prepare(
1365
+ positions, hidden_states, forward_batch, zero_allocator
1366
+ )
1367
+ else:
1368
+ # TODO(iforgetmyname): to be separated as a standalone func
1369
+ if self.mla_preprocess is None:
1370
+ self.mla_preprocess = NPUFusedMLAPreprocess(
1371
+ self.fused_qkv_a_proj_with_mqa,
1372
+ self.q_a_layernorm,
1373
+ self.kv_a_layernorm,
1374
+ self.q_b_proj,
1375
+ self.w_kc,
1376
+ self.rotary_emb,
1377
+ self.layer_id,
1378
+ self.num_local_heads,
1379
+ self.qk_nope_head_dim,
1380
+ self.qk_rope_head_dim,
1381
+ )
1382
+ inner_state = self.mla_preprocess.forward(
1383
+ positions, hidden_states, forward_batch, zero_allocator
1384
+ )
1385
+ inner_state = (*inner_state, None) # add a position for topk_indices
1386
+ elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
1387
+ inner_state = self.forward_npu_sparse_prepare(
1243
1388
  positions, hidden_states, forward_batch, zero_allocator
1244
1389
  )
1245
1390
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
@@ -1267,6 +1412,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1267
1412
  return self.forward_normal_chunked_kv_core(*inner_state)
1268
1413
  elif attn_forward_method == AttnForwardMethod.MLA:
1269
1414
  return self.forward_absorb_core(*inner_state)
1415
+ elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
1416
+ return self.forward_npu_sparse_core(*inner_state)
1270
1417
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
1271
1418
  return self.forward_absorb_fused_mla_rope_core(*inner_state)
1272
1419
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
@@ -1346,7 +1493,10 @@ class DeepseekV2AttentionMLA(nn.Module):
1346
1493
  """
1347
1494
  return (
1348
1495
  self.current_attention_backend == "trtllm_mla"
1349
- and forward_batch.forward_mode.is_decode_or_idle()
1496
+ and (
1497
+ forward_batch.forward_mode.is_decode_or_idle()
1498
+ or forward_batch.forward_mode.is_target_verify()
1499
+ )
1350
1500
  and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
1351
1501
  )
1352
1502
 
@@ -1359,6 +1509,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1359
1509
  ):
1360
1510
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1361
1511
 
1512
+ q_lora = None
1362
1513
  if self.q_lora_rank is not None:
1363
1514
  if (
1364
1515
  (not isinstance(hidden_states, tuple))
@@ -1397,6 +1548,10 @@ class DeepseekV2AttentionMLA(nn.Module):
1397
1548
  q = self.q_a_layernorm(q)
1398
1549
  k_nope = self.kv_a_layernorm(k_nope)
1399
1550
 
1551
+ # q_lora needed by indexer
1552
+ if self.use_nsa:
1553
+ q_lora = q
1554
+
1400
1555
  k_nope = k_nope.unsqueeze(1)
1401
1556
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
1402
1557
  else:
@@ -1449,9 +1604,14 @@ class DeepseekV2AttentionMLA(nn.Module):
1449
1604
  self.w_kc.to(torch.bfloat16) * self.w_scale,
1450
1605
  )
1451
1606
  elif self.w_kc.dtype == torch.float8_e4m3fn:
1607
+ # fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612
1452
1608
  q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1453
1609
  q_nope.transpose(0, 1),
1454
- zero_allocator.allocate(1),
1610
+ (
1611
+ torch.zeros((1,), dtype=torch.float32, device=q_nope.device)
1612
+ if _is_cublas_ge_129
1613
+ else zero_allocator.allocate(1)
1614
+ ),
1455
1615
  )
1456
1616
  q_nope_out = bmm_fp8(
1457
1617
  q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -1462,28 +1622,50 @@ class DeepseekV2AttentionMLA(nn.Module):
1462
1622
  q_nope_out = q_nope_out.transpose(0, 1)
1463
1623
 
1464
1624
  if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1465
- not _use_aiter or not _is_gfx95_supported
1625
+ not _use_aiter or not _is_gfx95_supported or self.use_nsa
1466
1626
  ):
1467
1627
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1468
1628
 
1469
- return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1629
+ topk_indices = None
1630
+ if q_lora is not None:
1631
+ topk_indices = self.indexer(
1632
+ x=hidden_states,
1633
+ q_lora=q_lora,
1634
+ positions=positions,
1635
+ forward_batch=forward_batch,
1636
+ layer_id=self.layer_id,
1637
+ )
1638
+
1639
+ return (
1640
+ q_pe,
1641
+ k_pe,
1642
+ q_nope_out,
1643
+ k_nope,
1644
+ forward_batch,
1645
+ zero_allocator,
1646
+ positions,
1647
+ topk_indices,
1648
+ )
1470
1649
 
1471
1650
  def forward_absorb_core(
1472
- self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1651
+ self,
1652
+ q_pe,
1653
+ k_pe,
1654
+ q_nope_out,
1655
+ k_nope,
1656
+ forward_batch,
1657
+ zero_allocator,
1658
+ positions,
1659
+ topk_indices,
1473
1660
  ):
1474
- if (
1475
- self.current_attention_backend == "fa3"
1476
- or self.current_attention_backend == "flashinfer"
1477
- or self.current_attention_backend == "cutlass_mla"
1478
- or self.current_attention_backend == "trtllm_mla"
1479
- or self.current_attention_backend == "ascend"
1480
- ):
1661
+ if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
1481
1662
  extra_args = {}
1482
1663
  if self._fuse_rope_for_trtllm_mla(forward_batch):
1483
1664
  extra_args = {
1484
1665
  "cos_sin_cache": self.rotary_emb.cos_sin_cache,
1485
1666
  "is_neox": self.rotary_emb.is_neox_style,
1486
1667
  }
1668
+
1487
1669
  attn_output = self.attn_mqa(
1488
1670
  q_nope_out,
1489
1671
  k_nope,
@@ -1492,6 +1674,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1492
1674
  q_rope=q_pe,
1493
1675
  k_rope=k_pe,
1494
1676
  **extra_args,
1677
+ **(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
1495
1678
  )
1496
1679
  else:
1497
1680
  if _use_aiter_gfx95:
@@ -1511,7 +1694,13 @@ class DeepseekV2AttentionMLA(nn.Module):
1511
1694
  q = torch.cat([q_nope_out, q_pe], dim=-1)
1512
1695
  k = torch.cat([k_nope, k_pe], dim=-1)
1513
1696
 
1514
- attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1697
+ attn_output = self.attn_mqa(
1698
+ q,
1699
+ k,
1700
+ k_nope,
1701
+ forward_batch,
1702
+ **(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
1703
+ )
1515
1704
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1516
1705
 
1517
1706
  if self.use_deep_gemm_bmm:
@@ -1566,7 +1755,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1566
1755
  elif self.w_vc.dtype == torch.float8_e4m3fn:
1567
1756
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1568
1757
  attn_output.transpose(0, 1),
1569
- zero_allocator.allocate(1),
1758
+ (
1759
+ torch.zeros((1,), dtype=torch.float32, device=attn_output.device)
1760
+ if _is_cublas_ge_129
1761
+ else zero_allocator.allocate(1)
1762
+ ),
1570
1763
  )
1571
1764
  attn_bmm_output = bmm_fp8(
1572
1765
  attn_output_val,
@@ -1593,6 +1786,221 @@ class DeepseekV2AttentionMLA(nn.Module):
1593
1786
 
1594
1787
  return output
1595
1788
 
1789
+ def forward_npu_sparse_prepare(
1790
+ self,
1791
+ positions: torch.Tensor,
1792
+ hidden_states: torch.Tensor,
1793
+ forward_batch: ForwardBatch,
1794
+ zero_allocator: BumpAllocator,
1795
+ ):
1796
+ """
1797
+ Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare
1798
+ """
1799
+ if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode():
1800
+ if self.mla_preprocess is None:
1801
+ self.mla_preprocess = NPUFusedMLAPreprocess(
1802
+ self.fused_qkv_a_proj_with_mqa,
1803
+ self.q_a_layernorm,
1804
+ self.kv_a_layernorm,
1805
+ self.q_b_proj,
1806
+ self.w_kc,
1807
+ self.rotary_emb,
1808
+ self.layer_id,
1809
+ self.num_local_heads,
1810
+ self.qk_nope_head_dim,
1811
+ self.qk_rope_head_dim,
1812
+ )
1813
+ (
1814
+ q_pe,
1815
+ k_pe,
1816
+ q_nope_out,
1817
+ k_nope,
1818
+ forward_batch,
1819
+ zero_allocator,
1820
+ positions,
1821
+ ) = self.mla_preprocess.forward(
1822
+ positions, hidden_states, forward_batch, zero_allocator
1823
+ )
1824
+
1825
+ fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
1826
+ q, _ = fused_qkv_a_proj_out.split(
1827
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1828
+ )
1829
+ q_lora = self.q_a_layernorm(q)
1830
+ else:
1831
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1832
+
1833
+ if (
1834
+ (not isinstance(hidden_states, tuple))
1835
+ and hidden_states.shape[0] <= 16
1836
+ and self.use_min_latency_fused_a_gemm
1837
+ ):
1838
+ fused_qkv_a_proj_out = dsv3_fused_a_gemm(
1839
+ hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
1840
+ )
1841
+ else:
1842
+ fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
1843
+ q, latent_cache = fused_qkv_a_proj_out.split(
1844
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1845
+ )
1846
+ k_nope = latent_cache[..., : self.kv_lora_rank]
1847
+
1848
+ # overlap qk norm
1849
+ if self.alt_stream is not None and get_is_capture_mode():
1850
+ current_stream = torch.cuda.current_stream()
1851
+ self.alt_stream.wait_stream(current_stream)
1852
+ q = self.q_a_layernorm(q)
1853
+ with torch.cuda.stream(self.alt_stream):
1854
+ k_nope = self.kv_a_layernorm(k_nope)
1855
+ current_stream.wait_stream(self.alt_stream)
1856
+ else:
1857
+ if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
1858
+ q, k_nope = fused_rms_mxfp4_quant(
1859
+ q,
1860
+ self.q_a_layernorm.weight,
1861
+ self.q_a_layernorm.variance_epsilon,
1862
+ k_nope,
1863
+ self.kv_a_layernorm.weight,
1864
+ self.kv_a_layernorm.variance_epsilon,
1865
+ )
1866
+ else:
1867
+ q = self.q_a_layernorm(q)
1868
+ k_nope = self.kv_a_layernorm(k_nope)
1869
+
1870
+ q_lora = q.clone() # required for topk_indices
1871
+ k_nope = k_nope.unsqueeze(1)
1872
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
1873
+
1874
+ q_nope, q_pe = q.split(
1875
+ [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
1876
+ )
1877
+ k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
1878
+
1879
+ if self.use_deep_gemm_bmm:
1880
+ q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
1881
+ per_token_group_quant_mla_deep_gemm_masked_fp8(
1882
+ q_nope.transpose(0, 1)
1883
+ )
1884
+ )
1885
+ q_nope_out = q_nope.new_empty(
1886
+ (self.num_local_heads, aligned_m, self.kv_lora_rank)
1887
+ )
1888
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1889
+ (q_nope_val, q_nope_scale),
1890
+ (self.w_kc, self.w_scale_k),
1891
+ q_nope_out,
1892
+ masked_m,
1893
+ expected_m,
1894
+ )
1895
+ q_nope_out = q_nope_out[:, :expected_m, :]
1896
+ elif _is_hip:
1897
+ # TODO(haishaw): add bmm_fp8 to ROCm
1898
+ if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
1899
+ x = q_nope.transpose(0, 1)
1900
+ q_nope_out = torch.empty(
1901
+ x.shape[0],
1902
+ x.shape[1],
1903
+ self.w_kc.shape[2],
1904
+ device=x.device,
1905
+ dtype=torch.bfloat16,
1906
+ )
1907
+ batched_gemm_afp4wfp4_pre_quant(
1908
+ x,
1909
+ self.w_kc.transpose(-2, -1),
1910
+ self.w_scale_k.transpose(-2, -1),
1911
+ torch.bfloat16,
1912
+ q_nope_out,
1913
+ )
1914
+ else:
1915
+ q_nope_out = torch.bmm(
1916
+ q_nope.to(torch.bfloat16).transpose(0, 1),
1917
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
1918
+ )
1919
+ elif self.w_kc.dtype == torch.float8_e4m3fn:
1920
+ q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1921
+ q_nope.transpose(0, 1),
1922
+ zero_allocator.allocate(1),
1923
+ )
1924
+ q_nope_out = bmm_fp8(
1925
+ q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
1926
+ )
1927
+ else:
1928
+ q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
1929
+
1930
+ q_nope_out = q_nope_out.transpose(0, 1)
1931
+
1932
+ if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1933
+ not _use_aiter or not _is_gfx95_supported
1934
+ ):
1935
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1936
+
1937
+ # TODO: multi-stream indexer
1938
+ topk_indices = self.indexer(
1939
+ hidden_states, q_lora, positions, forward_batch, self.layer_id
1940
+ )
1941
+
1942
+ return (
1943
+ q_pe,
1944
+ k_pe,
1945
+ q_nope_out,
1946
+ k_nope,
1947
+ topk_indices,
1948
+ forward_batch,
1949
+ zero_allocator,
1950
+ positions,
1951
+ )
1952
+
1953
+ def forward_npu_sparse_core(
1954
+ self,
1955
+ q_pe,
1956
+ k_pe,
1957
+ q_nope_out,
1958
+ k_nope,
1959
+ topk_indices,
1960
+ forward_batch,
1961
+ zero_allocator,
1962
+ positions,
1963
+ ):
1964
+ attn_output = self.attn_mqa(
1965
+ q_nope_out.contiguous(),
1966
+ k_nope.contiguous(),
1967
+ k_nope.contiguous(),
1968
+ forward_batch,
1969
+ save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True,
1970
+ q_rope=q_pe.contiguous(),
1971
+ k_rope=k_pe.contiguous(),
1972
+ topk_indices=topk_indices,
1973
+ )
1974
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1975
+
1976
+ attn_bmm_output = torch.empty(
1977
+ (attn_output.shape[0], self.num_local_heads, self.v_head_dim),
1978
+ dtype=attn_output.dtype,
1979
+ device=attn_output.device,
1980
+ )
1981
+
1982
+ if not forward_batch.forward_mode.is_decode():
1983
+ attn_output = attn_output.transpose(0, 1)
1984
+ torch.bmm(
1985
+ attn_output,
1986
+ self.w_vc,
1987
+ out=attn_bmm_output.view(
1988
+ -1, self.num_local_heads, self.v_head_dim
1989
+ ).transpose(0, 1),
1990
+ )
1991
+ else:
1992
+ attn_output = attn_output.contiguous()
1993
+ torch.ops.npu.batch_matmul_transpose(
1994
+ attn_output, self.w_vc, attn_bmm_output
1995
+ )
1996
+
1997
+ attn_bmm_output = attn_bmm_output.reshape(
1998
+ -1, self.num_local_heads * self.v_head_dim
1999
+ )
2000
+
2001
+ output, _ = self.o_proj(attn_bmm_output)
2002
+ return output
2003
+
1596
2004
  def forward_absorb_fused_mla_rope_prepare(
1597
2005
  self,
1598
2006
  positions: torch.Tensor,
@@ -1918,6 +2326,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1918
2326
  tmp_lse = torch.empty_like(accum_lse)
1919
2327
  merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
1920
2328
  accum_output, accum_lse = tmp_output, tmp_lse
2329
+ del kv, k, v, output, lse, tmp_output, tmp_lse
1921
2330
 
1922
2331
  return accum_output
1923
2332
 
@@ -1967,6 +2376,17 @@ class DeepseekV2AttentionMLA(nn.Module):
1967
2376
  output, _ = self.o_proj(attn_output)
1968
2377
  return output
1969
2378
 
2379
+ @staticmethod
2380
+ def _get_q_b_proj_quant_config(quant_config):
2381
+ if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
2382
+ # refer to real DeepSeek V3 quant config
2383
+ return Fp8Config(
2384
+ is_checkpoint_fp8_serialized=True,
2385
+ weight_block_size=[128, 128],
2386
+ )
2387
+ else:
2388
+ return quant_config
2389
+
1970
2390
 
1971
2391
  class DeepseekV2DecoderLayer(nn.Module):
1972
2392
 
@@ -1975,6 +2395,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1975
2395
  config: PretrainedConfig,
1976
2396
  layer_id: int,
1977
2397
  quant_config: Optional[QuantizationConfig] = None,
2398
+ moe_quant_config: Optional[QuantizationConfig] = None,
1978
2399
  is_nextn: bool = False,
1979
2400
  prefix: str = "",
1980
2401
  alt_stream: Optional[torch.cuda.Stream] = None,
@@ -1985,7 +2406,9 @@ class DeepseekV2DecoderLayer(nn.Module):
1985
2406
  rope_theta = getattr(config, "rope_theta", 10000)
1986
2407
  rope_scaling = getattr(config, "rope_scaling", None)
1987
2408
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1988
- self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
2409
+ self.speculative_algorithm = SpeculativeAlgorithm.from_string(
2410
+ get_global_server_args().speculative_algorithm
2411
+ )
1989
2412
  self.layer_id = layer_id
1990
2413
  self.is_nextn = is_nextn
1991
2414
  self.self_attn = DeepseekV2AttentionMLA(
@@ -2022,7 +2445,7 @@ class DeepseekV2DecoderLayer(nn.Module):
2022
2445
  if self.is_layer_sparse:
2023
2446
  self.mlp = DeepseekV2MoE(
2024
2447
  config=config,
2025
- quant_config=quant_config,
2448
+ quant_config=moe_quant_config or quant_config,
2026
2449
  prefix=add_prefix("mlp", prefix),
2027
2450
  layer_id=self.layer_id,
2028
2451
  alt_stream=alt_stream,
@@ -2074,7 +2497,6 @@ class DeepseekV2DecoderLayer(nn.Module):
2074
2497
  zero_allocator: BumpAllocator,
2075
2498
  gemm_output_zero_allocator: BumpAllocator = None,
2076
2499
  ) -> torch.Tensor:
2077
-
2078
2500
  quant_format = (
2079
2501
  "mxfp4"
2080
2502
  if _is_gfx95_supported
@@ -2429,6 +2851,10 @@ class DeepseekV2ForCausalLM(nn.Module):
2429
2851
  self.config = config
2430
2852
  self.tp_size = get_tensor_model_parallel_world_size()
2431
2853
  self.quant_config = quant_config
2854
+ if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set():
2855
+ CompressedTensorsConfig.DeepSeekFP8Config = Fp8Config(
2856
+ True, "dynamic", None, [128, 128]
2857
+ )
2432
2858
  self.determine_num_fused_shared_experts()
2433
2859
  self.model = DeepseekV2Model(
2434
2860
  config, quant_config, prefix=add_prefix("model", prefix)
@@ -2438,7 +2864,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2438
2864
  config.hidden_size,
2439
2865
  quant_config=quant_config,
2440
2866
  prefix=add_prefix("lm_head", prefix),
2441
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
2867
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
2442
2868
  )
2443
2869
  self.logits_processor = LogitsProcessor(config)
2444
2870
 
@@ -2458,7 +2884,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2458
2884
  self, architecture: str = "DeepseekV3ForCausalLM"
2459
2885
  ):
2460
2886
  self.num_fused_shared_experts = 0
2461
- if global_server_args_dict["disable_shared_experts_fusion"]:
2887
+ if get_global_server_args().disable_shared_experts_fusion:
2462
2888
  return
2463
2889
 
2464
2890
  # Only Deepseek V3/R1 can use shared experts fusion optimization now.
@@ -2477,7 +2903,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2477
2903
  disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
2478
2904
 
2479
2905
  if disable_reason is not None:
2480
- global_server_args_dict["disable_shared_experts_fusion"] = True
2906
+ get_global_server_args().disable_shared_experts_fusion = True
2481
2907
  self.num_fused_shared_experts = 0
2482
2908
  log_info_on_rank0(
2483
2909
  logger,
@@ -2542,7 +2968,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2542
2968
  )
2543
2969
  if hasattr(self_attn.kv_b_proj, "qweight"):
2544
2970
  # AWQ compatible
2545
- if _is_cuda or _is_hip:
2971
+ if _is_cuda or _is_hip or _is_npu:
2546
2972
  w = awq_dequantize(
2547
2973
  self_attn.kv_b_proj.qweight,
2548
2974
  self_attn.kv_b_proj.scales,
@@ -2568,11 +2994,13 @@ class DeepseekV2ForCausalLM(nn.Module):
2568
2994
  torch.float8_e4m3fn,
2569
2995
  torch.float8_e4m3fnuz,
2570
2996
  ):
2571
- if (
2572
- hasattr(self.quant_config, "weight_block_size")
2573
- and self.quant_config.weight_block_size is not None
2574
- ):
2575
- weight_block_size = self.quant_config.weight_block_size
2997
+ selected_quant_config = getattr(
2998
+ self.quant_config, "DeepSeekFP8Config", self.quant_config
2999
+ )
3000
+ weight_block_size = getattr(
3001
+ selected_quant_config, "weight_block_size", None
3002
+ )
3003
+ if weight_block_size is not None:
2576
3004
  assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
2577
3005
  if _is_fp8_fnuz:
2578
3006
  weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
@@ -2702,6 +3130,16 @@ class DeepseekV2ForCausalLM(nn.Module):
2702
3130
  ):
2703
3131
  self._weight_requant_ue8m0(is_nextn)
2704
3132
 
3133
+ # TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
3134
+ if (
3135
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
3136
+ and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
3137
+ and get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN")
3138
+ ):
3139
+ self._transform_scale_ue8m0(is_nextn)
3140
+ if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
3141
+ self._transform_scale_nextn_moe_ue8m0()
3142
+
2705
3143
  def _weight_requant_ue8m0(self, is_nextn=False):
2706
3144
  weight_block_size = self.quant_config.weight_block_size
2707
3145
 
@@ -2767,6 +3205,47 @@ class DeepseekV2ForCausalLM(nn.Module):
2767
3205
  module.weight, module.weight_scale_inv, weight_block_size
2768
3206
  )
2769
3207
 
3208
+ # TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
3209
+ def _transform_scale_ue8m0(self, is_nextn=False):
3210
+ num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
3211
+
3212
+ for layer_id in range(num_hidden_layers):
3213
+ if is_nextn:
3214
+ layer = self.model.decoder
3215
+ else:
3216
+ layer = self.model.layers[layer_id]
3217
+
3218
+ module_list = []
3219
+ if self.config.q_lora_rank is not None:
3220
+ module_list.append(layer.self_attn.q_b_proj)
3221
+
3222
+ for module in module_list:
3223
+ transform_scale_ue8m0_inplace(
3224
+ module.weight_scale_inv, mn=module.weight.shape[-2]
3225
+ )
3226
+
3227
+ # TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0)
3228
+ def _transform_scale_nextn_moe_ue8m0(self):
3229
+ layer = self.model.decoder
3230
+
3231
+ shared_experts = getattr(layer.mlp, "shared_experts", None)
3232
+ if shared_experts is not None:
3233
+ for module in [
3234
+ shared_experts.gate_up_proj,
3235
+ shared_experts.down_proj,
3236
+ ]:
3237
+ transform_scale_ue8m0_inplace(
3238
+ module.weight_scale_inv, mn=module.weight.shape[-2]
3239
+ )
3240
+
3241
+ experts = layer.mlp.experts
3242
+ if isinstance(experts, DeepEPMoE):
3243
+ for w in [
3244
+ experts.w13_weight_fp8,
3245
+ experts.w2_weight_fp8,
3246
+ ]:
3247
+ transform_scale_ue8m0_inplace(w[1], mn=w[0].shape[-2])
3248
+
2770
3249
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
2771
3250
 
2772
3251
  if is_nextn:
@@ -2782,6 +3261,13 @@ class DeepseekV2ForCausalLM(nn.Module):
2782
3261
  else:
2783
3262
  raise ValueError("num_nextn_predict_layers is not in the config")
2784
3263
 
3264
+ if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
3265
+ weights = self._quant_attn_to_fp8_ue8m0(weights, is_nextn=is_nextn)
3266
+ if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
3267
+ weights = self._quant_nextn_moe_to_fp8_ue8m0(
3268
+ weights, nextn_layer_id=nextn_layer_id
3269
+ )
3270
+
2785
3271
  stacked_params_mapping = [
2786
3272
  # (param_name, shard_name, shard_id)
2787
3273
  ("gate_up_proj", "gate_proj", 0),
@@ -3011,6 +3497,62 @@ class DeepseekV2ForCausalLM(nn.Module):
3011
3497
 
3012
3498
  self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
3013
3499
 
3500
+ def _quant_attn_to_fp8_ue8m0(self, weights, is_nextn):
3501
+ weights_dict = dict(weights)
3502
+
3503
+ # temporarily only support DeepSeek V3/R1
3504
+ weight_block_size = [128, 128]
3505
+
3506
+ for layer_id in tqdm.trange(
3507
+ self.config.num_hidden_layers + int(is_nextn),
3508
+ desc="quant attn to fp8 ue8m0",
3509
+ ):
3510
+ for stem in [
3511
+ # may put tensors like `o_proj` here for DeepSeek FP4 ckpt v1
3512
+ "q_b_proj",
3513
+ ]:
3514
+ partial_name = f"model.layers.{layer_id}.self_attn.{stem}"
3515
+ original_weight = weights_dict[f"{partial_name}.weight"]
3516
+ out_w, out_s = quant_weight_ue8m0(
3517
+ original_weight, weight_block_size=weight_block_size
3518
+ )
3519
+ weights_dict[f"{partial_name}.weight"] = out_w
3520
+ weights_dict[f"{partial_name}.weight_scale_inv"] = out_s
3521
+
3522
+ return list(weights_dict.items())
3523
+
3524
+ # TODO avoid code dup
3525
+ def _quant_nextn_moe_to_fp8_ue8m0(self, weights, nextn_layer_id: int):
3526
+ weights_dict = dict(weights)
3527
+
3528
+ # temporarily only support DeepSeek V3/R1
3529
+ weight_block_size = [128, 128]
3530
+
3531
+ for layer_id in [nextn_layer_id]:
3532
+ for expert_sub_name in [
3533
+ "shared_experts",
3534
+ *[
3535
+ f"experts.{expert_id}"
3536
+ for expert_id in range(self.config.n_routed_experts)
3537
+ ],
3538
+ ]:
3539
+ for stem in [
3540
+ "gate_proj",
3541
+ "up_proj",
3542
+ "down_proj",
3543
+ ]:
3544
+ partial_name = (
3545
+ f"model.layers.{layer_id}.mlp.{expert_sub_name}.{stem}"
3546
+ )
3547
+ original_weight = weights_dict[f"{partial_name}.weight"]
3548
+ out_w, out_s = quant_weight_ue8m0(
3549
+ original_weight, weight_block_size=weight_block_size
3550
+ )
3551
+ weights_dict[f"{partial_name}.weight"] = out_w
3552
+ weights_dict[f"{partial_name}.weight_scale_inv"] = out_s
3553
+
3554
+ return list(weights_dict.items())
3555
+
3014
3556
  def get_embed_and_head(self):
3015
3557
  return self.model.embed_tokens.weight, self.lm_head.weight
3016
3558
 
@@ -3031,8 +3573,24 @@ class DeepseekV2ForCausalLM(nn.Module):
3031
3573
  )
3032
3574
 
3033
3575
 
3576
+ AttentionBackendRegistry.register("ascend", handle_attention_ascend)
3577
+ AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
3578
+ AttentionBackendRegistry.register("fa3", handle_attention_fa3)
3579
+ AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
3580
+ AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
3581
+ AttentionBackendRegistry.register("fa4", handle_attention_fa4)
3582
+ AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
3583
+ AttentionBackendRegistry.register("aiter", handle_attention_aiter)
3584
+ AttentionBackendRegistry.register("nsa", handle_attention_nsa)
3585
+ AttentionBackendRegistry.register("triton", handle_attention_triton)
3586
+
3587
+
3034
3588
  class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
3035
3589
  pass
3036
3590
 
3037
3591
 
3038
- EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
3592
+ class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
3593
+ pass
3594
+
3595
+
3596
+ EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]