sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
17
17
  """Inference-only DeepseekV2 model."""
18
18
 
19
+ import os
19
20
  from typing import Any, Dict, Iterable, Optional, Tuple
20
21
 
21
22
  import torch
@@ -31,6 +32,9 @@ from sglang.srt.distributed import (
31
32
  tensor_model_parallel_all_reduce,
32
33
  )
33
34
  from sglang.srt.layers.activation import SiluAndMul
35
+ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
36
+ decode_attention_fwd_grouped_rope,
37
+ )
34
38
  from sglang.srt.layers.layernorm import RMSNorm
35
39
  from sglang.srt.layers.linear import (
36
40
  ColumnParallelLinear,
@@ -47,6 +51,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
47
51
  input_to_float8,
48
52
  normalize_e4m3fn_to_e4m3fnuz,
49
53
  )
54
+ from sglang.srt.layers.quantization.int8_utils import (
55
+ block_dequant as int8_block_dequant,
56
+ )
50
57
  from sglang.srt.layers.radix_attention import RadixAttention
51
58
  from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
52
59
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -56,7 +63,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
56
63
  from sglang.srt.managers.schedule_batch import global_server_args_dict
57
64
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
58
65
  from sglang.srt.model_loader.weight_utils import default_weight_loader
59
- from sglang.srt.utils import is_cuda_available, is_hip
66
+ from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
60
67
 
61
68
  is_hip_ = is_hip()
62
69
 
@@ -72,10 +79,15 @@ class DeepseekV2MLP(nn.Module):
72
79
  hidden_act: str,
73
80
  quant_config: Optional[QuantizationConfig] = None,
74
81
  reduce_results: bool = True,
82
+ prefix: str = "",
75
83
  ) -> None:
76
84
  super().__init__()
77
85
  self.gate_up_proj = MergedColumnParallelLinear(
78
- hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
86
+ hidden_size,
87
+ [intermediate_size] * 2,
88
+ bias=False,
89
+ quant_config=quant_config,
90
+ prefix=add_prefix("gate_up_proj", prefix),
79
91
  )
80
92
  self.down_proj = RowParallelLinear(
81
93
  intermediate_size,
@@ -83,6 +95,7 @@ class DeepseekV2MLP(nn.Module):
83
95
  bias=False,
84
96
  quant_config=quant_config,
85
97
  reduce_results=reduce_results,
98
+ prefix=add_prefix("down_proj", prefix),
86
99
  )
87
100
  if hidden_act != "silu":
88
101
  raise ValueError(
@@ -99,7 +112,11 @@ class DeepseekV2MLP(nn.Module):
99
112
 
100
113
 
101
114
  class MoEGate(nn.Module):
102
- def __init__(self, config):
115
+ def __init__(
116
+ self,
117
+ config,
118
+ prefix: str = "",
119
+ ):
103
120
  super().__init__()
104
121
  self.weight = nn.Parameter(
105
122
  torch.empty((config.n_routed_experts, config.hidden_size))
@@ -122,6 +139,7 @@ class DeepseekV2MoE(nn.Module):
122
139
  self,
123
140
  config: PretrainedConfig,
124
141
  quant_config: Optional[QuantizationConfig] = None,
142
+ prefix: str = "",
125
143
  ):
126
144
  super().__init__()
127
145
  self.tp_size = get_tensor_model_parallel_world_size()
@@ -140,7 +158,7 @@ class DeepseekV2MoE(nn.Module):
140
158
  "Only silu is supported for now."
141
159
  )
142
160
 
143
- self.gate = MoEGate(config=config)
161
+ self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
144
162
 
145
163
  MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
146
164
  self.experts = MoEImpl(
@@ -154,6 +172,7 @@ class DeepseekV2MoE(nn.Module):
154
172
  num_expert_group=config.n_group,
155
173
  topk_group=config.topk_group,
156
174
  correction_bias=self.gate.e_score_correction_bias,
175
+ prefix=add_prefix("experts", prefix),
157
176
  )
158
177
 
159
178
  if config.n_shared_experts is not None:
@@ -164,6 +183,7 @@ class DeepseekV2MoE(nn.Module):
164
183
  hidden_act=config.hidden_act,
165
184
  quant_config=quant_config,
166
185
  reduce_results=False,
186
+ prefix=add_prefix("shared_experts", prefix),
167
187
  )
168
188
 
169
189
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -210,6 +230,7 @@ class DeepseekV2Attention(nn.Module):
210
230
  max_position_embeddings: int = 8192,
211
231
  quant_config: Optional[QuantizationConfig] = None,
212
232
  layer_id=None,
233
+ prefix: str = "",
213
234
  ) -> None:
214
235
  super().__init__()
215
236
  self.layer_id = layer_id
@@ -234,6 +255,7 @@ class DeepseekV2Attention(nn.Module):
234
255
  self.q_lora_rank,
235
256
  bias=False,
236
257
  quant_config=quant_config,
258
+ prefix=add_prefix("q_a_proj", prefix),
237
259
  )
238
260
  self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
239
261
  self.q_b_proj = ColumnParallelLinear(
@@ -241,6 +263,7 @@ class DeepseekV2Attention(nn.Module):
241
263
  self.num_heads * self.qk_head_dim,
242
264
  bias=False,
243
265
  quant_config=quant_config,
266
+ prefix=add_prefix("q_b_proj", prefix),
244
267
  )
245
268
  else:
246
269
  self.q_proj = ColumnParallelLinear(
@@ -248,6 +271,7 @@ class DeepseekV2Attention(nn.Module):
248
271
  self.num_heads * self.qk_head_dim,
249
272
  bias=False,
250
273
  quant_config=quant_config,
274
+ prefix=add_prefix("q_proj", prefix),
251
275
  )
252
276
 
253
277
  self.kv_a_proj_with_mqa = ReplicatedLinear(
@@ -255,8 +279,7 @@ class DeepseekV2Attention(nn.Module):
255
279
  self.kv_lora_rank + self.qk_rope_head_dim,
256
280
  bias=False,
257
281
  quant_config=quant_config,
258
- # FIXME: quick fix for skip quantization
259
- prefix=f"self_attn.kv_a_proj_with_mqa",
282
+ prefix=add_prefix("kv_a_proj_with_mqa", prefix),
260
283
  )
261
284
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
262
285
  self.kv_b_proj = ColumnParallelLinear(
@@ -264,6 +287,7 @@ class DeepseekV2Attention(nn.Module):
264
287
  self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
265
288
  bias=False,
266
289
  quant_config=quant_config,
290
+ prefix=add_prefix("kv_b_proj", prefix),
267
291
  )
268
292
  # O projection.
269
293
  self.o_proj = RowParallelLinear(
@@ -271,6 +295,7 @@ class DeepseekV2Attention(nn.Module):
271
295
  self.hidden_size,
272
296
  bias=False,
273
297
  quant_config=quant_config,
298
+ prefix=add_prefix("o_proj", prefix),
274
299
  )
275
300
  rope_scaling["rope_type"] = "deepseek_yarn"
276
301
  self.rotary_emb = get_rope_wrapper(
@@ -296,6 +321,7 @@ class DeepseekV2Attention(nn.Module):
296
321
  self.scaling,
297
322
  num_kv_heads=self.num_local_heads,
298
323
  layer_id=layer_id,
324
+ prefix=add_prefix("attn", prefix),
299
325
  )
300
326
 
301
327
  def forward(
@@ -361,6 +387,7 @@ class DeepseekV2AttentionMLA(nn.Module):
361
387
  quant_config: Optional[QuantizationConfig] = None,
362
388
  layer_id=None,
363
389
  use_dp=False,
390
+ prefix: str = "",
364
391
  ) -> None:
365
392
  super().__init__()
366
393
  self.layer_id = layer_id
@@ -387,6 +414,7 @@ class DeepseekV2AttentionMLA(nn.Module):
387
414
  self.q_lora_rank,
388
415
  bias=False,
389
416
  quant_config=quant_config,
417
+ prefix=add_prefix("q_a_proj", prefix),
390
418
  )
391
419
  self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
392
420
  self.q_b_proj = ReplicatedLinear(
@@ -394,6 +422,7 @@ class DeepseekV2AttentionMLA(nn.Module):
394
422
  self.num_heads * self.qk_head_dim,
395
423
  bias=False,
396
424
  quant_config=quant_config,
425
+ prefix=add_prefix("q_b_proj", prefix),
397
426
  )
398
427
  else:
399
428
  self.q_proj = ReplicatedLinear(
@@ -401,12 +430,14 @@ class DeepseekV2AttentionMLA(nn.Module):
401
430
  self.num_heads * self.qk_head_dim,
402
431
  bias=False,
403
432
  quant_config=quant_config,
433
+ prefix=add_prefix("q_proj", prefix),
404
434
  )
405
435
  self.kv_b_proj = ReplicatedLinear(
406
436
  self.kv_lora_rank,
407
437
  self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
408
438
  bias=False,
409
439
  quant_config=quant_config,
440
+ prefix=add_prefix("kv_b_proj", prefix),
410
441
  )
411
442
  # O projection.
412
443
  self.o_proj = ReplicatedLinear(
@@ -414,6 +445,7 @@ class DeepseekV2AttentionMLA(nn.Module):
414
445
  self.hidden_size,
415
446
  bias=False,
416
447
  quant_config=quant_config,
448
+ prefix=add_prefix("o_proj", prefix),
417
449
  )
418
450
  else:
419
451
  # For tensor parallel attention
@@ -423,6 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module):
423
455
  self.q_lora_rank,
424
456
  bias=False,
425
457
  quant_config=quant_config,
458
+ prefix=add_prefix("q_a_proj", prefix),
426
459
  )
427
460
  self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
428
461
  self.q_b_proj = ColumnParallelLinear(
@@ -430,6 +463,7 @@ class DeepseekV2AttentionMLA(nn.Module):
430
463
  self.num_heads * self.qk_head_dim,
431
464
  bias=False,
432
465
  quant_config=quant_config,
466
+ prefix=add_prefix("q_b_proj", prefix),
433
467
  )
434
468
  else:
435
469
  self.q_proj = ColumnParallelLinear(
@@ -437,12 +471,14 @@ class DeepseekV2AttentionMLA(nn.Module):
437
471
  self.num_heads * self.qk_head_dim,
438
472
  bias=False,
439
473
  quant_config=quant_config,
474
+ prefix=add_prefix("q_proj", prefix),
440
475
  )
441
476
  self.kv_b_proj = ColumnParallelLinear(
442
477
  self.kv_lora_rank,
443
478
  self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
444
479
  bias=False,
445
480
  quant_config=quant_config,
481
+ prefix=add_prefix("kv_b_proj", prefix),
446
482
  )
447
483
  # O projection.
448
484
  self.o_proj = RowParallelLinear(
@@ -450,6 +486,7 @@ class DeepseekV2AttentionMLA(nn.Module):
450
486
  self.hidden_size,
451
487
  bias=False,
452
488
  quant_config=quant_config,
489
+ prefix=add_prefix("o_proj", prefix),
453
490
  )
454
491
 
455
492
  self.kv_a_proj_with_mqa = ReplicatedLinear(
@@ -457,8 +494,7 @@ class DeepseekV2AttentionMLA(nn.Module):
457
494
  self.kv_lora_rank + self.qk_rope_head_dim,
458
495
  bias=False,
459
496
  quant_config=quant_config,
460
- # FIXME: quick fix for skip quantization
461
- prefix=f"self_attn.kv_a_proj_with_mqa",
497
+ prefix=add_prefix("kv_a_proj_with_mqa", prefix),
462
498
  )
463
499
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
464
500
 
@@ -489,6 +525,7 @@ class DeepseekV2AttentionMLA(nn.Module):
489
525
  num_kv_heads=1,
490
526
  layer_id=layer_id,
491
527
  v_head_dim=self.kv_lora_rank,
528
+ prefix=add_prefix("attn_mqa", prefix),
492
529
  )
493
530
 
494
531
  self.attn_mha = RadixAttention(
@@ -498,6 +535,7 @@ class DeepseekV2AttentionMLA(nn.Module):
498
535
  num_kv_heads=self.num_local_heads,
499
536
  layer_id=layer_id,
500
537
  v_head_dim=self.v_head_dim,
538
+ prefix=add_prefix("attn_mha", prefix),
501
539
  )
502
540
 
503
541
  self.w_kc = None
@@ -510,20 +548,37 @@ class DeepseekV2AttentionMLA(nn.Module):
510
548
  hidden_states: torch.Tensor,
511
549
  forward_batch: ForwardBatch,
512
550
  ) -> torch.Tensor:
513
- if global_server_args_dict["enable_flashinfer_mla"]:
514
- if forward_batch.forward_mode.is_extend():
515
- return self.forward_normal(positions, hidden_states, forward_batch)
551
+
552
+ def no_absorb() -> bool:
553
+ if global_server_args_dict["enable_flashinfer_mla"]:
554
+ # Flashinfer MLA: Do not absorb when enabling ragged prefill
555
+ return (
556
+ not global_server_args_dict["flashinfer_mla_disable_ragged"]
557
+ and forward_batch.forward_mode.is_extend()
558
+ and forward_batch.extend_prefix_lens.sum() == 0
559
+ )
516
560
  else:
517
- return self.forward_absorb(positions, hidden_states, forward_batch)
561
+ # Triton: Use normal computation for prefill and use weight absorption for extend/decode
562
+ return (
563
+ forward_batch.forward_mode.is_extend()
564
+ and not forward_batch.forward_mode.is_target_verify()
565
+ and not forward_batch.forward_mode.is_draft_extend()
566
+ and forward_batch.extend_prefix_lens.sum() == 0
567
+ )
568
+
569
+ if no_absorb():
570
+ return self.forward_normal(positions, hidden_states, forward_batch)
518
571
  else:
519
- # Triton: Use normal computation for prefill and use weight absorption for extend/decode
520
- if (
521
- forward_batch.forward_mode.is_extend()
522
- and not forward_batch.forward_mode.is_target_verify()
523
- and not forward_batch.forward_mode.is_draft_extend()
524
- and forward_batch.extend_prefix_lens.sum() == 0
525
- ):
526
- return self.forward_normal(positions, hidden_states, forward_batch)
572
+ if is_hip_:
573
+ if (
574
+ os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
575
+ and forward_batch.forward_mode.is_decode()
576
+ ):
577
+ return self.forward_absorb_fused_mla_rope(
578
+ positions, hidden_states, forward_batch
579
+ )
580
+ else:
581
+ return self.forward_absorb(positions, hidden_states, forward_batch)
527
582
  else:
528
583
  return self.forward_absorb(positions, hidden_states, forward_batch)
529
584
 
@@ -644,6 +699,149 @@ class DeepseekV2AttentionMLA(nn.Module):
644
699
 
645
700
  return output
646
701
 
702
+ def forward_absorb_fused_mla_rope(
703
+ self,
704
+ positions: torch.Tensor,
705
+ hidden_states: torch.Tensor,
706
+ forward_batch: ForwardBatch,
707
+ ) -> torch.Tensor:
708
+ enable_rope_fusion = (
709
+ os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
710
+ )
711
+ q_len = hidden_states.shape[0]
712
+ q_input = hidden_states.new_empty(
713
+ q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
714
+ )
715
+ if self.q_lora_rank is not None:
716
+ q = self.q_a_proj(hidden_states)[0]
717
+ q = self.q_a_layernorm(q)
718
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
719
+ else:
720
+ q = self.q_proj(hidden_states)[0].view(
721
+ -1, self.num_local_heads, self.qk_head_dim
722
+ )
723
+ q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
724
+
725
+ if self.w_kc.dtype == torch.float8_e4m3fnuz:
726
+ # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
727
+ q_nope_out = torch.bmm(
728
+ q_nope.to(torch.bfloat16).transpose(0, 1),
729
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
730
+ )
731
+ elif self.w_kc.dtype == torch.float8_e4m3fn:
732
+ q_nope_val, q_nope_scale = input_to_float8(
733
+ q_nope.transpose(0, 1), torch.float8_e4m3fn
734
+ )
735
+ q_nope_out = bmm_fp8(
736
+ q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
737
+ )
738
+ else:
739
+ q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
740
+ q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
741
+
742
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
743
+ v_input = latent_cache[..., : self.kv_lora_rank]
744
+ v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
745
+ k_input = latent_cache.unsqueeze(1)
746
+ k_input[..., : self.kv_lora_rank] = v_input
747
+
748
+ if not enable_rope_fusion:
749
+ k_pe = k_input[..., self.kv_lora_rank :]
750
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
751
+ q_input[..., self.kv_lora_rank :] = q_pe
752
+ k_input[..., self.kv_lora_rank :] = k_pe
753
+ k_pe_output = None
754
+ else:
755
+ k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])
756
+
757
+ q_input[..., self.kv_lora_rank :] = q_pe
758
+
759
+ # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
760
+ # Use Fused ROPE with use_rope=OFF.
761
+ attn_output = torch.empty(
762
+ (q_len, self.num_local_heads, self.kv_lora_rank),
763
+ dtype=q.dtype,
764
+ device=q.device,
765
+ )
766
+ attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
767
+ forward_batch.attn_backend.forward_metadata
768
+ )
769
+ cos_sin_cache = self.rotary_emb.cos_sin_cache
770
+ num_kv_split = forward_batch.attn_backend.num_kv_splits
771
+ sm_scale = self.attn_mqa.scaling
772
+ if attn_logits is None:
773
+ attn_logits = torch.empty(
774
+ (
775
+ forward_batch.batch_size,
776
+ self.num_local_heads,
777
+ num_kv_split,
778
+ self.kv_lora_rank + 1,
779
+ ),
780
+ dtype=torch.float32,
781
+ device=q.device,
782
+ )
783
+
784
+ # save current latent cache.
785
+ forward_batch.token_to_kv_pool.set_kv_buffer(
786
+ self.attn_mqa, forward_batch.out_cache_loc, k_input, None
787
+ )
788
+ key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
789
+ self.attn_mqa.layer_id
790
+ )
791
+ val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
792
+
793
+ decode_attention_fwd_grouped_rope(
794
+ q_input,
795
+ key_cache_buf,
796
+ val_cache_buf,
797
+ attn_output,
798
+ kv_indptr,
799
+ kv_indices,
800
+ k_pe_output,
801
+ self.kv_lora_rank,
802
+ self.rotary_emb.rotary_dim,
803
+ cos_sin_cache,
804
+ positions,
805
+ attn_logits,
806
+ num_kv_split,
807
+ sm_scale,
808
+ logit_cap=self.attn_mqa.logit_cap,
809
+ use_rope=enable_rope_fusion,
810
+ is_neox_style=self.rotary_emb.is_neox_style,
811
+ )
812
+
813
+ if enable_rope_fusion:
814
+ k_input[..., self.kv_lora_rank :] = k_pe_output
815
+ forward_batch.token_to_kv_pool.set_kv_buffer(
816
+ self.attn_mqa, forward_batch.out_cache_loc, k_input, None
817
+ )
818
+
819
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
820
+
821
+ if self.w_vc.dtype == torch.float8_e4m3fnuz:
822
+ # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
823
+ attn_bmm_output = torch.bmm(
824
+ attn_output.to(torch.bfloat16).transpose(0, 1),
825
+ self.w_vc.to(torch.bfloat16) * self.w_scale,
826
+ )
827
+ elif self.w_vc.dtype == torch.float8_e4m3fn:
828
+ attn_output_val, attn_output_scale = input_to_float8(
829
+ attn_output.transpose(0, 1), torch.float8_e4m3fn
830
+ )
831
+ attn_bmm_output = bmm_fp8(
832
+ attn_output_val,
833
+ self.w_vc,
834
+ attn_output_scale,
835
+ self.w_scale,
836
+ torch.bfloat16,
837
+ )
838
+ else:
839
+ attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
840
+ attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
841
+ output, _ = self.o_proj(attn_output)
842
+
843
+ return output
844
+
647
845
 
648
846
  def all_gather(
649
847
  input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
@@ -651,16 +849,14 @@ def all_gather(
651
849
  if world_size == 1:
652
850
  return input_tensor
653
851
 
654
- all_lens = forward_batch.global_num_tokens
655
- max_len = max(forward_batch.global_num_tokens)
852
+ all_lens = forward_batch.global_num_tokens_cpu
853
+ max_len = max(forward_batch.global_num_tokens_cpu)
656
854
 
657
855
  padded_tensor = torch.nn.functional.pad(
658
856
  input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
659
857
  )
660
858
 
661
- torch.distributed.all_gather_into_tensor(
662
- forward_batch.gathered_buffer, padded_tensor, group=group
663
- )
859
+ group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
664
860
 
665
861
  gathered_tensors = torch.concat(
666
862
  [
@@ -683,6 +879,7 @@ class DeepseekV2DecoderLayer(nn.Module):
683
879
  layer_id: int,
684
880
  quant_config: Optional[QuantizationConfig] = None,
685
881
  is_nextn: bool = False,
882
+ prefix: str = "",
686
883
  ) -> None:
687
884
  super().__init__()
688
885
  self.hidden_size = config.hidden_size
@@ -696,7 +893,7 @@ class DeepseekV2DecoderLayer(nn.Module):
696
893
  if self.enable_dp_attention:
697
894
  self.tp_rank = get_tensor_model_parallel_rank()
698
895
  self.tp_size = get_tensor_model_parallel_world_size()
699
- self.tp_group = get_tp_group().device_group
896
+ self.tp_group = get_tp_group()
700
897
  if not global_server_args_dict["disable_mla"]:
701
898
  self.self_attn = DeepseekV2AttentionMLA(
702
899
  config=config,
@@ -715,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module):
715
912
  quant_config=quant_config,
716
913
  layer_id=layer_id,
717
914
  use_dp=self.enable_dp_attention,
915
+ prefix=add_prefix("self_attn", prefix),
718
916
  )
719
917
  else:
720
918
  self.self_attn = DeepseekV2Attention(
@@ -733,19 +931,25 @@ class DeepseekV2DecoderLayer(nn.Module):
733
931
  max_position_embeddings=max_position_embeddings,
734
932
  quant_config=quant_config,
735
933
  layer_id=layer_id,
934
+ prefix=add_prefix("self_attn", prefix),
736
935
  )
737
936
  if is_nextn or (
738
937
  config.n_routed_experts is not None
739
938
  and layer_id >= config.first_k_dense_replace
740
939
  and layer_id % config.moe_layer_freq == 0
741
940
  ):
742
- self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
941
+ self.mlp = DeepseekV2MoE(
942
+ config=config,
943
+ quant_config=quant_config,
944
+ prefix=add_prefix("mlp", prefix),
945
+ )
743
946
  else:
744
947
  self.mlp = DeepseekV2MLP(
745
948
  hidden_size=config.hidden_size,
746
949
  intermediate_size=config.intermediate_size,
747
950
  hidden_act=config.hidden_act,
748
951
  quant_config=quant_config,
952
+ prefix=add_prefix("mlp", prefix),
749
953
  )
750
954
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
751
955
  self.post_attention_layernorm = RMSNorm(
@@ -797,6 +1001,7 @@ class DeepseekV2Model(nn.Module):
797
1001
  self,
798
1002
  config: PretrainedConfig,
799
1003
  quant_config: Optional[QuantizationConfig] = None,
1004
+ prefix: str = "",
800
1005
  ) -> None:
801
1006
  super().__init__()
802
1007
  self.padding_id = config.pad_token_id
@@ -813,6 +1018,7 @@ class DeepseekV2Model(nn.Module):
813
1018
  config,
814
1019
  layer_id,
815
1020
  quant_config=quant_config,
1021
+ prefix=add_prefix(f"layers.{layer_id}", prefix),
816
1022
  )
817
1023
  for layer_id in range(config.num_hidden_layers)
818
1024
  ]
@@ -843,21 +1049,28 @@ class DeepseekV2ForCausalLM(nn.Module):
843
1049
  self,
844
1050
  config: PretrainedConfig,
845
1051
  quant_config: Optional[QuantizationConfig] = None,
1052
+ prefix: str = "",
846
1053
  ) -> None:
847
1054
  super().__init__()
848
1055
  self.config = config
849
1056
  self.quant_config = quant_config
850
- self.model = DeepseekV2Model(config, quant_config)
1057
+ self.model = DeepseekV2Model(
1058
+ config, quant_config, prefix=add_prefix("model", prefix)
1059
+ )
851
1060
  if global_server_args_dict["enable_dp_attention"]:
852
1061
  self.lm_head = ReplicatedLinear(
853
1062
  config.hidden_size,
854
1063
  config.vocab_size,
855
1064
  bias=False,
1065
+ prefix=add_prefix("lm_head", prefix),
856
1066
  )
857
1067
  self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
858
1068
  else:
859
1069
  self.lm_head = ParallelLMHead(
860
- config.vocab_size, config.hidden_size, quant_config=quant_config
1070
+ config.vocab_size,
1071
+ config.hidden_size,
1072
+ quant_config=quant_config,
1073
+ prefix=add_prefix("lm_head", prefix),
861
1074
  )
862
1075
  self.logits_processor = LogitsProcessor(config)
863
1076
 
@@ -989,6 +1202,18 @@ class DeepseekV2ForCausalLM(nn.Module):
989
1202
  weight, weight_scale, weight_block_size
990
1203
  )
991
1204
  self_attn.w_scale = scale
1205
+ if (
1206
+ hasattr(self.quant_config, "weight_block_size")
1207
+ and w.dtype == torch.int8
1208
+ ):
1209
+ weight_block_size = self.quant_config.weight_block_size
1210
+ if weight_block_size is not None:
1211
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1212
+ weight = w
1213
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1214
+ w = int8_block_dequant(
1215
+ weight, weight_scale, weight_block_size
1216
+ ).to(torch.bfloat16)
992
1217
  w_kc, w_vc = w.unflatten(
993
1218
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
994
1219
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
@@ -1002,6 +1227,17 @@ class DeepseekV2ForCausalLM(nn.Module):
1002
1227
  if is_hip_:
1003
1228
  self_attn.w_scale *= 2.0
1004
1229
 
1230
+ def get_embed_and_head(self):
1231
+ return self.model.embed_tokens.weight, self.lm_head.weight
1232
+
1233
+ def set_embed_and_head(self, embed, head):
1234
+ del self.model.embed_tokens.weight
1235
+ del self.lm_head.weight
1236
+ self.model.embed_tokens.weight = embed
1237
+ self.lm_head.weight = head
1238
+ torch.cuda.empty_cache()
1239
+ torch.cuda.synchronize()
1240
+
1005
1241
 
1006
1242
  class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
1007
1243
  pass