sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
17
17
  """Inference-only DeepseekV2 model."""
18
18
 
19
+ import os
19
20
  from typing import Any, Dict, Iterable, Optional, Tuple
20
21
 
21
22
  import torch
@@ -31,6 +32,9 @@ from sglang.srt.distributed import (
31
32
  tensor_model_parallel_all_reduce,
32
33
  )
33
34
  from sglang.srt.layers.activation import SiluAndMul
35
+ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
36
+ decode_attention_fwd_grouped_rope,
37
+ )
34
38
  from sglang.srt.layers.layernorm import RMSNorm
35
39
  from sglang.srt.layers.linear import (
36
40
  ColumnParallelLinear,
@@ -47,6 +51,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
47
51
  input_to_float8,
48
52
  normalize_e4m3fn_to_e4m3fnuz,
49
53
  )
54
+ from sglang.srt.layers.quantization.int8_utils import (
55
+ block_dequant as int8_block_dequant,
56
+ )
50
57
  from sglang.srt.layers.radix_attention import RadixAttention
51
58
  from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
52
59
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -56,7 +63,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
56
63
  from sglang.srt.managers.schedule_batch import global_server_args_dict
57
64
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
58
65
  from sglang.srt.model_loader.weight_utils import default_weight_loader
59
- from sglang.srt.utils import is_cuda_available, is_hip
66
+ from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
60
67
 
61
68
  is_hip_ = is_hip()
62
69
 
@@ -72,10 +79,15 @@ class DeepseekV2MLP(nn.Module):
72
79
  hidden_act: str,
73
80
  quant_config: Optional[QuantizationConfig] = None,
74
81
  reduce_results: bool = True,
82
+ prefix: str = "",
75
83
  ) -> None:
76
84
  super().__init__()
77
85
  self.gate_up_proj = MergedColumnParallelLinear(
78
- hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
86
+ hidden_size,
87
+ [intermediate_size] * 2,
88
+ bias=False,
89
+ quant_config=quant_config,
90
+ prefix=add_prefix("gate_up_proj", prefix),
79
91
  )
80
92
  self.down_proj = RowParallelLinear(
81
93
  intermediate_size,
@@ -83,6 +95,7 @@ class DeepseekV2MLP(nn.Module):
83
95
  bias=False,
84
96
  quant_config=quant_config,
85
97
  reduce_results=reduce_results,
98
+ prefix=add_prefix("down_proj", prefix),
86
99
  )
87
100
  if hidden_act != "silu":
88
101
  raise ValueError(
@@ -99,7 +112,11 @@ class DeepseekV2MLP(nn.Module):
99
112
 
100
113
 
101
114
  class MoEGate(nn.Module):
102
- def __init__(self, config):
115
+ def __init__(
116
+ self,
117
+ config,
118
+ prefix: str = "",
119
+ ):
103
120
  super().__init__()
104
121
  self.weight = nn.Parameter(
105
122
  torch.empty((config.n_routed_experts, config.hidden_size))
@@ -122,6 +139,7 @@ class DeepseekV2MoE(nn.Module):
122
139
  self,
123
140
  config: PretrainedConfig,
124
141
  quant_config: Optional[QuantizationConfig] = None,
142
+ prefix: str = "",
125
143
  ):
126
144
  super().__init__()
127
145
  self.tp_size = get_tensor_model_parallel_world_size()
@@ -140,7 +158,7 @@ class DeepseekV2MoE(nn.Module):
140
158
  "Only silu is supported for now."
141
159
  )
142
160
 
143
- self.gate = MoEGate(config=config)
161
+ self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
144
162
 
145
163
  MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
146
164
  self.experts = MoEImpl(
@@ -154,6 +172,7 @@ class DeepseekV2MoE(nn.Module):
154
172
  num_expert_group=config.n_group,
155
173
  topk_group=config.topk_group,
156
174
  correction_bias=self.gate.e_score_correction_bias,
175
+ prefix=add_prefix("experts", prefix),
157
176
  )
158
177
 
159
178
  if config.n_shared_experts is not None:
@@ -164,6 +183,7 @@ class DeepseekV2MoE(nn.Module):
164
183
  hidden_act=config.hidden_act,
165
184
  quant_config=quant_config,
166
185
  reduce_results=False,
186
+ prefix=add_prefix("shared_experts", prefix),
167
187
  )
168
188
 
169
189
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -210,6 +230,7 @@ class DeepseekV2Attention(nn.Module):
210
230
  max_position_embeddings: int = 8192,
211
231
  quant_config: Optional[QuantizationConfig] = None,
212
232
  layer_id=None,
233
+ prefix: str = "",
213
234
  ) -> None:
214
235
  super().__init__()
215
236
  self.layer_id = layer_id
@@ -234,6 +255,7 @@ class DeepseekV2Attention(nn.Module):
234
255
  self.q_lora_rank,
235
256
  bias=False,
236
257
  quant_config=quant_config,
258
+ prefix=add_prefix("q_a_proj", prefix),
237
259
  )
238
260
  self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
239
261
  self.q_b_proj = ColumnParallelLinear(
@@ -241,6 +263,7 @@ class DeepseekV2Attention(nn.Module):
241
263
  self.num_heads * self.qk_head_dim,
242
264
  bias=False,
243
265
  quant_config=quant_config,
266
+ prefix=add_prefix("q_b_proj", prefix),
244
267
  )
245
268
  else:
246
269
  self.q_proj = ColumnParallelLinear(
@@ -248,6 +271,7 @@ class DeepseekV2Attention(nn.Module):
248
271
  self.num_heads * self.qk_head_dim,
249
272
  bias=False,
250
273
  quant_config=quant_config,
274
+ prefix=add_prefix("q_proj", prefix),
251
275
  )
252
276
 
253
277
  self.kv_a_proj_with_mqa = ReplicatedLinear(
@@ -255,8 +279,7 @@ class DeepseekV2Attention(nn.Module):
255
279
  self.kv_lora_rank + self.qk_rope_head_dim,
256
280
  bias=False,
257
281
  quant_config=quant_config,
258
- # FIXME: quick fix for skip quantization
259
- prefix=f"self_attn.kv_a_proj_with_mqa",
282
+ prefix=add_prefix("kv_a_proj_with_mqa", prefix),
260
283
  )
261
284
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
262
285
  self.kv_b_proj = ColumnParallelLinear(
@@ -264,6 +287,7 @@ class DeepseekV2Attention(nn.Module):
264
287
  self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
265
288
  bias=False,
266
289
  quant_config=quant_config,
290
+ prefix=add_prefix("kv_b_proj", prefix),
267
291
  )
268
292
  # O projection.
269
293
  self.o_proj = RowParallelLinear(
@@ -271,6 +295,7 @@ class DeepseekV2Attention(nn.Module):
271
295
  self.hidden_size,
272
296
  bias=False,
273
297
  quant_config=quant_config,
298
+ prefix=add_prefix("o_proj", prefix),
274
299
  )
275
300
  rope_scaling["rope_type"] = "deepseek_yarn"
276
301
  self.rotary_emb = get_rope_wrapper(
@@ -296,6 +321,7 @@ class DeepseekV2Attention(nn.Module):
296
321
  self.scaling,
297
322
  num_kv_heads=self.num_local_heads,
298
323
  layer_id=layer_id,
324
+ prefix=add_prefix("attn", prefix),
299
325
  )
300
326
 
301
327
  def forward(
@@ -361,6 +387,7 @@ class DeepseekV2AttentionMLA(nn.Module):
361
387
  quant_config: Optional[QuantizationConfig] = None,
362
388
  layer_id=None,
363
389
  use_dp=False,
390
+ prefix: str = "",
364
391
  ) -> None:
365
392
  super().__init__()
366
393
  self.layer_id = layer_id
@@ -387,6 +414,7 @@ class DeepseekV2AttentionMLA(nn.Module):
387
414
  self.q_lora_rank,
388
415
  bias=False,
389
416
  quant_config=quant_config,
417
+ prefix=add_prefix("q_a_proj", prefix),
390
418
  )
391
419
  self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
392
420
  self.q_b_proj = ReplicatedLinear(
@@ -394,6 +422,7 @@ class DeepseekV2AttentionMLA(nn.Module):
394
422
  self.num_heads * self.qk_head_dim,
395
423
  bias=False,
396
424
  quant_config=quant_config,
425
+ prefix=add_prefix("q_b_proj", prefix),
397
426
  )
398
427
  else:
399
428
  self.q_proj = ReplicatedLinear(
@@ -401,12 +430,14 @@ class DeepseekV2AttentionMLA(nn.Module):
401
430
  self.num_heads * self.qk_head_dim,
402
431
  bias=False,
403
432
  quant_config=quant_config,
433
+ prefix=add_prefix("q_proj", prefix),
404
434
  )
405
435
  self.kv_b_proj = ReplicatedLinear(
406
436
  self.kv_lora_rank,
407
437
  self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
408
438
  bias=False,
409
439
  quant_config=quant_config,
440
+ prefix=add_prefix("kv_b_proj", prefix),
410
441
  )
411
442
  # O projection.
412
443
  self.o_proj = ReplicatedLinear(
@@ -414,6 +445,7 @@ class DeepseekV2AttentionMLA(nn.Module):
414
445
  self.hidden_size,
415
446
  bias=False,
416
447
  quant_config=quant_config,
448
+ prefix=add_prefix("o_proj", prefix),
417
449
  )
418
450
  else:
419
451
  # For tensor parallel attention
@@ -423,6 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module):
423
455
  self.q_lora_rank,
424
456
  bias=False,
425
457
  quant_config=quant_config,
458
+ prefix=add_prefix("q_a_proj", prefix),
426
459
  )
427
460
  self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
428
461
  self.q_b_proj = ColumnParallelLinear(
@@ -430,6 +463,7 @@ class DeepseekV2AttentionMLA(nn.Module):
430
463
  self.num_heads * self.qk_head_dim,
431
464
  bias=False,
432
465
  quant_config=quant_config,
466
+ prefix=add_prefix("q_b_proj", prefix),
433
467
  )
434
468
  else:
435
469
  self.q_proj = ColumnParallelLinear(
@@ -437,12 +471,14 @@ class DeepseekV2AttentionMLA(nn.Module):
437
471
  self.num_heads * self.qk_head_dim,
438
472
  bias=False,
439
473
  quant_config=quant_config,
474
+ prefix=add_prefix("q_proj", prefix),
440
475
  )
441
476
  self.kv_b_proj = ColumnParallelLinear(
442
477
  self.kv_lora_rank,
443
478
  self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
444
479
  bias=False,
445
480
  quant_config=quant_config,
481
+ prefix=add_prefix("kv_b_proj", prefix),
446
482
  )
447
483
  # O projection.
448
484
  self.o_proj = RowParallelLinear(
@@ -450,6 +486,7 @@ class DeepseekV2AttentionMLA(nn.Module):
450
486
  self.hidden_size,
451
487
  bias=False,
452
488
  quant_config=quant_config,
489
+ prefix=add_prefix("o_proj", prefix),
453
490
  )
454
491
 
455
492
  self.kv_a_proj_with_mqa = ReplicatedLinear(
@@ -457,8 +494,7 @@ class DeepseekV2AttentionMLA(nn.Module):
457
494
  self.kv_lora_rank + self.qk_rope_head_dim,
458
495
  bias=False,
459
496
  quant_config=quant_config,
460
- # FIXME: quick fix for skip quantization
461
- prefix=f"self_attn.kv_a_proj_with_mqa",
497
+ prefix=add_prefix("kv_a_proj_with_mqa", prefix),
462
498
  )
463
499
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
464
500
 
@@ -489,6 +525,7 @@ class DeepseekV2AttentionMLA(nn.Module):
489
525
  num_kv_heads=1,
490
526
  layer_id=layer_id,
491
527
  v_head_dim=self.kv_lora_rank,
528
+ prefix=add_prefix("attn_mqa", prefix),
492
529
  )
493
530
 
494
531
  self.attn_mha = RadixAttention(
@@ -498,6 +535,7 @@ class DeepseekV2AttentionMLA(nn.Module):
498
535
  num_kv_heads=self.num_local_heads,
499
536
  layer_id=layer_id,
500
537
  v_head_dim=self.v_head_dim,
538
+ prefix=add_prefix("attn_mha", prefix),
501
539
  )
502
540
 
503
541
  self.w_kc = None
@@ -510,23 +548,37 @@ class DeepseekV2AttentionMLA(nn.Module):
510
548
  hidden_states: torch.Tensor,
511
549
  forward_batch: ForwardBatch,
512
550
  ) -> torch.Tensor:
513
- if global_server_args_dict["enable_flashinfer_mla"]:
514
- if global_server_args_dict["disable_radix_cache"]:
515
- if forward_batch.forward_mode.is_extend():
516
- return self.forward_normal(positions, hidden_states, forward_batch)
517
- else:
518
- return self.forward_absorb(positions, hidden_states, forward_batch)
551
+
552
+ def no_absorb() -> bool:
553
+ if global_server_args_dict["enable_flashinfer_mla"]:
554
+ # Flashinfer MLA: Do not absorb when enabling ragged prefill
555
+ return (
556
+ not global_server_args_dict["flashinfer_mla_disable_ragged"]
557
+ and forward_batch.forward_mode.is_extend()
558
+ and forward_batch.extend_prefix_lens.sum() == 0
559
+ )
519
560
  else:
520
- return self.forward_absorb(positions, hidden_states, forward_batch)
561
+ # Triton: Use normal computation for prefill and use weight absorption for extend/decode
562
+ return (
563
+ forward_batch.forward_mode.is_extend()
564
+ and not forward_batch.forward_mode.is_target_verify()
565
+ and not forward_batch.forward_mode.is_draft_extend()
566
+ and forward_batch.extend_prefix_lens.sum() == 0
567
+ )
568
+
569
+ if no_absorb():
570
+ return self.forward_normal(positions, hidden_states, forward_batch)
521
571
  else:
522
- # Triton: Use normal computation for prefill and use weight absorption for extend/decode
523
- if (
524
- forward_batch.forward_mode.is_extend()
525
- and not forward_batch.forward_mode.is_target_verify()
526
- and not forward_batch.forward_mode.is_draft_extend()
527
- and forward_batch.extend_prefix_lens.sum() == 0
528
- ):
529
- return self.forward_normal(positions, hidden_states, forward_batch)
572
+ if is_hip_:
573
+ if (
574
+ os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
575
+ and forward_batch.forward_mode.is_decode()
576
+ ):
577
+ return self.forward_absorb_fused_mla_rope(
578
+ positions, hidden_states, forward_batch
579
+ )
580
+ else:
581
+ return self.forward_absorb(positions, hidden_states, forward_batch)
530
582
  else:
531
583
  return self.forward_absorb(positions, hidden_states, forward_batch)
532
584
 
@@ -647,6 +699,149 @@ class DeepseekV2AttentionMLA(nn.Module):
647
699
 
648
700
  return output
649
701
 
702
+ def forward_absorb_fused_mla_rope(
703
+ self,
704
+ positions: torch.Tensor,
705
+ hidden_states: torch.Tensor,
706
+ forward_batch: ForwardBatch,
707
+ ) -> torch.Tensor:
708
+ enable_rope_fusion = (
709
+ os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
710
+ )
711
+ q_len = hidden_states.shape[0]
712
+ q_input = hidden_states.new_empty(
713
+ q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
714
+ )
715
+ if self.q_lora_rank is not None:
716
+ q = self.q_a_proj(hidden_states)[0]
717
+ q = self.q_a_layernorm(q)
718
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
719
+ else:
720
+ q = self.q_proj(hidden_states)[0].view(
721
+ -1, self.num_local_heads, self.qk_head_dim
722
+ )
723
+ q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
724
+
725
+ if self.w_kc.dtype == torch.float8_e4m3fnuz:
726
+ # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
727
+ q_nope_out = torch.bmm(
728
+ q_nope.to(torch.bfloat16).transpose(0, 1),
729
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
730
+ )
731
+ elif self.w_kc.dtype == torch.float8_e4m3fn:
732
+ q_nope_val, q_nope_scale = input_to_float8(
733
+ q_nope.transpose(0, 1), torch.float8_e4m3fn
734
+ )
735
+ q_nope_out = bmm_fp8(
736
+ q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
737
+ )
738
+ else:
739
+ q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
740
+ q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
741
+
742
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
743
+ v_input = latent_cache[..., : self.kv_lora_rank]
744
+ v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
745
+ k_input = latent_cache.unsqueeze(1)
746
+ k_input[..., : self.kv_lora_rank] = v_input
747
+
748
+ if not enable_rope_fusion:
749
+ k_pe = k_input[..., self.kv_lora_rank :]
750
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
751
+ q_input[..., self.kv_lora_rank :] = q_pe
752
+ k_input[..., self.kv_lora_rank :] = k_pe
753
+ k_pe_output = None
754
+ else:
755
+ k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])
756
+
757
+ q_input[..., self.kv_lora_rank :] = q_pe
758
+
759
+ # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
760
+ # Use Fused ROPE with use_rope=OFF.
761
+ attn_output = torch.empty(
762
+ (q_len, self.num_local_heads, self.kv_lora_rank),
763
+ dtype=q.dtype,
764
+ device=q.device,
765
+ )
766
+ attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
767
+ forward_batch.attn_backend.forward_metadata
768
+ )
769
+ cos_sin_cache = self.rotary_emb.cos_sin_cache
770
+ num_kv_split = forward_batch.attn_backend.num_kv_splits
771
+ sm_scale = self.attn_mqa.scaling
772
+ if attn_logits is None:
773
+ attn_logits = torch.empty(
774
+ (
775
+ forward_batch.batch_size,
776
+ self.num_local_heads,
777
+ num_kv_split,
778
+ self.kv_lora_rank + 1,
779
+ ),
780
+ dtype=torch.float32,
781
+ device=q.device,
782
+ )
783
+
784
+ # save current latent cache.
785
+ forward_batch.token_to_kv_pool.set_kv_buffer(
786
+ self.attn_mqa, forward_batch.out_cache_loc, k_input, None
787
+ )
788
+ key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
789
+ self.attn_mqa.layer_id
790
+ )
791
+ val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
792
+
793
+ decode_attention_fwd_grouped_rope(
794
+ q_input,
795
+ key_cache_buf,
796
+ val_cache_buf,
797
+ attn_output,
798
+ kv_indptr,
799
+ kv_indices,
800
+ k_pe_output,
801
+ self.kv_lora_rank,
802
+ self.rotary_emb.rotary_dim,
803
+ cos_sin_cache,
804
+ positions,
805
+ attn_logits,
806
+ num_kv_split,
807
+ sm_scale,
808
+ logit_cap=self.attn_mqa.logit_cap,
809
+ use_rope=enable_rope_fusion,
810
+ is_neox_style=self.rotary_emb.is_neox_style,
811
+ )
812
+
813
+ if enable_rope_fusion:
814
+ k_input[..., self.kv_lora_rank :] = k_pe_output
815
+ forward_batch.token_to_kv_pool.set_kv_buffer(
816
+ self.attn_mqa, forward_batch.out_cache_loc, k_input, None
817
+ )
818
+
819
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
820
+
821
+ if self.w_vc.dtype == torch.float8_e4m3fnuz:
822
+ # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
823
+ attn_bmm_output = torch.bmm(
824
+ attn_output.to(torch.bfloat16).transpose(0, 1),
825
+ self.w_vc.to(torch.bfloat16) * self.w_scale,
826
+ )
827
+ elif self.w_vc.dtype == torch.float8_e4m3fn:
828
+ attn_output_val, attn_output_scale = input_to_float8(
829
+ attn_output.transpose(0, 1), torch.float8_e4m3fn
830
+ )
831
+ attn_bmm_output = bmm_fp8(
832
+ attn_output_val,
833
+ self.w_vc,
834
+ attn_output_scale,
835
+ self.w_scale,
836
+ torch.bfloat16,
837
+ )
838
+ else:
839
+ attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
840
+ attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
841
+ output, _ = self.o_proj(attn_output)
842
+
843
+ return output
844
+
650
845
 
651
846
  def all_gather(
652
847
  input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
@@ -654,16 +849,14 @@ def all_gather(
654
849
  if world_size == 1:
655
850
  return input_tensor
656
851
 
657
- all_lens = forward_batch.global_num_tokens
658
- max_len = max(forward_batch.global_num_tokens)
852
+ all_lens = forward_batch.global_num_tokens_cpu
853
+ max_len = max(forward_batch.global_num_tokens_cpu)
659
854
 
660
855
  padded_tensor = torch.nn.functional.pad(
661
856
  input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
662
857
  )
663
858
 
664
- torch.distributed.all_gather_into_tensor(
665
- forward_batch.gathered_buffer, padded_tensor, group=group
666
- )
859
+ group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
667
860
 
668
861
  gathered_tensors = torch.concat(
669
862
  [
@@ -686,6 +879,7 @@ class DeepseekV2DecoderLayer(nn.Module):
686
879
  layer_id: int,
687
880
  quant_config: Optional[QuantizationConfig] = None,
688
881
  is_nextn: bool = False,
882
+ prefix: str = "",
689
883
  ) -> None:
690
884
  super().__init__()
691
885
  self.hidden_size = config.hidden_size
@@ -699,7 +893,7 @@ class DeepseekV2DecoderLayer(nn.Module):
699
893
  if self.enable_dp_attention:
700
894
  self.tp_rank = get_tensor_model_parallel_rank()
701
895
  self.tp_size = get_tensor_model_parallel_world_size()
702
- self.tp_group = get_tp_group().device_group
896
+ self.tp_group = get_tp_group()
703
897
  if not global_server_args_dict["disable_mla"]:
704
898
  self.self_attn = DeepseekV2AttentionMLA(
705
899
  config=config,
@@ -718,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module):
718
912
  quant_config=quant_config,
719
913
  layer_id=layer_id,
720
914
  use_dp=self.enable_dp_attention,
915
+ prefix=add_prefix("self_attn", prefix),
721
916
  )
722
917
  else:
723
918
  self.self_attn = DeepseekV2Attention(
@@ -736,19 +931,25 @@ class DeepseekV2DecoderLayer(nn.Module):
736
931
  max_position_embeddings=max_position_embeddings,
737
932
  quant_config=quant_config,
738
933
  layer_id=layer_id,
934
+ prefix=add_prefix("self_attn", prefix),
739
935
  )
740
936
  if is_nextn or (
741
937
  config.n_routed_experts is not None
742
938
  and layer_id >= config.first_k_dense_replace
743
939
  and layer_id % config.moe_layer_freq == 0
744
940
  ):
745
- self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
941
+ self.mlp = DeepseekV2MoE(
942
+ config=config,
943
+ quant_config=quant_config,
944
+ prefix=add_prefix("mlp", prefix),
945
+ )
746
946
  else:
747
947
  self.mlp = DeepseekV2MLP(
748
948
  hidden_size=config.hidden_size,
749
949
  intermediate_size=config.intermediate_size,
750
950
  hidden_act=config.hidden_act,
751
951
  quant_config=quant_config,
952
+ prefix=add_prefix("mlp", prefix),
752
953
  )
753
954
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
754
955
  self.post_attention_layernorm = RMSNorm(
@@ -800,6 +1001,7 @@ class DeepseekV2Model(nn.Module):
800
1001
  self,
801
1002
  config: PretrainedConfig,
802
1003
  quant_config: Optional[QuantizationConfig] = None,
1004
+ prefix: str = "",
803
1005
  ) -> None:
804
1006
  super().__init__()
805
1007
  self.padding_id = config.pad_token_id
@@ -816,6 +1018,7 @@ class DeepseekV2Model(nn.Module):
816
1018
  config,
817
1019
  layer_id,
818
1020
  quant_config=quant_config,
1021
+ prefix=add_prefix(f"layers.{layer_id}", prefix),
819
1022
  )
820
1023
  for layer_id in range(config.num_hidden_layers)
821
1024
  ]
@@ -846,21 +1049,28 @@ class DeepseekV2ForCausalLM(nn.Module):
846
1049
  self,
847
1050
  config: PretrainedConfig,
848
1051
  quant_config: Optional[QuantizationConfig] = None,
1052
+ prefix: str = "",
849
1053
  ) -> None:
850
1054
  super().__init__()
851
1055
  self.config = config
852
1056
  self.quant_config = quant_config
853
- self.model = DeepseekV2Model(config, quant_config)
1057
+ self.model = DeepseekV2Model(
1058
+ config, quant_config, prefix=add_prefix("model", prefix)
1059
+ )
854
1060
  if global_server_args_dict["enable_dp_attention"]:
855
1061
  self.lm_head = ReplicatedLinear(
856
1062
  config.hidden_size,
857
1063
  config.vocab_size,
858
1064
  bias=False,
1065
+ prefix=add_prefix("lm_head", prefix),
859
1066
  )
860
1067
  self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
861
1068
  else:
862
1069
  self.lm_head = ParallelLMHead(
863
- config.vocab_size, config.hidden_size, quant_config=quant_config
1070
+ config.vocab_size,
1071
+ config.hidden_size,
1072
+ quant_config=quant_config,
1073
+ prefix=add_prefix("lm_head", prefix),
864
1074
  )
865
1075
  self.logits_processor = LogitsProcessor(config)
866
1076
 
@@ -992,6 +1202,18 @@ class DeepseekV2ForCausalLM(nn.Module):
992
1202
  weight, weight_scale, weight_block_size
993
1203
  )
994
1204
  self_attn.w_scale = scale
1205
+ if (
1206
+ hasattr(self.quant_config, "weight_block_size")
1207
+ and w.dtype == torch.int8
1208
+ ):
1209
+ weight_block_size = self.quant_config.weight_block_size
1210
+ if weight_block_size is not None:
1211
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1212
+ weight = w
1213
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1214
+ w = int8_block_dequant(
1215
+ weight, weight_scale, weight_block_size
1216
+ ).to(torch.bfloat16)
995
1217
  w_kc, w_vc = w.unflatten(
996
1218
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
997
1219
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
@@ -1005,6 +1227,17 @@ class DeepseekV2ForCausalLM(nn.Module):
1005
1227
  if is_hip_:
1006
1228
  self_attn.w_scale *= 2.0
1007
1229
 
1230
+ def get_embed_and_head(self):
1231
+ return self.model.embed_tokens.weight, self.lm_head.weight
1232
+
1233
+ def set_embed_and_head(self, embed, head):
1234
+ del self.model.embed_tokens.weight
1235
+ del self.lm_head.weight
1236
+ self.model.embed_tokens.weight = embed
1237
+ self.lm_head.weight = head
1238
+ torch.cuda.empty_cache()
1239
+ torch.cuda.synchronize()
1240
+
1008
1241
 
1009
1242
  class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
1010
1243
  pass