sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -65,10 +65,11 @@ from sglang.srt.layers.moe import (
65
65
  get_deepep_mode,
66
66
  get_moe_a2a_backend,
67
67
  should_use_flashinfer_cutlass_moe_fp4_allgather,
68
+ should_use_flashinfer_trtllm_moe,
68
69
  )
69
70
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
70
71
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
71
- from sglang.srt.layers.moe.topk import TopK
72
+ from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
72
73
  from sglang.srt.layers.quantization import deep_gemm_wrapper
73
74
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
74
75
  from sglang.srt.layers.quantization.fp8_kernel import (
@@ -112,6 +113,7 @@ from sglang.srt.utils import (
112
113
  is_cpu,
113
114
  is_cuda,
114
115
  is_flashinfer_available,
116
+ is_gfx95_supported,
115
117
  is_hip,
116
118
  is_non_idle_and_non_empty,
117
119
  is_npu,
@@ -129,11 +131,28 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
129
131
  _is_cpu_amx_available = cpu_has_amx_support()
130
132
  _is_cpu = is_cpu()
131
133
  _device_sm = get_device_sm()
134
+ _is_gfx95_supported = is_gfx95_supported()
135
+
136
+ _use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
137
+
138
+ if _use_aiter_gfx95:
139
+ from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
140
+ from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
141
+ batched_gemm_afp4wfp4_pre_quant,
142
+ fused_flatten_mxfp4_quant,
143
+ fused_rms_mxfp4_quant,
144
+ )
145
+ from sglang.srt.layers.rocm_linear_utils import (
146
+ aiter_dsv3_router_gemm,
147
+ fused_qk_rope_cat,
148
+ get_dsv3_gemm_output_zero_allocator_size,
149
+ )
132
150
 
133
151
  if _is_cuda:
134
152
  from sgl_kernel import (
135
153
  awq_dequantize,
136
154
  bmm_fp8,
155
+ concat_mla_k,
137
156
  dsv3_fused_a_gemm,
138
157
  dsv3_router_gemm,
139
158
  merge_state_v2,
@@ -224,10 +243,21 @@ class DeepseekV2MLP(nn.Module):
224
243
  forward_batch=None,
225
244
  should_allreduce_fusion: bool = False,
226
245
  use_reduce_scatter: bool = False,
246
+ gemm_output_zero_allocator: BumpAllocator = None,
227
247
  ):
228
248
  if (self.tp_size == 1) and x.shape[0] == 0:
229
249
  return x
230
250
 
251
+ if (
252
+ gemm_output_zero_allocator is not None
253
+ and x.shape[0] <= 256
254
+ and self.gate_up_proj.weight.dtype == torch.uint8
255
+ ):
256
+ y = gemm_output_zero_allocator.allocate(
257
+ x.shape[0] * self.gate_up_proj.output_size_per_partition
258
+ ).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
259
+ x = (x, None, y)
260
+
231
261
  gate_up, _ = self.gate_up_proj(x)
232
262
  x = self.act_fn(gate_up)
233
263
  x, _ = self.down_proj(
@@ -240,6 +270,7 @@ class MoEGate(nn.Module):
240
270
  def __init__(
241
271
  self,
242
272
  config,
273
+ quant_config,
243
274
  prefix: str = "",
244
275
  is_nextn: bool = False,
245
276
  ):
@@ -249,15 +280,22 @@ class MoEGate(nn.Module):
249
280
  torch.empty((config.n_routed_experts, config.hidden_size))
250
281
  )
251
282
  if config.topk_method == "noaux_tc":
283
+ correction_bias_dtype = (
284
+ torch.bfloat16
285
+ if quant_config is not None
286
+ and quant_config.get_name() == "modelopt_fp4"
287
+ and should_use_flashinfer_trtllm_moe()
288
+ else torch.float32
289
+ )
252
290
  self.e_score_correction_bias = nn.Parameter(
253
- torch.empty((config.n_routed_experts), dtype=torch.float32)
291
+ torch.empty((config.n_routed_experts), dtype=correction_bias_dtype)
254
292
  )
255
293
  else:
256
294
  self.e_score_correction_bias = None
257
295
  if _is_cpu and _is_cpu_amx_available:
258
296
  self.quant_method = PackWeightMethod(weight_names=["weight"])
259
297
 
260
- def forward(self, hidden_states):
298
+ def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
261
299
  if use_intel_amx_backend(self):
262
300
  return torch.ops.sgl_kernel.weight_packed_linear(
263
301
  hidden_states,
@@ -275,7 +313,13 @@ class MoEGate(nn.Module):
275
313
  and _device_sm >= 90
276
314
  ):
277
315
  # router gemm output float32
278
- logits = dsv3_router_gemm(hidden_states, self.weight)
316
+ logits = dsv3_router_gemm(
317
+ hidden_states, self.weight, out_dtype=torch.float32
318
+ )
319
+ elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
320
+ logits = aiter_dsv3_router_gemm(
321
+ hidden_states, self.weight, gemm_output_zero_allocator
322
+ )
279
323
  else:
280
324
  logits = F.linear(hidden_states, self.weight, None)
281
325
 
@@ -319,7 +363,10 @@ class DeepseekV2MoE(nn.Module):
319
363
  )
320
364
 
321
365
  self.gate = MoEGate(
322
- config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
366
+ config=config,
367
+ quant_config=quant_config,
368
+ prefix=add_prefix("gate", prefix),
369
+ is_nextn=is_nextn,
323
370
  )
324
371
 
325
372
  self.experts = get_moe_impl_class(quant_config)(
@@ -344,9 +391,12 @@ class DeepseekV2MoE(nn.Module):
344
391
  num_fused_shared_experts=self.num_fused_shared_experts,
345
392
  topk_group=config.topk_group,
346
393
  correction_bias=self.gate.e_score_correction_bias,
394
+ quant_config=quant_config,
347
395
  routed_scaling_factor=self.routed_scaling_factor,
348
396
  apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
349
- force_topk=quant_config is None,
397
+ # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
398
+ # and requires the output format to be standard. We use quant_config to determine the output format.
399
+ output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
350
400
  )
351
401
 
352
402
  self.shared_experts_is_int8 = False
@@ -439,6 +489,7 @@ class DeepseekV2MoE(nn.Module):
439
489
  forward_batch: Optional[ForwardBatch] = None,
440
490
  should_allreduce_fusion: bool = False,
441
491
  use_reduce_scatter: bool = False,
492
+ gemm_output_zero_allocator: BumpAllocator = None,
442
493
  ) -> torch.Tensor:
443
494
  if not self._enable_deepep_moe:
444
495
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
@@ -452,12 +503,14 @@ class DeepseekV2MoE(nn.Module):
452
503
  hidden_states,
453
504
  should_allreduce_fusion,
454
505
  use_reduce_scatter,
506
+ gemm_output_zero_allocator,
455
507
  )
456
508
  else:
457
509
  return self.forward_normal(
458
510
  hidden_states,
459
511
  should_allreduce_fusion,
460
512
  use_reduce_scatter,
513
+ gemm_output_zero_allocator,
461
514
  )
462
515
  else:
463
516
  return self.forward_deepep(hidden_states, forward_batch)
@@ -467,15 +520,18 @@ class DeepseekV2MoE(nn.Module):
467
520
  hidden_states: torch.Tensor,
468
521
  should_allreduce_fusion: bool = False,
469
522
  use_reduce_scatter: bool = False,
523
+ gemm_output_zero_allocator: BumpAllocator = None,
470
524
  ) -> torch.Tensor:
471
525
 
472
526
  current_stream = torch.cuda.current_stream()
473
527
  self.alt_stream.wait_stream(current_stream)
474
- shared_output = self._forward_shared_experts(hidden_states)
528
+ shared_output = self._forward_shared_experts(
529
+ hidden_states, gemm_output_zero_allocator
530
+ )
475
531
 
476
532
  with torch.cuda.stream(self.alt_stream):
477
533
  # router_logits: (num_tokens, n_experts)
478
- router_logits = self.gate(hidden_states)
534
+ router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
479
535
  topk_output = self.topk(hidden_states, router_logits)
480
536
  final_hidden_states = self.experts(hidden_states, topk_output)
481
537
  if not _is_cuda:
@@ -502,6 +558,7 @@ class DeepseekV2MoE(nn.Module):
502
558
  hidden_states: torch.Tensor,
503
559
  should_allreduce_fusion: bool = False,
504
560
  use_reduce_scatter: bool = False,
561
+ gemm_output_zero_allocator: BumpAllocator = None,
505
562
  ) -> torch.Tensor:
506
563
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
507
564
  self.shared_experts.gate_up_proj
@@ -509,9 +566,11 @@ class DeepseekV2MoE(nn.Module):
509
566
  return self.forward_cpu(hidden_states, should_allreduce_fusion)
510
567
 
511
568
  if hidden_states.shape[0] > 0:
512
- shared_output = self._forward_shared_experts(hidden_states)
569
+ shared_output = self._forward_shared_experts(
570
+ hidden_states, gemm_output_zero_allocator
571
+ )
513
572
  # router_logits: (num_tokens, n_experts)
514
- router_logits = self.gate(hidden_states)
573
+ router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
515
574
  topk_output = self.topk(hidden_states, router_logits)
516
575
  else:
517
576
  shared_output = None
@@ -624,16 +683,24 @@ class DeepseekV2MoE(nn.Module):
624
683
 
625
684
  if shared_output is not None:
626
685
  x = shared_output
627
- x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
686
+ if self.experts.should_fuse_routed_scaling_factor_in_topk():
687
+ x.add_(final_hidden_states)
688
+ else:
689
+ x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
628
690
  final_hidden_states = x
629
691
  else:
630
- final_hidden_states *= self.routed_scaling_factor
692
+ if not self.experts.should_fuse_routed_scaling_factor_in_topk():
693
+ final_hidden_states *= self.routed_scaling_factor
631
694
 
632
695
  return final_hidden_states
633
696
 
634
- def _forward_shared_experts(self, hidden_states):
697
+ def _forward_shared_experts(
698
+ self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
699
+ ):
635
700
  if self.num_fused_shared_experts == 0:
636
- return self.shared_experts(hidden_states)
701
+ return self.shared_experts(
702
+ hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
703
+ )
637
704
  else:
638
705
  return None
639
706
 
@@ -992,6 +1059,15 @@ class DeepseekV2AttentionMLA(nn.Module):
992
1059
  # Determine attention backend used by current forward batch
993
1060
  if forward_batch.forward_mode.is_decode_or_idle():
994
1061
  attention_backend = global_server_args_dict["decode_attention_backend"]
1062
+ elif (
1063
+ forward_batch.forward_mode.is_target_verify()
1064
+ or forward_batch.forward_mode.is_draft_extend()
1065
+ ):
1066
+ # Use the specified backend for speculative operations (both verify and draft extend)
1067
+ if global_server_args_dict["speculative_attention_mode"] == "decode":
1068
+ attention_backend = global_server_args_dict["decode_attention_backend"]
1069
+ else: # default to prefill
1070
+ attention_backend = global_server_args_dict["prefill_attention_backend"]
995
1071
  else:
996
1072
  attention_backend = global_server_args_dict["prefill_attention_backend"]
997
1073
  self.current_attention_backend = attention_backend
@@ -1009,7 +1085,6 @@ class DeepseekV2AttentionMLA(nn.Module):
1009
1085
  attention_backend == "flashinfer"
1010
1086
  or attention_backend == "fa3"
1011
1087
  or attention_backend == "flashmla"
1012
- or attention_backend == "trtllm_mla"
1013
1088
  or attention_backend == "cutlass_mla"
1014
1089
  ):
1015
1090
  # Use MHA with chunked KV cache when prefilling on long sequences.
@@ -1022,6 +1097,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1022
1097
  disable_ragged = (
1023
1098
  attention_backend == "flashinfer" or attention_backend == "flashmla"
1024
1099
  ) and self.flashinfer_mla_disable_ragged
1100
+
1101
+ original_mode = getattr(forward_batch, "_original_forward_mode", None)
1025
1102
  if (
1026
1103
  not disable_ragged
1027
1104
  and forward_batch.forward_mode.is_extend()
@@ -1034,6 +1111,40 @@ class DeepseekV2AttentionMLA(nn.Module):
1034
1111
  )
1035
1112
  or sum_extend_prefix_lens == 0
1036
1113
  )
1114
+ # TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for
1115
+ # dp case. Redirect to mla kernel as a workaround.
1116
+ # Tracked by https://github.com/sgl-project/sglang/issues/9806.
1117
+ and not (
1118
+ original_mode is not None
1119
+ and original_mode.is_decode()
1120
+ and is_sm100_supported()
1121
+ and self.current_attention_backend in ("cutlass_mla", "flashinfer")
1122
+ )
1123
+ ):
1124
+ return AttnForwardMethod.MHA_CHUNKED_KV
1125
+ else:
1126
+ return _dispatch_mla_subtype()
1127
+ elif attention_backend == "trtllm_mla":
1128
+ original_mode = getattr(forward_batch, "_original_forward_mode", None)
1129
+ if (
1130
+ original_mode is not None
1131
+ and original_mode.is_decode()
1132
+ and is_sm100_supported()
1133
+ ):
1134
+ return _dispatch_mla_subtype()
1135
+
1136
+ sum_extend_prefix_lens = (
1137
+ sum(forward_batch.extend_prefix_lens_cpu)
1138
+ if forward_batch.extend_prefix_lens_cpu is not None
1139
+ else 0
1140
+ )
1141
+ if (
1142
+ forward_batch.forward_mode.is_extend()
1143
+ and not forward_batch.forward_mode.is_target_verify()
1144
+ and not forward_batch.forward_mode.is_draft_extend()
1145
+ and (
1146
+ not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
1147
+ )
1037
1148
  ):
1038
1149
  return AttnForwardMethod.MHA_CHUNKED_KV
1039
1150
  else:
@@ -1044,7 +1155,13 @@ class DeepseekV2AttentionMLA(nn.Module):
1044
1155
  and not forward_batch.forward_mode.is_target_verify()
1045
1156
  and not forward_batch.forward_mode.is_draft_extend()
1046
1157
  ):
1047
- return AttnForwardMethod.MHA
1158
+ if is_dp_attention_enabled():
1159
+ if sum(forward_batch.extend_prefix_lens_cpu) == 0:
1160
+ return AttnForwardMethod.MHA
1161
+ else:
1162
+ return AttnForwardMethod.MLA
1163
+ else:
1164
+ return AttnForwardMethod.MHA
1048
1165
  else:
1049
1166
  return AttnForwardMethod.MLA
1050
1167
  else:
@@ -1097,11 +1214,19 @@ class DeepseekV2AttentionMLA(nn.Module):
1097
1214
  if self.attn_mha.kv_b_proj is None:
1098
1215
  self.attn_mha.kv_b_proj = self.kv_b_proj
1099
1216
 
1100
- if hidden_states.shape[0] == 0:
1101
- assert (
1102
- not self.o_proj.reduce_results
1103
- ), "short-circuiting allreduce will lead to hangs"
1104
- return hidden_states, None, forward_batch, None
1217
+ # when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
1218
+ if isinstance(hidden_states, tuple):
1219
+ if hidden_states[0].shape[0] == 0:
1220
+ assert (
1221
+ not self.o_proj.reduce_results
1222
+ ), "short-circuiting allreduce will lead to hangs"
1223
+ return hidden_states[0]
1224
+ else:
1225
+ if hidden_states.shape[0] == 0:
1226
+ assert (
1227
+ not self.o_proj.reduce_results
1228
+ ), "short-circuiting allreduce will lead to hangs"
1229
+ return hidden_states, None, forward_batch, None
1105
1230
 
1106
1231
  attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
1107
1232
 
@@ -1180,8 +1305,18 @@ class DeepseekV2AttentionMLA(nn.Module):
1180
1305
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1181
1306
  q[..., self.qk_nope_head_dim :] = q_pe
1182
1307
  k = torch.empty_like(q)
1183
- k[..., : self.qk_nope_head_dim] = k_nope
1184
- k[..., self.qk_nope_head_dim :] = k_pe
1308
+
1309
+ # Temporary for DeepSeek V3/R1 only, but can generalize if needed
1310
+ if (
1311
+ _is_cuda
1312
+ and (self.num_local_heads == 128)
1313
+ and (self.qk_nope_head_dim == 128)
1314
+ and (self.qk_rope_head_dim == 64)
1315
+ ):
1316
+ concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
1317
+ else:
1318
+ k[..., : self.qk_nope_head_dim] = k_nope
1319
+ k[..., self.qk_nope_head_dim :] = k_pe
1185
1320
 
1186
1321
  if not _is_npu:
1187
1322
  latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
@@ -1225,7 +1360,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1225
1360
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1226
1361
 
1227
1362
  if self.q_lora_rank is not None:
1228
- if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
1363
+ if (
1364
+ (not isinstance(hidden_states, tuple))
1365
+ and hidden_states.shape[0] <= 16
1366
+ and self.use_min_latency_fused_a_gemm
1367
+ ):
1229
1368
  fused_qkv_a_proj_out = dsv3_fused_a_gemm(
1230
1369
  hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
1231
1370
  )
@@ -1245,8 +1384,18 @@ class DeepseekV2AttentionMLA(nn.Module):
1245
1384
  k_nope = self.kv_a_layernorm(k_nope)
1246
1385
  current_stream.wait_stream(self.alt_stream)
1247
1386
  else:
1248
- q = self.q_a_layernorm(q)
1249
- k_nope = self.kv_a_layernorm(k_nope)
1387
+ if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
1388
+ q, k_nope = fused_rms_mxfp4_quant(
1389
+ q,
1390
+ self.q_a_layernorm.weight,
1391
+ self.q_a_layernorm.variance_epsilon,
1392
+ k_nope,
1393
+ self.kv_a_layernorm.weight,
1394
+ self.kv_a_layernorm.variance_epsilon,
1395
+ )
1396
+ else:
1397
+ q = self.q_a_layernorm(q)
1398
+ k_nope = self.kv_a_layernorm(k_nope)
1250
1399
 
1251
1400
  k_nope = k_nope.unsqueeze(1)
1252
1401
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
@@ -1278,10 +1427,27 @@ class DeepseekV2AttentionMLA(nn.Module):
1278
1427
  q_nope_out = q_nope_out[:, :expected_m, :]
1279
1428
  elif _is_hip:
1280
1429
  # TODO(haishaw): add bmm_fp8 to ROCm
1281
- q_nope_out = torch.bmm(
1282
- q_nope.to(torch.bfloat16).transpose(0, 1),
1283
- self.w_kc.to(torch.bfloat16) * self.w_scale,
1284
- )
1430
+ if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
1431
+ x = q_nope.transpose(0, 1)
1432
+ q_nope_out = torch.empty(
1433
+ x.shape[0],
1434
+ x.shape[1],
1435
+ self.w_kc.shape[2],
1436
+ device=x.device,
1437
+ dtype=torch.bfloat16,
1438
+ )
1439
+ batched_gemm_afp4wfp4_pre_quant(
1440
+ x,
1441
+ self.w_kc.transpose(-2, -1),
1442
+ self.w_scale_k.transpose(-2, -1),
1443
+ torch.bfloat16,
1444
+ q_nope_out,
1445
+ )
1446
+ else:
1447
+ q_nope_out = torch.bmm(
1448
+ q_nope.to(torch.bfloat16).transpose(0, 1),
1449
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
1450
+ )
1285
1451
  elif self.w_kc.dtype == torch.float8_e4m3fn:
1286
1452
  q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1287
1453
  q_nope.transpose(0, 1),
@@ -1295,13 +1461,15 @@ class DeepseekV2AttentionMLA(nn.Module):
1295
1461
 
1296
1462
  q_nope_out = q_nope_out.transpose(0, 1)
1297
1463
 
1298
- if not self._fuse_rope_for_trtllm_mla(forward_batch):
1464
+ if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1465
+ not _use_aiter or not _is_gfx95_supported
1466
+ ):
1299
1467
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1300
1468
 
1301
- return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1469
+ return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1302
1470
 
1303
1471
  def forward_absorb_core(
1304
- self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1472
+ self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1305
1473
  ):
1306
1474
  if (
1307
1475
  self.current_attention_backend == "fa3"
@@ -1326,8 +1494,23 @@ class DeepseekV2AttentionMLA(nn.Module):
1326
1494
  **extra_args,
1327
1495
  )
1328
1496
  else:
1329
- q = torch.cat([q_nope_out, q_pe], dim=-1)
1330
- k = torch.cat([k_nope, k_pe], dim=-1)
1497
+ if _use_aiter_gfx95:
1498
+ cos = self.rotary_emb.cos_cache
1499
+ sin = self.rotary_emb.sin_cache
1500
+ q, k = fused_qk_rope_cat(
1501
+ q_nope_out,
1502
+ q_pe,
1503
+ k_nope,
1504
+ k_pe,
1505
+ positions,
1506
+ cos,
1507
+ sin,
1508
+ self.rotary_emb.is_neox_style,
1509
+ )
1510
+ else:
1511
+ q = torch.cat([q_nope_out, q_pe], dim=-1)
1512
+ k = torch.cat([k_nope, k_pe], dim=-1)
1513
+
1331
1514
  attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1332
1515
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1333
1516
 
@@ -1352,11 +1535,34 @@ class DeepseekV2AttentionMLA(nn.Module):
1352
1535
  )
1353
1536
  elif _is_hip:
1354
1537
  # TODO(haishaw): add bmm_fp8 to ROCm
1355
- attn_bmm_output = torch.bmm(
1356
- attn_output.to(torch.bfloat16).transpose(0, 1),
1357
- self.w_vc.to(torch.bfloat16) * self.w_scale,
1358
- )
1359
- attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1538
+ if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
1539
+ x = attn_output.transpose(0, 1)
1540
+ attn_bmm_output = torch.empty(
1541
+ x.shape[0],
1542
+ x.shape[1],
1543
+ self.w_vc.shape[2],
1544
+ device=x.device,
1545
+ dtype=torch.bfloat16,
1546
+ )
1547
+ batched_gemm_afp4wfp4_pre_quant(
1548
+ x,
1549
+ self.w_vc.transpose(-2, -1),
1550
+ self.w_scale_v.transpose(-2, -1),
1551
+ torch.bfloat16,
1552
+ attn_bmm_output,
1553
+ )
1554
+ else:
1555
+ attn_bmm_output = torch.bmm(
1556
+ attn_output.to(torch.bfloat16).transpose(0, 1),
1557
+ self.w_vc.to(torch.bfloat16) * self.w_scale,
1558
+ )
1559
+
1560
+ if self.o_proj.weight.dtype == torch.uint8:
1561
+ attn_bmm_output = attn_bmm_output.transpose(0, 1)
1562
+ attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
1563
+ else:
1564
+ attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1565
+
1360
1566
  elif self.w_vc.dtype == torch.float8_e4m3fn:
1361
1567
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1362
1568
  attn_output.transpose(0, 1),
@@ -1678,9 +1884,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1678
1884
  latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
1679
1885
  self.attn_mha.layer_id
1680
1886
  )
1681
- latent_cache = latent_cache_buf[
1682
- forward_batch.prefix_chunk_kv_indices[i]
1683
- ].contiguous()
1887
+ latent_cache = (
1888
+ latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
1889
+ .contiguous()
1890
+ .to(q.dtype)
1891
+ )
1684
1892
 
1685
1893
  kv_a_normed, k_pe = latent_cache.split(
1686
1894
  [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
@@ -1864,10 +2072,24 @@ class DeepseekV2DecoderLayer(nn.Module):
1864
2072
  forward_batch: ForwardBatch,
1865
2073
  residual: Optional[torch.Tensor],
1866
2074
  zero_allocator: BumpAllocator,
2075
+ gemm_output_zero_allocator: BumpAllocator = None,
1867
2076
  ) -> torch.Tensor:
1868
2077
 
2078
+ quant_format = (
2079
+ "mxfp4"
2080
+ if _is_gfx95_supported
2081
+ and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
2082
+ and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
2083
+ is not None
2084
+ and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
2085
+ else ""
2086
+ )
2087
+
1869
2088
  hidden_states, residual = self.layer_communicator.prepare_attn(
1870
- hidden_states, residual, forward_batch
2089
+ hidden_states,
2090
+ residual,
2091
+ forward_batch,
2092
+ quant_format,
1871
2093
  )
1872
2094
 
1873
2095
  hidden_states = self.self_attn(
@@ -1891,8 +2113,16 @@ class DeepseekV2DecoderLayer(nn.Module):
1891
2113
  use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
1892
2114
  forward_batch
1893
2115
  )
2116
+
2117
+ if isinstance(self.mlp, DeepseekV2MLP):
2118
+ gemm_output_zero_allocator = None
2119
+
1894
2120
  hidden_states = self.mlp(
1895
- hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
2121
+ hidden_states,
2122
+ forward_batch,
2123
+ should_allreduce_fusion,
2124
+ use_reduce_scatter,
2125
+ gemm_output_zero_allocator,
1896
2126
  )
1897
2127
 
1898
2128
  if should_allreduce_fusion:
@@ -2023,8 +2253,15 @@ class DeepseekV2Model(nn.Module):
2023
2253
  [
2024
2254
  "w13_weight",
2025
2255
  "w2_weight",
2026
- "w13_blockscale_swizzled",
2027
- "w2_blockscale_swizzled",
2256
+ # only for nvfp4
2257
+ *(
2258
+ [
2259
+ "w13_blockscale_swizzled",
2260
+ "w2_blockscale_swizzled",
2261
+ ]
2262
+ if hasattr(module, "w13_blockscale_swizzled")
2263
+ else []
2264
+ ),
2028
2265
  ]
2029
2266
  if isinstance(module, FusedMoE)
2030
2267
  else []
@@ -2036,6 +2273,37 @@ class DeepseekV2Model(nn.Module):
2036
2273
  else:
2037
2274
  self.norm = PPMissingLayer(return_tuple=True)
2038
2275
 
2276
+ self.gemm_output_zero_allocator_size = 0
2277
+ if (
2278
+ _use_aiter_gfx95
2279
+ and config.n_routed_experts == 256
2280
+ and self.embed_tokens.embedding_dim == 7168
2281
+ ):
2282
+ num_moe_layers = sum(
2283
+ [
2284
+ 1
2285
+ for i in range(len(self.layers))
2286
+ if isinstance(self.layers[i].mlp, DeepseekV2MoE)
2287
+ ]
2288
+ )
2289
+
2290
+ allocate_size = 0
2291
+ for i in range(len(self.layers)):
2292
+ if isinstance(self.layers[i].mlp, DeepseekV2MoE):
2293
+ allocate_size = self.layers[
2294
+ i
2295
+ ].mlp.shared_experts.gate_up_proj.output_size_per_partition
2296
+ break
2297
+
2298
+ self.gemm_output_zero_allocator_size = (
2299
+ get_dsv3_gemm_output_zero_allocator_size(
2300
+ config.n_routed_experts,
2301
+ num_moe_layers,
2302
+ allocate_size,
2303
+ self.embed_tokens.embedding_dim,
2304
+ )
2305
+ )
2306
+
2039
2307
  def get_input_embeddings(self) -> torch.Tensor:
2040
2308
  return self.embed_tokens
2041
2309
 
@@ -2055,6 +2323,21 @@ class DeepseekV2Model(nn.Module):
2055
2323
  device=device,
2056
2324
  )
2057
2325
 
2326
+ has_gemm_output_zero_allocator = hasattr(
2327
+ self, "gemm_output_zero_allocator_size"
2328
+ )
2329
+
2330
+ gemm_output_zero_allocator = (
2331
+ BumpAllocator(
2332
+ buffer_size=self.gemm_output_zero_allocator_size,
2333
+ dtype=torch.float32,
2334
+ device=device,
2335
+ )
2336
+ if has_gemm_output_zero_allocator
2337
+ and self.gemm_output_zero_allocator_size > 0
2338
+ else None
2339
+ )
2340
+
2058
2341
  if self.pp_group.is_first_rank:
2059
2342
  if input_embeds is None:
2060
2343
  hidden_states = self.embed_tokens(input_ids)
@@ -2081,7 +2364,12 @@ class DeepseekV2Model(nn.Module):
2081
2364
  with get_global_expert_distribution_recorder().with_current_layer(i):
2082
2365
  layer = self.layers[i]
2083
2366
  hidden_states, residual = layer(
2084
- positions, hidden_states, forward_batch, residual, zero_allocator
2367
+ positions,
2368
+ hidden_states,
2369
+ forward_batch,
2370
+ residual,
2371
+ zero_allocator,
2372
+ gemm_output_zero_allocator,
2085
2373
  )
2086
2374
 
2087
2375
  if normal_end_layer != self.end_layer:
@@ -2354,6 +2642,16 @@ class DeepseekV2ForCausalLM(nn.Module):
2354
2642
  w_kc, w_vc = w.unflatten(
2355
2643
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
2356
2644
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
2645
+
2646
+ if (
2647
+ _use_aiter_gfx95
2648
+ and self.quant_config is not None
2649
+ and self.quant_config.get_name() == "quark"
2650
+ ):
2651
+ w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
2652
+ quark_post_load_weights(self_attn, w, "mxfp4")
2653
+ )
2654
+
2357
2655
  if not use_deep_gemm_bmm:
2358
2656
  self_attn.w_kc = bind_or_assign(
2359
2657
  self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)