sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -67,7 +67,10 @@ from sglang.srt.layers.moe import (
67
67
  should_use_flashinfer_cutlass_moe_fp4_allgather,
68
68
  )
69
69
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
70
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
70
+ from sglang.srt.layers.moe.fused_moe_triton.layer import (
71
+ FusedMoE,
72
+ _is_fp4_quantization_enabled,
73
+ )
71
74
  from sglang.srt.layers.moe.topk import TopK
72
75
  from sglang.srt.layers.quantization import deep_gemm_wrapper
73
76
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -112,8 +115,10 @@ from sglang.srt.utils import (
112
115
  is_cpu,
113
116
  is_cuda,
114
117
  is_flashinfer_available,
118
+ is_gfx95_supported,
115
119
  is_hip,
116
120
  is_non_idle_and_non_empty,
121
+ is_npu,
117
122
  is_sm100_supported,
118
123
  log_info_on_rank0,
119
124
  make_layers,
@@ -122,11 +127,28 @@ from sglang.srt.utils import (
122
127
 
123
128
  _is_hip = is_hip()
124
129
  _is_cuda = is_cuda()
130
+ _is_npu = is_npu()
125
131
  _is_fp8_fnuz = is_fp8_fnuz()
126
132
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
127
133
  _is_cpu_amx_available = cpu_has_amx_support()
128
134
  _is_cpu = is_cpu()
129
135
  _device_sm = get_device_sm()
136
+ _is_gfx95_supported = is_gfx95_supported()
137
+
138
+ _use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
139
+
140
+ if _use_aiter_gfx95:
141
+ from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
142
+ from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
143
+ batched_gemm_afp4wfp4_pre_quant,
144
+ fused_flatten_mxfp4_quant,
145
+ fused_rms_mxfp4_quant,
146
+ )
147
+ from sglang.srt.layers.rocm_linear_utils import (
148
+ aiter_dsv3_router_gemm,
149
+ fused_qk_rope_cat,
150
+ get_dsv3_gemm_output_zero_allocator_size,
151
+ )
130
152
 
131
153
  if _is_cuda:
132
154
  from sgl_kernel import (
@@ -222,10 +244,21 @@ class DeepseekV2MLP(nn.Module):
222
244
  forward_batch=None,
223
245
  should_allreduce_fusion: bool = False,
224
246
  use_reduce_scatter: bool = False,
247
+ gemm_output_zero_allocator: BumpAllocator = None,
225
248
  ):
226
249
  if (self.tp_size == 1) and x.shape[0] == 0:
227
250
  return x
228
251
 
252
+ if (
253
+ gemm_output_zero_allocator is not None
254
+ and x.shape[0] <= 256
255
+ and self.gate_up_proj.weight.dtype == torch.uint8
256
+ ):
257
+ y = gemm_output_zero_allocator.allocate(
258
+ x.shape[0] * self.gate_up_proj.output_size_per_partition
259
+ ).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
260
+ x = (x, None, y)
261
+
229
262
  gate_up, _ = self.gate_up_proj(x)
230
263
  x = self.act_fn(gate_up)
231
264
  x, _ = self.down_proj(
@@ -255,7 +288,7 @@ class MoEGate(nn.Module):
255
288
  if _is_cpu and _is_cpu_amx_available:
256
289
  self.quant_method = PackWeightMethod(weight_names=["weight"])
257
290
 
258
- def forward(self, hidden_states):
291
+ def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
259
292
  if use_intel_amx_backend(self):
260
293
  return torch.ops.sgl_kernel.weight_packed_linear(
261
294
  hidden_states,
@@ -273,7 +306,13 @@ class MoEGate(nn.Module):
273
306
  and _device_sm >= 90
274
307
  ):
275
308
  # router gemm output float32
276
- logits = dsv3_router_gemm(hidden_states, self.weight)
309
+ logits = dsv3_router_gemm(
310
+ hidden_states, self.weight, out_dtype=torch.float32
311
+ )
312
+ elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
313
+ logits = aiter_dsv3_router_gemm(
314
+ hidden_states, self.weight, gemm_output_zero_allocator
315
+ )
277
316
  else:
278
317
  logits = F.linear(hidden_states, self.weight, None)
279
318
 
@@ -334,6 +373,9 @@ class DeepseekV2MoE(nn.Module):
334
373
  prefix=add_prefix("experts", prefix),
335
374
  )
336
375
 
376
+ correction_bias = self.gate.e_score_correction_bias
377
+ if _is_fp4_quantization_enabled():
378
+ correction_bias = correction_bias.to(torch.bfloat16)
337
379
  self.topk = TopK(
338
380
  top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
339
381
  renormalize=config.norm_topk_prob,
@@ -341,7 +383,7 @@ class DeepseekV2MoE(nn.Module):
341
383
  num_expert_group=config.n_group,
342
384
  num_fused_shared_experts=self.num_fused_shared_experts,
343
385
  topk_group=config.topk_group,
344
- correction_bias=self.gate.e_score_correction_bias,
386
+ correction_bias=correction_bias,
345
387
  routed_scaling_factor=self.routed_scaling_factor,
346
388
  apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
347
389
  force_topk=quant_config is None,
@@ -437,6 +479,7 @@ class DeepseekV2MoE(nn.Module):
437
479
  forward_batch: Optional[ForwardBatch] = None,
438
480
  should_allreduce_fusion: bool = False,
439
481
  use_reduce_scatter: bool = False,
482
+ gemm_output_zero_allocator: BumpAllocator = None,
440
483
  ) -> torch.Tensor:
441
484
  if not self._enable_deepep_moe:
442
485
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
@@ -450,12 +493,14 @@ class DeepseekV2MoE(nn.Module):
450
493
  hidden_states,
451
494
  should_allreduce_fusion,
452
495
  use_reduce_scatter,
496
+ gemm_output_zero_allocator,
453
497
  )
454
498
  else:
455
499
  return self.forward_normal(
456
500
  hidden_states,
457
501
  should_allreduce_fusion,
458
502
  use_reduce_scatter,
503
+ gemm_output_zero_allocator,
459
504
  )
460
505
  else:
461
506
  return self.forward_deepep(hidden_states, forward_batch)
@@ -465,15 +510,18 @@ class DeepseekV2MoE(nn.Module):
465
510
  hidden_states: torch.Tensor,
466
511
  should_allreduce_fusion: bool = False,
467
512
  use_reduce_scatter: bool = False,
513
+ gemm_output_zero_allocator: BumpAllocator = None,
468
514
  ) -> torch.Tensor:
469
515
 
470
516
  current_stream = torch.cuda.current_stream()
471
517
  self.alt_stream.wait_stream(current_stream)
472
- shared_output = self._forward_shared_experts(hidden_states)
518
+ shared_output = self._forward_shared_experts(
519
+ hidden_states, gemm_output_zero_allocator
520
+ )
473
521
 
474
522
  with torch.cuda.stream(self.alt_stream):
475
523
  # router_logits: (num_tokens, n_experts)
476
- router_logits = self.gate(hidden_states)
524
+ router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
477
525
  topk_output = self.topk(hidden_states, router_logits)
478
526
  final_hidden_states = self.experts(hidden_states, topk_output)
479
527
  if not _is_cuda:
@@ -500,6 +548,7 @@ class DeepseekV2MoE(nn.Module):
500
548
  hidden_states: torch.Tensor,
501
549
  should_allreduce_fusion: bool = False,
502
550
  use_reduce_scatter: bool = False,
551
+ gemm_output_zero_allocator: BumpAllocator = None,
503
552
  ) -> torch.Tensor:
504
553
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
505
554
  self.shared_experts.gate_up_proj
@@ -507,9 +556,11 @@ class DeepseekV2MoE(nn.Module):
507
556
  return self.forward_cpu(hidden_states, should_allreduce_fusion)
508
557
 
509
558
  if hidden_states.shape[0] > 0:
510
- shared_output = self._forward_shared_experts(hidden_states)
559
+ shared_output = self._forward_shared_experts(
560
+ hidden_states, gemm_output_zero_allocator
561
+ )
511
562
  # router_logits: (num_tokens, n_experts)
512
- router_logits = self.gate(hidden_states)
563
+ router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
513
564
  topk_output = self.topk(hidden_states, router_logits)
514
565
  else:
515
566
  shared_output = None
@@ -629,9 +680,13 @@ class DeepseekV2MoE(nn.Module):
629
680
 
630
681
  return final_hidden_states
631
682
 
632
- def _forward_shared_experts(self, hidden_states):
683
+ def _forward_shared_experts(
684
+ self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
685
+ ):
633
686
  if self.num_fused_shared_experts == 0:
634
- return self.shared_experts(hidden_states)
687
+ return self.shared_experts(
688
+ hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
689
+ )
635
690
  else:
636
691
  return None
637
692
 
@@ -990,6 +1045,15 @@ class DeepseekV2AttentionMLA(nn.Module):
990
1045
  # Determine attention backend used by current forward batch
991
1046
  if forward_batch.forward_mode.is_decode_or_idle():
992
1047
  attention_backend = global_server_args_dict["decode_attention_backend"]
1048
+ elif (
1049
+ forward_batch.forward_mode.is_target_verify()
1050
+ or forward_batch.forward_mode.is_draft_extend()
1051
+ ):
1052
+ # Use the specified backend for speculative operations (both verify and draft extend)
1053
+ if global_server_args_dict["speculative_attention_mode"] == "decode":
1054
+ attention_backend = global_server_args_dict["decode_attention_backend"]
1055
+ else: # default to prefill
1056
+ attention_backend = global_server_args_dict["prefill_attention_backend"]
993
1057
  else:
994
1058
  attention_backend = global_server_args_dict["prefill_attention_backend"]
995
1059
  self.current_attention_backend = attention_backend
@@ -1007,7 +1071,6 @@ class DeepseekV2AttentionMLA(nn.Module):
1007
1071
  attention_backend == "flashinfer"
1008
1072
  or attention_backend == "fa3"
1009
1073
  or attention_backend == "flashmla"
1010
- or attention_backend == "trtllm_mla"
1011
1074
  or attention_backend == "cutlass_mla"
1012
1075
  ):
1013
1076
  # Use MHA with chunked KV cache when prefilling on long sequences.
@@ -1036,13 +1099,28 @@ class DeepseekV2AttentionMLA(nn.Module):
1036
1099
  return AttnForwardMethod.MHA_CHUNKED_KV
1037
1100
  else:
1038
1101
  return _dispatch_mla_subtype()
1102
+ elif attention_backend == "trtllm_mla":
1103
+ if (
1104
+ forward_batch.forward_mode.is_extend()
1105
+ and not forward_batch.forward_mode.is_target_verify()
1106
+ and not forward_batch.forward_mode.is_draft_extend()
1107
+ ):
1108
+ return AttnForwardMethod.MHA_CHUNKED_KV
1109
+ else:
1110
+ return _dispatch_mla_subtype()
1039
1111
  elif attention_backend == "aiter":
1040
1112
  if (
1041
1113
  forward_batch.forward_mode.is_extend()
1042
1114
  and not forward_batch.forward_mode.is_target_verify()
1043
1115
  and not forward_batch.forward_mode.is_draft_extend()
1044
1116
  ):
1045
- return AttnForwardMethod.MHA
1117
+ if is_dp_attention_enabled():
1118
+ if sum(forward_batch.extend_prefix_lens_cpu) == 0:
1119
+ return AttnForwardMethod.MHA
1120
+ else:
1121
+ return AttnForwardMethod.MLA
1122
+ else:
1123
+ return AttnForwardMethod.MHA
1046
1124
  else:
1047
1125
  return AttnForwardMethod.MLA
1048
1126
  else:
@@ -1095,11 +1173,19 @@ class DeepseekV2AttentionMLA(nn.Module):
1095
1173
  if self.attn_mha.kv_b_proj is None:
1096
1174
  self.attn_mha.kv_b_proj = self.kv_b_proj
1097
1175
 
1098
- if hidden_states.shape[0] == 0:
1099
- assert (
1100
- not self.o_proj.reduce_results
1101
- ), "short-circuiting allreduce will lead to hangs"
1102
- return hidden_states, None, forward_batch, None
1176
+ # when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
1177
+ if isinstance(hidden_states, tuple):
1178
+ if hidden_states[0].shape[0] == 0:
1179
+ assert (
1180
+ not self.o_proj.reduce_results
1181
+ ), "short-circuiting allreduce will lead to hangs"
1182
+ return hidden_states[0]
1183
+ else:
1184
+ if hidden_states.shape[0] == 0:
1185
+ assert (
1186
+ not self.o_proj.reduce_results
1187
+ ), "short-circuiting allreduce will lead to hangs"
1188
+ return hidden_states, None, forward_batch, None
1103
1189
 
1104
1190
  attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
1105
1191
 
@@ -1181,13 +1267,19 @@ class DeepseekV2AttentionMLA(nn.Module):
1181
1267
  k[..., : self.qk_nope_head_dim] = k_nope
1182
1268
  k[..., self.qk_nope_head_dim :] = k_pe
1183
1269
 
1184
- latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1185
- latent_cache[:, :, self.kv_lora_rank :] = k_pe
1270
+ if not _is_npu:
1271
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1272
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
1186
1273
 
1187
- # Save latent cache
1188
- forward_batch.token_to_kv_pool.set_kv_buffer(
1189
- self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1190
- )
1274
+ # Save latent cache
1275
+ forward_batch.token_to_kv_pool.set_kv_buffer(
1276
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1277
+ )
1278
+ else:
1279
+ # To reduce a time-costing split operation
1280
+ forward_batch.token_to_kv_pool.set_kv_buffer(
1281
+ self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
1282
+ )
1191
1283
 
1192
1284
  return q, k, v, forward_batch
1193
1285
 
@@ -1217,7 +1309,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1217
1309
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1218
1310
 
1219
1311
  if self.q_lora_rank is not None:
1220
- if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
1312
+ if (
1313
+ (not isinstance(hidden_states, tuple))
1314
+ and hidden_states.shape[0] <= 16
1315
+ and self.use_min_latency_fused_a_gemm
1316
+ ):
1221
1317
  fused_qkv_a_proj_out = dsv3_fused_a_gemm(
1222
1318
  hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
1223
1319
  )
@@ -1237,8 +1333,18 @@ class DeepseekV2AttentionMLA(nn.Module):
1237
1333
  k_nope = self.kv_a_layernorm(k_nope)
1238
1334
  current_stream.wait_stream(self.alt_stream)
1239
1335
  else:
1240
- q = self.q_a_layernorm(q)
1241
- k_nope = self.kv_a_layernorm(k_nope)
1336
+ if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
1337
+ q, k_nope = fused_rms_mxfp4_quant(
1338
+ q,
1339
+ self.q_a_layernorm.weight,
1340
+ self.q_a_layernorm.variance_epsilon,
1341
+ k_nope,
1342
+ self.kv_a_layernorm.weight,
1343
+ self.kv_a_layernorm.variance_epsilon,
1344
+ )
1345
+ else:
1346
+ q = self.q_a_layernorm(q)
1347
+ k_nope = self.kv_a_layernorm(k_nope)
1242
1348
 
1243
1349
  k_nope = k_nope.unsqueeze(1)
1244
1350
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
@@ -1270,10 +1376,27 @@ class DeepseekV2AttentionMLA(nn.Module):
1270
1376
  q_nope_out = q_nope_out[:, :expected_m, :]
1271
1377
  elif _is_hip:
1272
1378
  # TODO(haishaw): add bmm_fp8 to ROCm
1273
- q_nope_out = torch.bmm(
1274
- q_nope.to(torch.bfloat16).transpose(0, 1),
1275
- self.w_kc.to(torch.bfloat16) * self.w_scale,
1276
- )
1379
+ if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
1380
+ x = q_nope.transpose(0, 1)
1381
+ q_nope_out = torch.empty(
1382
+ x.shape[0],
1383
+ x.shape[1],
1384
+ self.w_kc.shape[2],
1385
+ device=x.device,
1386
+ dtype=torch.bfloat16,
1387
+ )
1388
+ batched_gemm_afp4wfp4_pre_quant(
1389
+ x,
1390
+ self.w_kc.transpose(-2, -1),
1391
+ self.w_scale_k.transpose(-2, -1),
1392
+ torch.bfloat16,
1393
+ q_nope_out,
1394
+ )
1395
+ else:
1396
+ q_nope_out = torch.bmm(
1397
+ q_nope.to(torch.bfloat16).transpose(0, 1),
1398
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
1399
+ )
1277
1400
  elif self.w_kc.dtype == torch.float8_e4m3fn:
1278
1401
  q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1279
1402
  q_nope.transpose(0, 1),
@@ -1287,13 +1410,15 @@ class DeepseekV2AttentionMLA(nn.Module):
1287
1410
 
1288
1411
  q_nope_out = q_nope_out.transpose(0, 1)
1289
1412
 
1290
- if not self._fuse_rope_for_trtllm_mla(forward_batch):
1413
+ if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1414
+ not _use_aiter or not _is_gfx95_supported
1415
+ ):
1291
1416
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1292
1417
 
1293
- return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1418
+ return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1294
1419
 
1295
1420
  def forward_absorb_core(
1296
- self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1421
+ self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1297
1422
  ):
1298
1423
  if (
1299
1424
  self.current_attention_backend == "fa3"
@@ -1318,8 +1443,23 @@ class DeepseekV2AttentionMLA(nn.Module):
1318
1443
  **extra_args,
1319
1444
  )
1320
1445
  else:
1321
- q = torch.cat([q_nope_out, q_pe], dim=-1)
1322
- k = torch.cat([k_nope, k_pe], dim=-1)
1446
+ if _use_aiter_gfx95:
1447
+ cos = self.rotary_emb.cos_cache
1448
+ sin = self.rotary_emb.sin_cache
1449
+ q, k = fused_qk_rope_cat(
1450
+ q_nope_out,
1451
+ q_pe,
1452
+ k_nope,
1453
+ k_pe,
1454
+ positions,
1455
+ cos,
1456
+ sin,
1457
+ self.rotary_emb.is_neox_style,
1458
+ )
1459
+ else:
1460
+ q = torch.cat([q_nope_out, q_pe], dim=-1)
1461
+ k = torch.cat([k_nope, k_pe], dim=-1)
1462
+
1323
1463
  attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1324
1464
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1325
1465
 
@@ -1344,11 +1484,34 @@ class DeepseekV2AttentionMLA(nn.Module):
1344
1484
  )
1345
1485
  elif _is_hip:
1346
1486
  # TODO(haishaw): add bmm_fp8 to ROCm
1347
- attn_bmm_output = torch.bmm(
1348
- attn_output.to(torch.bfloat16).transpose(0, 1),
1349
- self.w_vc.to(torch.bfloat16) * self.w_scale,
1350
- )
1351
- attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1487
+ if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
1488
+ x = attn_output.transpose(0, 1)
1489
+ attn_bmm_output = torch.empty(
1490
+ x.shape[0],
1491
+ x.shape[1],
1492
+ self.w_vc.shape[2],
1493
+ device=x.device,
1494
+ dtype=torch.bfloat16,
1495
+ )
1496
+ batched_gemm_afp4wfp4_pre_quant(
1497
+ x,
1498
+ self.w_vc.transpose(-2, -1),
1499
+ self.w_scale_v.transpose(-2, -1),
1500
+ torch.bfloat16,
1501
+ attn_bmm_output,
1502
+ )
1503
+ else:
1504
+ attn_bmm_output = torch.bmm(
1505
+ attn_output.to(torch.bfloat16).transpose(0, 1),
1506
+ self.w_vc.to(torch.bfloat16) * self.w_scale,
1507
+ )
1508
+
1509
+ if self.o_proj.weight.dtype == torch.uint8:
1510
+ attn_bmm_output = attn_bmm_output.transpose(0, 1)
1511
+ attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
1512
+ else:
1513
+ attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1514
+
1352
1515
  elif self.w_vc.dtype == torch.float8_e4m3fn:
1353
1516
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1354
1517
  attn_output.transpose(0, 1),
@@ -1670,9 +1833,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1670
1833
  latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
1671
1834
  self.attn_mha.layer_id
1672
1835
  )
1673
- latent_cache = latent_cache_buf[
1674
- forward_batch.prefix_chunk_kv_indices[i]
1675
- ].contiguous()
1836
+ latent_cache = (
1837
+ latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
1838
+ .contiguous()
1839
+ .to(q.dtype)
1840
+ )
1676
1841
 
1677
1842
  kv_a_normed, k_pe = latent_cache.split(
1678
1843
  [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
@@ -1856,10 +2021,24 @@ class DeepseekV2DecoderLayer(nn.Module):
1856
2021
  forward_batch: ForwardBatch,
1857
2022
  residual: Optional[torch.Tensor],
1858
2023
  zero_allocator: BumpAllocator,
2024
+ gemm_output_zero_allocator: BumpAllocator = None,
1859
2025
  ) -> torch.Tensor:
1860
2026
 
2027
+ quant_format = (
2028
+ "mxfp4"
2029
+ if _is_gfx95_supported
2030
+ and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
2031
+ and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
2032
+ is not None
2033
+ and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
2034
+ else ""
2035
+ )
2036
+
1861
2037
  hidden_states, residual = self.layer_communicator.prepare_attn(
1862
- hidden_states, residual, forward_batch
2038
+ hidden_states,
2039
+ residual,
2040
+ forward_batch,
2041
+ quant_format,
1863
2042
  )
1864
2043
 
1865
2044
  hidden_states = self.self_attn(
@@ -1883,8 +2062,16 @@ class DeepseekV2DecoderLayer(nn.Module):
1883
2062
  use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
1884
2063
  forward_batch
1885
2064
  )
2065
+
2066
+ if isinstance(self.mlp, DeepseekV2MLP):
2067
+ gemm_output_zero_allocator = None
2068
+
1886
2069
  hidden_states = self.mlp(
1887
- hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
2070
+ hidden_states,
2071
+ forward_batch,
2072
+ should_allreduce_fusion,
2073
+ use_reduce_scatter,
2074
+ gemm_output_zero_allocator,
1888
2075
  )
1889
2076
 
1890
2077
  if should_allreduce_fusion:
@@ -2028,6 +2215,37 @@ class DeepseekV2Model(nn.Module):
2028
2215
  else:
2029
2216
  self.norm = PPMissingLayer(return_tuple=True)
2030
2217
 
2218
+ self.gemm_output_zero_allocator_size = 0
2219
+ if (
2220
+ _use_aiter_gfx95
2221
+ and config.n_routed_experts == 256
2222
+ and self.embed_tokens.embedding_dim == 7168
2223
+ ):
2224
+ num_moe_layers = sum(
2225
+ [
2226
+ 1
2227
+ for i in range(len(self.layers))
2228
+ if isinstance(self.layers[i].mlp, DeepseekV2MoE)
2229
+ ]
2230
+ )
2231
+
2232
+ allocate_size = 0
2233
+ for i in range(len(self.layers)):
2234
+ if isinstance(self.layers[i].mlp, DeepseekV2MoE):
2235
+ allocate_size = self.layers[
2236
+ i
2237
+ ].mlp.shared_experts.gate_up_proj.output_size_per_partition
2238
+ break
2239
+
2240
+ self.gemm_output_zero_allocator_size = (
2241
+ get_dsv3_gemm_output_zero_allocator_size(
2242
+ config.n_routed_experts,
2243
+ num_moe_layers,
2244
+ allocate_size,
2245
+ self.embed_tokens.embedding_dim,
2246
+ )
2247
+ )
2248
+
2031
2249
  def get_input_embeddings(self) -> torch.Tensor:
2032
2250
  return self.embed_tokens
2033
2251
 
@@ -2047,6 +2265,21 @@ class DeepseekV2Model(nn.Module):
2047
2265
  device=device,
2048
2266
  )
2049
2267
 
2268
+ has_gemm_output_zero_allocator = hasattr(
2269
+ self, "gemm_output_zero_allocator_size"
2270
+ )
2271
+
2272
+ gemm_output_zero_allocator = (
2273
+ BumpAllocator(
2274
+ buffer_size=self.gemm_output_zero_allocator_size,
2275
+ dtype=torch.float32,
2276
+ device=device,
2277
+ )
2278
+ if has_gemm_output_zero_allocator
2279
+ and self.gemm_output_zero_allocator_size > 0
2280
+ else None
2281
+ )
2282
+
2050
2283
  if self.pp_group.is_first_rank:
2051
2284
  if input_embeds is None:
2052
2285
  hidden_states = self.embed_tokens(input_ids)
@@ -2073,7 +2306,12 @@ class DeepseekV2Model(nn.Module):
2073
2306
  with get_global_expert_distribution_recorder().with_current_layer(i):
2074
2307
  layer = self.layers[i]
2075
2308
  hidden_states, residual = layer(
2076
- positions, hidden_states, forward_batch, residual, zero_allocator
2309
+ positions,
2310
+ hidden_states,
2311
+ forward_batch,
2312
+ residual,
2313
+ zero_allocator,
2314
+ gemm_output_zero_allocator,
2077
2315
  )
2078
2316
 
2079
2317
  if normal_end_layer != self.end_layer:
@@ -2177,6 +2415,8 @@ class DeepseekV2ForCausalLM(nn.Module):
2177
2415
  disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
2178
2416
  elif get_moe_expert_parallel_world_size() > 1:
2179
2417
  disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
2418
+ elif self.quant_config.get_name() == "w4afp8":
2419
+ disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
2180
2420
 
2181
2421
  if disable_reason is not None:
2182
2422
  global_server_args_dict["disable_shared_experts_fusion"] = True
@@ -2344,6 +2584,16 @@ class DeepseekV2ForCausalLM(nn.Module):
2344
2584
  w_kc, w_vc = w.unflatten(
2345
2585
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
2346
2586
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
2587
+
2588
+ if (
2589
+ _use_aiter_gfx95
2590
+ and self.quant_config is not None
2591
+ and self.quant_config.get_name() == "quark"
2592
+ ):
2593
+ w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
2594
+ quark_post_load_weights(self_attn, w, "mxfp4")
2595
+ )
2596
+
2347
2597
  if not use_deep_gemm_bmm:
2348
2598
  self_attn.w_kc = bind_or_assign(
2349
2599
  self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
@@ -2406,18 +2656,26 @@ class DeepseekV2ForCausalLM(nn.Module):
2406
2656
  )
2407
2657
 
2408
2658
  num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
2659
+
2409
2660
  for layer_id in range(num_hidden_layers):
2410
2661
  if is_nextn:
2411
2662
  layer = self.model.decoder
2412
2663
  else:
2413
2664
  layer = self.model.layers[layer_id]
2414
2665
 
2415
- for module in [
2416
- layer.self_attn.fused_qkv_a_proj_with_mqa,
2417
- layer.self_attn.q_b_proj,
2666
+ module_list = [
2418
2667
  layer.self_attn.kv_b_proj,
2419
2668
  layer.self_attn.o_proj,
2420
- ]:
2669
+ ]
2670
+
2671
+ if self.config.q_lora_rank is not None:
2672
+ module_list.append(layer.self_attn.fused_qkv_a_proj_with_mqa)
2673
+ module_list.append(layer.self_attn.q_b_proj)
2674
+ else:
2675
+ module_list.append(layer.self_attn.kv_a_proj_with_mqa)
2676
+ module_list.append(layer.self_attn.q_proj)
2677
+
2678
+ for module in module_list:
2421
2679
  requant_weight_ue8m0_inplace(
2422
2680
  module.weight, module.weight_scale_inv, weight_block_size
2423
2681
  )
@@ -2480,6 +2738,9 @@ class DeepseekV2ForCausalLM(nn.Module):
2480
2738
  ckpt_up_proj_name="up_proj",
2481
2739
  num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
2482
2740
  )
2741
+ # Params for special naming rules in mixed-precision models, for example:
2742
+ # model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
2743
+ # see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
2483
2744
  if self.quant_config and self.quant_config.get_name() == "w4afp8":
2484
2745
  expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
2485
2746
  num_experts=self.config.n_routed_experts
@@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
499
499
  def should_apply_lora(self, module_name: str) -> bool:
500
500
  return bool(self.lora_pattern.match(module_name))
501
501
 
502
- def get_hidden_dim(self, module_name):
502
+ def get_hidden_dim(self, module_name, layer_idx):
503
503
  # return input_dim, output_dim
504
504
  if module_name == "qkv_proj":
505
505
  return (