sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -67,7 +67,10 @@ from sglang.srt.layers.moe import (
67
67
  should_use_flashinfer_cutlass_moe_fp4_allgather,
68
68
  )
69
69
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
70
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
70
+ from sglang.srt.layers.moe.fused_moe_triton.layer import (
71
+ FusedMoE,
72
+ _is_fp4_quantization_enabled,
73
+ )
71
74
  from sglang.srt.layers.moe.topk import TopK
72
75
  from sglang.srt.layers.quantization import deep_gemm_wrapper
73
76
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -87,8 +90,8 @@ from sglang.srt.layers.quantization.int8_utils import (
87
90
  block_dequant as int8_block_dequant,
88
91
  )
89
92
  from sglang.srt.layers.radix_attention import RadixAttention
90
- from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
91
- from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
93
+ from sglang.srt.layers.rotary_embedding import get_rope_wrapper
94
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
92
95
  from sglang.srt.layers.vocab_parallel_embedding import (
93
96
  ParallelLMHead,
94
97
  VocabParallelEmbedding,
@@ -112,8 +115,11 @@ from sglang.srt.utils import (
112
115
  is_cpu,
113
116
  is_cuda,
114
117
  is_flashinfer_available,
118
+ is_gfx95_supported,
115
119
  is_hip,
116
120
  is_non_idle_and_non_empty,
121
+ is_npu,
122
+ is_sm100_supported,
117
123
  log_info_on_rank0,
118
124
  make_layers,
119
125
  use_intel_amx_backend,
@@ -121,11 +127,28 @@ from sglang.srt.utils import (
121
127
 
122
128
  _is_hip = is_hip()
123
129
  _is_cuda = is_cuda()
130
+ _is_npu = is_npu()
124
131
  _is_fp8_fnuz = is_fp8_fnuz()
125
132
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
126
133
  _is_cpu_amx_available = cpu_has_amx_support()
127
134
  _is_cpu = is_cpu()
128
135
  _device_sm = get_device_sm()
136
+ _is_gfx95_supported = is_gfx95_supported()
137
+
138
+ _use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
139
+
140
+ if _use_aiter_gfx95:
141
+ from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
142
+ from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
143
+ batched_gemm_afp4wfp4_pre_quant,
144
+ fused_flatten_mxfp4_quant,
145
+ fused_rms_mxfp4_quant,
146
+ )
147
+ from sglang.srt.layers.rocm_linear_utils import (
148
+ aiter_dsv3_router_gemm,
149
+ fused_qk_rope_cat,
150
+ get_dsv3_gemm_output_zero_allocator_size,
151
+ )
129
152
 
130
153
  if _is_cuda:
131
154
  from sgl_kernel import (
@@ -221,10 +244,21 @@ class DeepseekV2MLP(nn.Module):
221
244
  forward_batch=None,
222
245
  should_allreduce_fusion: bool = False,
223
246
  use_reduce_scatter: bool = False,
247
+ gemm_output_zero_allocator: BumpAllocator = None,
224
248
  ):
225
249
  if (self.tp_size == 1) and x.shape[0] == 0:
226
250
  return x
227
251
 
252
+ if (
253
+ gemm_output_zero_allocator is not None
254
+ and x.shape[0] <= 256
255
+ and self.gate_up_proj.weight.dtype == torch.uint8
256
+ ):
257
+ y = gemm_output_zero_allocator.allocate(
258
+ x.shape[0] * self.gate_up_proj.output_size_per_partition
259
+ ).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
260
+ x = (x, None, y)
261
+
228
262
  gate_up, _ = self.gate_up_proj(x)
229
263
  x = self.act_fn(gate_up)
230
264
  x, _ = self.down_proj(
@@ -254,7 +288,7 @@ class MoEGate(nn.Module):
254
288
  if _is_cpu and _is_cpu_amx_available:
255
289
  self.quant_method = PackWeightMethod(weight_names=["weight"])
256
290
 
257
- def forward(self, hidden_states):
291
+ def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
258
292
  if use_intel_amx_backend(self):
259
293
  return torch.ops.sgl_kernel.weight_packed_linear(
260
294
  hidden_states,
@@ -272,7 +306,13 @@ class MoEGate(nn.Module):
272
306
  and _device_sm >= 90
273
307
  ):
274
308
  # router gemm output float32
275
- logits = dsv3_router_gemm(hidden_states, self.weight)
309
+ logits = dsv3_router_gemm(
310
+ hidden_states, self.weight, out_dtype=torch.float32
311
+ )
312
+ elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
313
+ logits = aiter_dsv3_router_gemm(
314
+ hidden_states, self.weight, gemm_output_zero_allocator
315
+ )
276
316
  else:
277
317
  logits = F.linear(hidden_states, self.weight, None)
278
318
 
@@ -333,6 +373,9 @@ class DeepseekV2MoE(nn.Module):
333
373
  prefix=add_prefix("experts", prefix),
334
374
  )
335
375
 
376
+ correction_bias = self.gate.e_score_correction_bias
377
+ if _is_fp4_quantization_enabled():
378
+ correction_bias = correction_bias.to(torch.bfloat16)
336
379
  self.topk = TopK(
337
380
  top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
338
381
  renormalize=config.norm_topk_prob,
@@ -340,7 +383,7 @@ class DeepseekV2MoE(nn.Module):
340
383
  num_expert_group=config.n_group,
341
384
  num_fused_shared_experts=self.num_fused_shared_experts,
342
385
  topk_group=config.topk_group,
343
- correction_bias=self.gate.e_score_correction_bias,
386
+ correction_bias=correction_bias,
344
387
  routed_scaling_factor=self.routed_scaling_factor,
345
388
  apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
346
389
  force_topk=quant_config is None,
@@ -436,6 +479,7 @@ class DeepseekV2MoE(nn.Module):
436
479
  forward_batch: Optional[ForwardBatch] = None,
437
480
  should_allreduce_fusion: bool = False,
438
481
  use_reduce_scatter: bool = False,
482
+ gemm_output_zero_allocator: BumpAllocator = None,
439
483
  ) -> torch.Tensor:
440
484
  if not self._enable_deepep_moe:
441
485
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
@@ -449,12 +493,14 @@ class DeepseekV2MoE(nn.Module):
449
493
  hidden_states,
450
494
  should_allreduce_fusion,
451
495
  use_reduce_scatter,
496
+ gemm_output_zero_allocator,
452
497
  )
453
498
  else:
454
499
  return self.forward_normal(
455
500
  hidden_states,
456
501
  should_allreduce_fusion,
457
502
  use_reduce_scatter,
503
+ gemm_output_zero_allocator,
458
504
  )
459
505
  else:
460
506
  return self.forward_deepep(hidden_states, forward_batch)
@@ -464,15 +510,18 @@ class DeepseekV2MoE(nn.Module):
464
510
  hidden_states: torch.Tensor,
465
511
  should_allreduce_fusion: bool = False,
466
512
  use_reduce_scatter: bool = False,
513
+ gemm_output_zero_allocator: BumpAllocator = None,
467
514
  ) -> torch.Tensor:
468
515
 
469
516
  current_stream = torch.cuda.current_stream()
470
517
  self.alt_stream.wait_stream(current_stream)
471
- shared_output = self._forward_shared_experts(hidden_states)
518
+ shared_output = self._forward_shared_experts(
519
+ hidden_states, gemm_output_zero_allocator
520
+ )
472
521
 
473
522
  with torch.cuda.stream(self.alt_stream):
474
523
  # router_logits: (num_tokens, n_experts)
475
- router_logits = self.gate(hidden_states)
524
+ router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
476
525
  topk_output = self.topk(hidden_states, router_logits)
477
526
  final_hidden_states = self.experts(hidden_states, topk_output)
478
527
  if not _is_cuda:
@@ -499,6 +548,7 @@ class DeepseekV2MoE(nn.Module):
499
548
  hidden_states: torch.Tensor,
500
549
  should_allreduce_fusion: bool = False,
501
550
  use_reduce_scatter: bool = False,
551
+ gemm_output_zero_allocator: BumpAllocator = None,
502
552
  ) -> torch.Tensor:
503
553
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
504
554
  self.shared_experts.gate_up_proj
@@ -506,9 +556,11 @@ class DeepseekV2MoE(nn.Module):
506
556
  return self.forward_cpu(hidden_states, should_allreduce_fusion)
507
557
 
508
558
  if hidden_states.shape[0] > 0:
509
- shared_output = self._forward_shared_experts(hidden_states)
559
+ shared_output = self._forward_shared_experts(
560
+ hidden_states, gemm_output_zero_allocator
561
+ )
510
562
  # router_logits: (num_tokens, n_experts)
511
- router_logits = self.gate(hidden_states)
563
+ router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
512
564
  topk_output = self.topk(hidden_states, router_logits)
513
565
  else:
514
566
  shared_output = None
@@ -628,9 +680,13 @@ class DeepseekV2MoE(nn.Module):
628
680
 
629
681
  return final_hidden_states
630
682
 
631
- def _forward_shared_experts(self, hidden_states):
683
+ def _forward_shared_experts(
684
+ self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
685
+ ):
632
686
  if self.num_fused_shared_experts == 0:
633
- return self.shared_experts(hidden_states)
687
+ return self.shared_experts(
688
+ hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
689
+ )
634
690
  else:
635
691
  return None
636
692
 
@@ -989,17 +1045,32 @@ class DeepseekV2AttentionMLA(nn.Module):
989
1045
  # Determine attention backend used by current forward batch
990
1046
  if forward_batch.forward_mode.is_decode_or_idle():
991
1047
  attention_backend = global_server_args_dict["decode_attention_backend"]
1048
+ elif (
1049
+ forward_batch.forward_mode.is_target_verify()
1050
+ or forward_batch.forward_mode.is_draft_extend()
1051
+ ):
1052
+ # Use the specified backend for speculative operations (both verify and draft extend)
1053
+ if global_server_args_dict["speculative_attention_mode"] == "decode":
1054
+ attention_backend = global_server_args_dict["decode_attention_backend"]
1055
+ else: # default to prefill
1056
+ attention_backend = global_server_args_dict["prefill_attention_backend"]
992
1057
  else:
993
1058
  attention_backend = global_server_args_dict["prefill_attention_backend"]
994
1059
  self.current_attention_backend = attention_backend
995
1060
 
996
1061
  if attention_backend == "ascend":
997
- return AttnForwardMethod.MLA
1062
+ if (
1063
+ forward_batch.forward_mode.is_extend()
1064
+ and not forward_batch.forward_mode.is_target_verify()
1065
+ and not forward_batch.forward_mode.is_draft_extend()
1066
+ ):
1067
+ return AttnForwardMethod.MHA
1068
+ else:
1069
+ return AttnForwardMethod.MLA
998
1070
  elif (
999
1071
  attention_backend == "flashinfer"
1000
1072
  or attention_backend == "fa3"
1001
1073
  or attention_backend == "flashmla"
1002
- or attention_backend == "trtllm_mla"
1003
1074
  or attention_backend == "cutlass_mla"
1004
1075
  ):
1005
1076
  # Use MHA with chunked KV cache when prefilling on long sequences.
@@ -1028,13 +1099,28 @@ class DeepseekV2AttentionMLA(nn.Module):
1028
1099
  return AttnForwardMethod.MHA_CHUNKED_KV
1029
1100
  else:
1030
1101
  return _dispatch_mla_subtype()
1102
+ elif attention_backend == "trtllm_mla":
1103
+ if (
1104
+ forward_batch.forward_mode.is_extend()
1105
+ and not forward_batch.forward_mode.is_target_verify()
1106
+ and not forward_batch.forward_mode.is_draft_extend()
1107
+ ):
1108
+ return AttnForwardMethod.MHA_CHUNKED_KV
1109
+ else:
1110
+ return _dispatch_mla_subtype()
1031
1111
  elif attention_backend == "aiter":
1032
1112
  if (
1033
1113
  forward_batch.forward_mode.is_extend()
1034
1114
  and not forward_batch.forward_mode.is_target_verify()
1035
1115
  and not forward_batch.forward_mode.is_draft_extend()
1036
1116
  ):
1037
- return AttnForwardMethod.MHA
1117
+ if is_dp_attention_enabled():
1118
+ if sum(forward_batch.extend_prefix_lens_cpu) == 0:
1119
+ return AttnForwardMethod.MHA
1120
+ else:
1121
+ return AttnForwardMethod.MLA
1122
+ else:
1123
+ return AttnForwardMethod.MHA
1038
1124
  else:
1039
1125
  return AttnForwardMethod.MLA
1040
1126
  else:
@@ -1087,11 +1173,19 @@ class DeepseekV2AttentionMLA(nn.Module):
1087
1173
  if self.attn_mha.kv_b_proj is None:
1088
1174
  self.attn_mha.kv_b_proj = self.kv_b_proj
1089
1175
 
1090
- if hidden_states.shape[0] == 0:
1091
- assert (
1092
- not self.o_proj.reduce_results
1093
- ), "short-circuiting allreduce will lead to hangs"
1094
- return hidden_states, None, forward_batch, None
1176
+ # when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
1177
+ if isinstance(hidden_states, tuple):
1178
+ if hidden_states[0].shape[0] == 0:
1179
+ assert (
1180
+ not self.o_proj.reduce_results
1181
+ ), "short-circuiting allreduce will lead to hangs"
1182
+ return hidden_states[0]
1183
+ else:
1184
+ if hidden_states.shape[0] == 0:
1185
+ assert (
1186
+ not self.o_proj.reduce_results
1187
+ ), "short-circuiting allreduce will lead to hangs"
1188
+ return hidden_states, None, forward_batch, None
1095
1189
 
1096
1190
  attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
1097
1191
 
@@ -1173,13 +1267,19 @@ class DeepseekV2AttentionMLA(nn.Module):
1173
1267
  k[..., : self.qk_nope_head_dim] = k_nope
1174
1268
  k[..., self.qk_nope_head_dim :] = k_pe
1175
1269
 
1176
- latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1177
- latent_cache[:, :, self.kv_lora_rank :] = k_pe
1270
+ if not _is_npu:
1271
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1272
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
1178
1273
 
1179
- # Save latent cache
1180
- forward_batch.token_to_kv_pool.set_kv_buffer(
1181
- self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1182
- )
1274
+ # Save latent cache
1275
+ forward_batch.token_to_kv_pool.set_kv_buffer(
1276
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1277
+ )
1278
+ else:
1279
+ # To reduce a time-costing split operation
1280
+ forward_batch.token_to_kv_pool.set_kv_buffer(
1281
+ self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
1282
+ )
1183
1283
 
1184
1284
  return q, k, v, forward_batch
1185
1285
 
@@ -1209,7 +1309,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1209
1309
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1210
1310
 
1211
1311
  if self.q_lora_rank is not None:
1212
- if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
1312
+ if (
1313
+ (not isinstance(hidden_states, tuple))
1314
+ and hidden_states.shape[0] <= 16
1315
+ and self.use_min_latency_fused_a_gemm
1316
+ ):
1213
1317
  fused_qkv_a_proj_out = dsv3_fused_a_gemm(
1214
1318
  hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
1215
1319
  )
@@ -1229,8 +1333,18 @@ class DeepseekV2AttentionMLA(nn.Module):
1229
1333
  k_nope = self.kv_a_layernorm(k_nope)
1230
1334
  current_stream.wait_stream(self.alt_stream)
1231
1335
  else:
1232
- q = self.q_a_layernorm(q)
1233
- k_nope = self.kv_a_layernorm(k_nope)
1336
+ if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
1337
+ q, k_nope = fused_rms_mxfp4_quant(
1338
+ q,
1339
+ self.q_a_layernorm.weight,
1340
+ self.q_a_layernorm.variance_epsilon,
1341
+ k_nope,
1342
+ self.kv_a_layernorm.weight,
1343
+ self.kv_a_layernorm.variance_epsilon,
1344
+ )
1345
+ else:
1346
+ q = self.q_a_layernorm(q)
1347
+ k_nope = self.kv_a_layernorm(k_nope)
1234
1348
 
1235
1349
  k_nope = k_nope.unsqueeze(1)
1236
1350
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
@@ -1262,10 +1376,27 @@ class DeepseekV2AttentionMLA(nn.Module):
1262
1376
  q_nope_out = q_nope_out[:, :expected_m, :]
1263
1377
  elif _is_hip:
1264
1378
  # TODO(haishaw): add bmm_fp8 to ROCm
1265
- q_nope_out = torch.bmm(
1266
- q_nope.to(torch.bfloat16).transpose(0, 1),
1267
- self.w_kc.to(torch.bfloat16) * self.w_scale,
1268
- )
1379
+ if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
1380
+ x = q_nope.transpose(0, 1)
1381
+ q_nope_out = torch.empty(
1382
+ x.shape[0],
1383
+ x.shape[1],
1384
+ self.w_kc.shape[2],
1385
+ device=x.device,
1386
+ dtype=torch.bfloat16,
1387
+ )
1388
+ batched_gemm_afp4wfp4_pre_quant(
1389
+ x,
1390
+ self.w_kc.transpose(-2, -1),
1391
+ self.w_scale_k.transpose(-2, -1),
1392
+ torch.bfloat16,
1393
+ q_nope_out,
1394
+ )
1395
+ else:
1396
+ q_nope_out = torch.bmm(
1397
+ q_nope.to(torch.bfloat16).transpose(0, 1),
1398
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
1399
+ )
1269
1400
  elif self.w_kc.dtype == torch.float8_e4m3fn:
1270
1401
  q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1271
1402
  q_nope.transpose(0, 1),
@@ -1279,19 +1410,22 @@ class DeepseekV2AttentionMLA(nn.Module):
1279
1410
 
1280
1411
  q_nope_out = q_nope_out.transpose(0, 1)
1281
1412
 
1282
- if not self._fuse_rope_for_trtllm_mla(forward_batch):
1413
+ if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
1414
+ not _use_aiter or not _is_gfx95_supported
1415
+ ):
1283
1416
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1284
1417
 
1285
- return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1418
+ return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1286
1419
 
1287
1420
  def forward_absorb_core(
1288
- self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1421
+ self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1289
1422
  ):
1290
1423
  if (
1291
1424
  self.current_attention_backend == "fa3"
1292
1425
  or self.current_attention_backend == "flashinfer"
1293
1426
  or self.current_attention_backend == "cutlass_mla"
1294
1427
  or self.current_attention_backend == "trtllm_mla"
1428
+ or self.current_attention_backend == "ascend"
1295
1429
  ):
1296
1430
  extra_args = {}
1297
1431
  if self._fuse_rope_for_trtllm_mla(forward_batch):
@@ -1309,8 +1443,23 @@ class DeepseekV2AttentionMLA(nn.Module):
1309
1443
  **extra_args,
1310
1444
  )
1311
1445
  else:
1312
- q = torch.cat([q_nope_out, q_pe], dim=-1)
1313
- k = torch.cat([k_nope, k_pe], dim=-1)
1446
+ if _use_aiter_gfx95:
1447
+ cos = self.rotary_emb.cos_cache
1448
+ sin = self.rotary_emb.sin_cache
1449
+ q, k = fused_qk_rope_cat(
1450
+ q_nope_out,
1451
+ q_pe,
1452
+ k_nope,
1453
+ k_pe,
1454
+ positions,
1455
+ cos,
1456
+ sin,
1457
+ self.rotary_emb.is_neox_style,
1458
+ )
1459
+ else:
1460
+ q = torch.cat([q_nope_out, q_pe], dim=-1)
1461
+ k = torch.cat([k_nope, k_pe], dim=-1)
1462
+
1314
1463
  attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1315
1464
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1316
1465
 
@@ -1335,11 +1484,34 @@ class DeepseekV2AttentionMLA(nn.Module):
1335
1484
  )
1336
1485
  elif _is_hip:
1337
1486
  # TODO(haishaw): add bmm_fp8 to ROCm
1338
- attn_bmm_output = torch.bmm(
1339
- attn_output.to(torch.bfloat16).transpose(0, 1),
1340
- self.w_vc.to(torch.bfloat16) * self.w_scale,
1341
- )
1342
- attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1487
+ if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
1488
+ x = attn_output.transpose(0, 1)
1489
+ attn_bmm_output = torch.empty(
1490
+ x.shape[0],
1491
+ x.shape[1],
1492
+ self.w_vc.shape[2],
1493
+ device=x.device,
1494
+ dtype=torch.bfloat16,
1495
+ )
1496
+ batched_gemm_afp4wfp4_pre_quant(
1497
+ x,
1498
+ self.w_vc.transpose(-2, -1),
1499
+ self.w_scale_v.transpose(-2, -1),
1500
+ torch.bfloat16,
1501
+ attn_bmm_output,
1502
+ )
1503
+ else:
1504
+ attn_bmm_output = torch.bmm(
1505
+ attn_output.to(torch.bfloat16).transpose(0, 1),
1506
+ self.w_vc.to(torch.bfloat16) * self.w_scale,
1507
+ )
1508
+
1509
+ if self.o_proj.weight.dtype == torch.uint8:
1510
+ attn_bmm_output = attn_bmm_output.transpose(0, 1)
1511
+ attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
1512
+ else:
1513
+ attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1514
+
1343
1515
  elif self.w_vc.dtype == torch.float8_e4m3fn:
1344
1516
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1345
1517
  attn_output.transpose(0, 1),
@@ -1661,9 +1833,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1661
1833
  latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
1662
1834
  self.attn_mha.layer_id
1663
1835
  )
1664
- latent_cache = latent_cache_buf[
1665
- forward_batch.prefix_chunk_kv_indices[i]
1666
- ].contiguous()
1836
+ latent_cache = (
1837
+ latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
1838
+ .contiguous()
1839
+ .to(q.dtype)
1840
+ )
1667
1841
 
1668
1842
  kv_a_normed, k_pe = latent_cache.split(
1669
1843
  [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
@@ -1847,10 +2021,24 @@ class DeepseekV2DecoderLayer(nn.Module):
1847
2021
  forward_batch: ForwardBatch,
1848
2022
  residual: Optional[torch.Tensor],
1849
2023
  zero_allocator: BumpAllocator,
2024
+ gemm_output_zero_allocator: BumpAllocator = None,
1850
2025
  ) -> torch.Tensor:
1851
2026
 
2027
+ quant_format = (
2028
+ "mxfp4"
2029
+ if _is_gfx95_supported
2030
+ and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
2031
+ and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
2032
+ is not None
2033
+ and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
2034
+ else ""
2035
+ )
2036
+
1852
2037
  hidden_states, residual = self.layer_communicator.prepare_attn(
1853
- hidden_states, residual, forward_batch
2038
+ hidden_states,
2039
+ residual,
2040
+ forward_batch,
2041
+ quant_format,
1854
2042
  )
1855
2043
 
1856
2044
  hidden_states = self.self_attn(
@@ -1874,8 +2062,16 @@ class DeepseekV2DecoderLayer(nn.Module):
1874
2062
  use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
1875
2063
  forward_batch
1876
2064
  )
2065
+
2066
+ if isinstance(self.mlp, DeepseekV2MLP):
2067
+ gemm_output_zero_allocator = None
2068
+
1877
2069
  hidden_states = self.mlp(
1878
- hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
2070
+ hidden_states,
2071
+ forward_batch,
2072
+ should_allreduce_fusion,
2073
+ use_reduce_scatter,
2074
+ gemm_output_zero_allocator,
1879
2075
  )
1880
2076
 
1881
2077
  if should_allreduce_fusion:
@@ -2019,6 +2215,37 @@ class DeepseekV2Model(nn.Module):
2019
2215
  else:
2020
2216
  self.norm = PPMissingLayer(return_tuple=True)
2021
2217
 
2218
+ self.gemm_output_zero_allocator_size = 0
2219
+ if (
2220
+ _use_aiter_gfx95
2221
+ and config.n_routed_experts == 256
2222
+ and self.embed_tokens.embedding_dim == 7168
2223
+ ):
2224
+ num_moe_layers = sum(
2225
+ [
2226
+ 1
2227
+ for i in range(len(self.layers))
2228
+ if isinstance(self.layers[i].mlp, DeepseekV2MoE)
2229
+ ]
2230
+ )
2231
+
2232
+ allocate_size = 0
2233
+ for i in range(len(self.layers)):
2234
+ if isinstance(self.layers[i].mlp, DeepseekV2MoE):
2235
+ allocate_size = self.layers[
2236
+ i
2237
+ ].mlp.shared_experts.gate_up_proj.output_size_per_partition
2238
+ break
2239
+
2240
+ self.gemm_output_zero_allocator_size = (
2241
+ get_dsv3_gemm_output_zero_allocator_size(
2242
+ config.n_routed_experts,
2243
+ num_moe_layers,
2244
+ allocate_size,
2245
+ self.embed_tokens.embedding_dim,
2246
+ )
2247
+ )
2248
+
2022
2249
  def get_input_embeddings(self) -> torch.Tensor:
2023
2250
  return self.embed_tokens
2024
2251
 
@@ -2038,6 +2265,21 @@ class DeepseekV2Model(nn.Module):
2038
2265
  device=device,
2039
2266
  )
2040
2267
 
2268
+ has_gemm_output_zero_allocator = hasattr(
2269
+ self, "gemm_output_zero_allocator_size"
2270
+ )
2271
+
2272
+ gemm_output_zero_allocator = (
2273
+ BumpAllocator(
2274
+ buffer_size=self.gemm_output_zero_allocator_size,
2275
+ dtype=torch.float32,
2276
+ device=device,
2277
+ )
2278
+ if has_gemm_output_zero_allocator
2279
+ and self.gemm_output_zero_allocator_size > 0
2280
+ else None
2281
+ )
2282
+
2041
2283
  if self.pp_group.is_first_rank:
2042
2284
  if input_embeds is None:
2043
2285
  hidden_states = self.embed_tokens(input_ids)
@@ -2064,7 +2306,12 @@ class DeepseekV2Model(nn.Module):
2064
2306
  with get_global_expert_distribution_recorder().with_current_layer(i):
2065
2307
  layer = self.layers[i]
2066
2308
  hidden_states, residual = layer(
2067
- positions, hidden_states, forward_batch, residual, zero_allocator
2309
+ positions,
2310
+ hidden_states,
2311
+ forward_batch,
2312
+ residual,
2313
+ zero_allocator,
2314
+ gemm_output_zero_allocator,
2068
2315
  )
2069
2316
 
2070
2317
  if normal_end_layer != self.end_layer:
@@ -2168,6 +2415,8 @@ class DeepseekV2ForCausalLM(nn.Module):
2168
2415
  disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
2169
2416
  elif get_moe_expert_parallel_world_size() > 1:
2170
2417
  disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
2418
+ elif self.quant_config.get_name() == "w4afp8":
2419
+ disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
2171
2420
 
2172
2421
  if disable_reason is not None:
2173
2422
  global_server_args_dict["disable_shared_experts_fusion"] = True
@@ -2335,6 +2584,16 @@ class DeepseekV2ForCausalLM(nn.Module):
2335
2584
  w_kc, w_vc = w.unflatten(
2336
2585
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
2337
2586
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
2587
+
2588
+ if (
2589
+ _use_aiter_gfx95
2590
+ and self.quant_config is not None
2591
+ and self.quant_config.get_name() == "quark"
2592
+ ):
2593
+ w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
2594
+ quark_post_load_weights(self_attn, w, "mxfp4")
2595
+ )
2596
+
2338
2597
  if not use_deep_gemm_bmm:
2339
2598
  self_attn.w_kc = bind_or_assign(
2340
2599
  self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
@@ -2397,18 +2656,26 @@ class DeepseekV2ForCausalLM(nn.Module):
2397
2656
  )
2398
2657
 
2399
2658
  num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
2659
+
2400
2660
  for layer_id in range(num_hidden_layers):
2401
2661
  if is_nextn:
2402
2662
  layer = self.model.decoder
2403
2663
  else:
2404
2664
  layer = self.model.layers[layer_id]
2405
2665
 
2406
- for module in [
2407
- layer.self_attn.fused_qkv_a_proj_with_mqa,
2408
- layer.self_attn.q_b_proj,
2666
+ module_list = [
2409
2667
  layer.self_attn.kv_b_proj,
2410
2668
  layer.self_attn.o_proj,
2411
- ]:
2669
+ ]
2670
+
2671
+ if self.config.q_lora_rank is not None:
2672
+ module_list.append(layer.self_attn.fused_qkv_a_proj_with_mqa)
2673
+ module_list.append(layer.self_attn.q_b_proj)
2674
+ else:
2675
+ module_list.append(layer.self_attn.kv_a_proj_with_mqa)
2676
+ module_list.append(layer.self_attn.q_proj)
2677
+
2678
+ for module in module_list:
2412
2679
  requant_weight_ue8m0_inplace(
2413
2680
  module.weight, module.weight_scale_inv, weight_block_size
2414
2681
  )
@@ -2471,6 +2738,9 @@ class DeepseekV2ForCausalLM(nn.Module):
2471
2738
  ckpt_up_proj_name="up_proj",
2472
2739
  num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
2473
2740
  )
2741
+ # Params for special naming rules in mixed-precision models, for example:
2742
+ # model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
2743
+ # see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
2474
2744
  if self.quant_config and self.quant_config.get_name() == "w4afp8":
2475
2745
  expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
2476
2746
  num_experts=self.config.n_routed_experts
@@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
499
499
  def should_apply_lora(self, module_name: str) -> bool:
500
500
  return bool(self.lora_pattern.match(module_name))
501
501
 
502
- def get_hidden_dim(self, module_name):
502
+ def get_hidden_dim(self, module_name, layer_idx):
503
503
  # return input_dim, output_dim
504
504
  if module_name == "qkv_proj":
505
505
  return (