sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,8 @@
18
18
 
19
19
  import logging
20
20
  import os
21
+ from dataclasses import dataclass
22
+ from enum import Enum, IntEnum, auto
21
23
  from typing import Any, Dict, Iterable, Optional, Tuple
22
24
 
23
25
  import torch
@@ -27,6 +29,7 @@ from tqdm import tqdm
27
29
  from transformers import PretrainedConfig
28
30
 
29
31
  from sglang.srt.distributed import (
32
+ get_tensor_model_parallel_rank,
30
33
  get_tensor_model_parallel_world_size,
31
34
  parallel_state,
32
35
  tensor_model_parallel_all_reduce,
@@ -54,9 +57,14 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
54
57
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
55
58
  from sglang.srt.layers.moe.topk import select_experts
56
59
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
60
+ from sglang.srt.layers.quantization.fp8_kernel import (
61
+ _enable_jit_deepgemm_bmm,
62
+ per_tensor_quant_mla_deep_gemm_masked_fp8,
63
+ per_tensor_quant_mla_fp8,
64
+ )
57
65
  from sglang.srt.layers.quantization.fp8_utils import (
58
66
  block_quant_to_tensor_quant,
59
- input_to_float8,
67
+ channel_quant_to_tensor_quant,
60
68
  normalize_e4m3fn_to_e4m3fnuz,
61
69
  )
62
70
  from sglang.srt.layers.quantization.int8_utils import (
@@ -72,15 +80,16 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
72
80
  from sglang.srt.managers.schedule_batch import global_server_args_dict
73
81
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
74
82
  from sglang.srt.model_loader.weight_utils import default_weight_loader
75
- from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
83
+ from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip
76
84
 
77
85
  _is_hip = is_hip()
78
86
  _is_cuda = is_cuda()
79
87
 
80
88
  if _is_cuda:
81
- from sgl_kernel import awq_dequantize, bmm_fp8
89
+ from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
90
+ from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
82
91
  else:
83
- from vllm import _custom_ops as ops
92
+ from vllm._custom_ops import awq_dequantize
84
93
 
85
94
  if _is_hip:
86
95
  from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
@@ -92,6 +101,18 @@ expert_distribution_recorder = ExpertDistributionRecorder()
92
101
  logger = logging.getLogger(__name__)
93
102
 
94
103
 
104
+ class AttnForwardMethod(IntEnum):
105
+ # Use multi-head attention
106
+ MHA = auto()
107
+
108
+ # Use absorbed multi-latent attention
109
+ MLA = auto()
110
+
111
+ # Use multi-head attention, but with KV cache chunked.
112
+ # This method can avoid OOM when prefix lengths are long.
113
+ MHA_CHUNKED_KV = auto()
114
+
115
+
95
116
  class DeepseekV2MLP(nn.Module):
96
117
  def __init__(
97
118
  self,
@@ -131,7 +152,7 @@ class DeepseekV2MLP(nn.Module):
131
152
  )
132
153
  self.act_fn = SiluAndMul()
133
154
 
134
- def forward(self, x):
155
+ def forward(self, x, forward_mode: Optional[ForwardMode] = None):
135
156
  gate_up, _ = self.gate_up_proj(x)
136
157
  x = self.act_fn(gate_up)
137
158
  x, _ = self.down_proj(x)
@@ -172,13 +193,8 @@ class DeepseekV2MoE(nn.Module):
172
193
  self.tp_size = get_tensor_model_parallel_world_size()
173
194
  self.routed_scaling_factor = config.routed_scaling_factor
174
195
  self.n_shared_experts = config.n_shared_experts
175
- self.n_share_experts_fusion = (
176
- global_server_args_dict["n_share_experts_fusion"]
177
- if global_server_args_dict["n_share_experts_fusion"] is not None
178
- else 0
179
- )
196
+ self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
180
197
 
181
- self.routed_scaling_factor = config.routed_scaling_factor
182
198
  if self.tp_size > config.n_routed_experts:
183
199
  raise ValueError(
184
200
  f"Tensor parallel size {self.tp_size} is greater than "
@@ -210,6 +226,7 @@ class DeepseekV2MoE(nn.Module):
210
226
  num_expert_group=config.n_group,
211
227
  topk_group=config.topk_group,
212
228
  correction_bias=self.gate.e_score_correction_bias,
229
+ routed_scaling_factor=self.routed_scaling_factor,
213
230
  prefix=add_prefix("experts", prefix),
214
231
  **(
215
232
  dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
@@ -278,10 +295,7 @@ class DeepseekV2MoE(nn.Module):
278
295
  return self.forward_deepep(hidden_states, forward_mode)
279
296
 
280
297
  def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
281
- if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
282
- shared_output = self.shared_experts(hidden_states)
283
- else:
284
- shared_output = None
298
+ shared_output = self._forward_shared_experts(hidden_states)
285
299
  # router_logits: (num_tokens, n_experts)
286
300
  router_logits = self.gate(hidden_states)
287
301
  final_hidden_states = (
@@ -311,8 +325,7 @@ class DeepseekV2MoE(nn.Module):
311
325
  ):
312
326
  # router_logits: (num_tokens, n_experts)
313
327
  router_logits = self.gate(hidden_states)
314
- if self.n_shared_experts is not None:
315
- shared_output = self.shared_experts(hidden_states)
328
+ shared_output = self._forward_shared_experts(hidden_states)
316
329
  topk_weights, topk_idx = select_experts(
317
330
  hidden_states=hidden_states,
318
331
  router_logits=router_logits,
@@ -322,8 +335,10 @@ class DeepseekV2MoE(nn.Module):
322
335
  topk_group=self.topk_group,
323
336
  num_expert_group=self.num_expert_group,
324
337
  correction_bias=self.correction_bias,
338
+ routed_scaling_factor=self.routed_scaling_factor,
325
339
  )
326
340
  if self.ep_size > 1:
341
+ # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
327
342
  (
328
343
  hidden_states,
329
344
  topk_idx,
@@ -336,19 +351,15 @@ class DeepseekV2MoE(nn.Module):
336
351
  hidden_states,
337
352
  topk_idx,
338
353
  topk_weights,
339
- self.num_experts,
340
354
  forward_mode=forward_mode,
341
355
  )
342
- final_hidden_states = (
343
- self.experts(
344
- hidden_states=hidden_states,
345
- reorder_topk_ids=reorder_topk_ids,
346
- seg_indptr=seg_indptr,
347
- masked_m=masked_m,
348
- expected_m=expected_m,
349
- forward_mode=forward_mode,
350
- )
351
- * self.routed_scaling_factor
356
+ final_hidden_states = self.experts(
357
+ hidden_states=hidden_states,
358
+ reorder_topk_ids=reorder_topk_ids,
359
+ seg_indptr=seg_indptr,
360
+ masked_m=masked_m,
361
+ expected_m=expected_m,
362
+ forward_mode=forward_mode,
352
363
  )
353
364
  if self.ep_size > 1:
354
365
  final_hidden_states = self.deepep_dispatcher.combine(
@@ -357,11 +368,19 @@ class DeepseekV2MoE(nn.Module):
357
368
  topk_weights,
358
369
  forward_mode,
359
370
  )
371
+ final_hidden_states *= self.routed_scaling_factor
372
+
360
373
  if shared_output is not None:
361
374
  final_hidden_states = final_hidden_states + shared_output
362
375
 
363
376
  return final_hidden_states
364
377
 
378
+ def _forward_shared_experts(self, hidden_states):
379
+ if self.n_share_experts_fusion == 0:
380
+ return self.shared_experts(hidden_states)
381
+ else:
382
+ return None
383
+
365
384
 
366
385
  def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
367
386
  import math
@@ -371,178 +390,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
371
390
  return 0.1 * mscale * math.log(scale) + 1.0
372
391
 
373
392
 
374
- class DeepseekV2Attention(nn.Module):
375
-
376
- def __init__(
377
- self,
378
- config: PretrainedConfig,
379
- hidden_size: int,
380
- num_heads: int,
381
- qk_nope_head_dim: int,
382
- qk_rope_head_dim: int,
383
- v_head_dim: int,
384
- q_lora_rank: int,
385
- kv_lora_rank: int,
386
- rope_theta: float = 10000,
387
- rope_scaling: Optional[Dict[str, Any]] = None,
388
- max_position_embeddings: int = 8192,
389
- quant_config: Optional[QuantizationConfig] = None,
390
- layer_id=None,
391
- reduce_results: bool = True,
392
- prefix: str = "",
393
- ) -> None:
394
- super().__init__()
395
- self.layer_id = layer_id
396
- self.hidden_size = hidden_size
397
- self.qk_nope_head_dim = qk_nope_head_dim
398
- self.qk_rope_head_dim = qk_rope_head_dim
399
- self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
400
- self.v_head_dim = v_head_dim
401
- self.q_lora_rank = q_lora_rank
402
- self.kv_lora_rank = kv_lora_rank
403
-
404
- self.dp_size = get_attention_dp_size()
405
- attn_tp_rank = get_attention_tp_rank()
406
- attn_tp_size = get_attention_tp_size()
407
-
408
- self.num_heads = num_heads
409
- assert num_heads % attn_tp_size == 0
410
- self.num_local_heads = num_heads // attn_tp_size
411
- self.scaling = self.qk_head_dim**-0.5
412
- self.rope_theta = rope_theta
413
- self.max_position_embeddings = max_position_embeddings
414
-
415
- if self.q_lora_rank is not None:
416
- self.q_a_proj = ReplicatedLinear(
417
- self.hidden_size,
418
- self.q_lora_rank,
419
- bias=False,
420
- quant_config=quant_config,
421
- prefix=add_prefix("q_a_proj", prefix),
422
- )
423
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
424
- self.q_b_proj = ColumnParallelLinear(
425
- q_lora_rank,
426
- self.num_heads * self.qk_head_dim,
427
- bias=False,
428
- quant_config=quant_config,
429
- prefix=add_prefix("q_b_proj", prefix),
430
- )
431
- else:
432
- self.q_proj = ColumnParallelLinear(
433
- self.hidden_size,
434
- self.num_heads * self.qk_head_dim,
435
- bias=False,
436
- quant_config=quant_config,
437
- prefix=add_prefix("q_proj", prefix),
438
- tp_rank=attn_tp_rank,
439
- tp_size=attn_tp_size,
440
- )
441
-
442
- self.kv_a_proj_with_mqa = ReplicatedLinear(
443
- self.hidden_size,
444
- self.kv_lora_rank + self.qk_rope_head_dim,
445
- bias=False,
446
- quant_config=quant_config,
447
- prefix=add_prefix("kv_a_proj_with_mqa", prefix),
448
- )
449
- self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
450
- self.kv_b_proj = ColumnParallelLinear(
451
- self.kv_lora_rank,
452
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
453
- bias=False,
454
- quant_config=quant_config,
455
- prefix=add_prefix("kv_b_proj", prefix),
456
- )
457
- # O projection.
458
- self.o_proj = RowParallelLinear(
459
- self.num_heads * self.v_head_dim,
460
- self.hidden_size,
461
- bias=False,
462
- quant_config=quant_config,
463
- prefix=add_prefix("o_proj", prefix),
464
- reduce_results=reduce_results,
465
- tp_rank=attn_tp_rank,
466
- tp_size=attn_tp_size,
467
- )
468
- rope_scaling["rope_type"] = "deepseek_yarn"
469
- self.rotary_emb = get_rope_wrapper(
470
- qk_rope_head_dim,
471
- rotary_dim=qk_rope_head_dim,
472
- max_position=max_position_embeddings,
473
- base=rope_theta,
474
- rope_scaling=rope_scaling,
475
- is_neox_style=False,
476
- device=global_server_args_dict["device"],
477
- )
478
-
479
- if rope_scaling:
480
- mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
481
- scaling_factor = rope_scaling["factor"]
482
- mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
483
- self.scaling = self.scaling * mscale * mscale
484
-
485
- # TODO, support head_size 192
486
- self.attn = RadixAttention(
487
- self.num_local_heads,
488
- 256,
489
- self.scaling,
490
- num_kv_heads=self.num_local_heads,
491
- layer_id=layer_id,
492
- prefix=add_prefix("attn", prefix),
493
- )
494
-
495
- def forward(
496
- self,
497
- positions: torch.Tensor,
498
- hidden_states: torch.Tensor,
499
- forward_batch: ForwardBatch,
500
- ) -> torch.Tensor:
501
- if hidden_states.shape[0] == 0:
502
- assert (
503
- not self.o_proj.reduce_results
504
- ), "short-circuiting allreduce will lead to hangs"
505
- return hidden_states
506
-
507
- if self.q_lora_rank is not None:
508
- q = self.q_a_proj(hidden_states)[0]
509
- q = self.q_a_layernorm(q)
510
- q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
511
- else:
512
- q = self.q_proj(hidden_states)[0].view(
513
- -1, self.num_local_heads, self.qk_head_dim
514
- )
515
- _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
516
- latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
517
- kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
518
- latent_cache = latent_cache.unsqueeze(1)
519
- kv_a = self.kv_a_layernorm(kv_a.contiguous())
520
- kv = self.kv_b_proj(kv_a)[0]
521
- kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
522
- k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
523
- k_pe = latent_cache[:, :, self.kv_lora_rank :]
524
- q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
525
- q[..., self.qk_nope_head_dim :] = q_pe
526
- k = torch.empty_like(q)
527
- k[..., : self.qk_nope_head_dim] = k_nope
528
- k[..., self.qk_nope_head_dim :] = k_pe
529
- q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view(
530
- -1, self.num_local_heads * 256
531
- )
532
- k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(
533
- -1, self.num_local_heads * 256
534
- )
535
- v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
536
- -1, self.num_local_heads * 256
537
- )
538
- attn_output = self.attn(q, k, v, forward_batch)
539
- attn_output = attn_output.view(-1, self.num_local_heads, 256)[
540
- ..., : self.v_head_dim
541
- ].reshape(-1, self.num_local_heads * self.v_head_dim)
542
- output, _ = self.o_proj(attn_output)
543
- return output
544
-
545
-
546
393
  class DeepseekV2AttentionMLA(nn.Module):
547
394
 
548
395
  def __init__(
@@ -669,6 +516,7 @@ class DeepseekV2AttentionMLA(nn.Module):
669
516
  num_kv_heads=1,
670
517
  layer_id=layer_id,
671
518
  v_head_dim=self.kv_lora_rank,
519
+ quant_config=quant_config,
672
520
  prefix=add_prefix("attn_mqa", prefix),
673
521
  )
674
522
 
@@ -679,6 +527,7 @@ class DeepseekV2AttentionMLA(nn.Module):
679
527
  num_kv_heads=self.num_local_heads,
680
528
  layer_id=layer_id,
681
529
  v_head_dim=self.v_head_dim,
530
+ quant_config=quant_config,
682
531
  prefix=add_prefix("attn_mha", prefix),
683
532
  )
684
533
 
@@ -686,39 +535,68 @@ class DeepseekV2AttentionMLA(nn.Module):
686
535
  self.w_vc = None
687
536
  self.w_scale = None
688
537
 
538
+ self.w_scale_k = None
539
+ self.w_scale_v = None
540
+ self.use_deep_gemm_bmm = False
541
+
689
542
  self.flashinfer_mla_disable_ragged = global_server_args_dict[
690
543
  "flashinfer_mla_disable_ragged"
691
544
  ]
545
+ self.disable_chunked_prefix_cache = global_server_args_dict[
546
+ "disable_chunked_prefix_cache"
547
+ ]
692
548
  self.attention_backend = global_server_args_dict["attention_backend"]
693
549
  self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
694
550
 
695
- def no_absorb(self, forward_batch: ForwardBatch) -> bool:
551
+ # TODO: Design a finer way to determine the threshold
552
+ self.chunked_prefix_cache_threshold = 8192
553
+
554
+ def dispatch_attn_forward_method(
555
+ self, forward_batch: ForwardBatch
556
+ ) -> AttnForwardMethod:
696
557
  if self.attention_backend == "flashinfer":
697
558
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
698
- return (
559
+ if (
699
560
  not self.flashinfer_mla_disable_ragged
700
561
  and forward_batch.forward_mode.is_extend()
701
562
  and not forward_batch.forward_mode.is_target_verify()
702
563
  and not forward_batch.forward_mode.is_draft_extend()
703
564
  and sum(forward_batch.extend_prefix_lens_cpu) == 0
704
- )
565
+ ):
566
+ return AttnForwardMethod.MHA
567
+ else:
568
+ return AttnForwardMethod.MLA
705
569
  elif self.attention_backend == "fa3":
706
- # Flash Attention: Keep absorbing for all extend/decode
707
- return False
570
+ # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
571
+ if (
572
+ forward_batch.forward_mode.is_extend()
573
+ and not self.disable_chunked_prefix_cache
574
+ and not forward_batch.forward_mode.is_target_verify()
575
+ and not forward_batch.forward_mode.is_draft_extend()
576
+ and sum(forward_batch.extend_prefix_lens_cpu)
577
+ >= self.chunked_prefix_cache_threshold
578
+ ):
579
+ return AttnForwardMethod.MHA_CHUNKED_KV
580
+ else:
581
+ return AttnForwardMethod.MLA
708
582
  else:
709
583
  # Triton: Use normal computation for prefill and use weight absorption for extend/decode
710
- return (
584
+ if (
711
585
  forward_batch.forward_mode.is_extend()
712
586
  and not forward_batch.forward_mode.is_target_verify()
713
587
  and not forward_batch.forward_mode.is_draft_extend()
714
588
  and sum(forward_batch.extend_prefix_lens_cpu) == 0
715
- )
589
+ ):
590
+ return AttnForwardMethod.MHA
591
+ else:
592
+ return AttnForwardMethod.MLA
716
593
 
717
594
  def forward(
718
595
  self,
719
596
  positions: torch.Tensor,
720
597
  hidden_states: torch.Tensor,
721
598
  forward_batch: ForwardBatch,
599
+ zero_allocator: BumpAllocator,
722
600
  ) -> torch.Tensor:
723
601
  if hidden_states.shape[0] == 0:
724
602
  assert (
@@ -726,8 +604,14 @@ class DeepseekV2AttentionMLA(nn.Module):
726
604
  ), "short-circuiting allreduce will lead to hangs"
727
605
  return hidden_states
728
606
 
729
- if self.no_absorb(forward_batch):
607
+ attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
608
+
609
+ if attn_forward_method == AttnForwardMethod.MHA:
730
610
  return self.forward_normal(positions, hidden_states, forward_batch)
611
+ elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
612
+ return self.forward_normal_chunked_kv(
613
+ positions, hidden_states, forward_batch
614
+ )
731
615
  else:
732
616
  if _is_hip:
733
617
  if (
@@ -738,9 +622,13 @@ class DeepseekV2AttentionMLA(nn.Module):
738
622
  positions, hidden_states, forward_batch
739
623
  )
740
624
  else:
741
- return self.forward_absorb(positions, hidden_states, forward_batch)
625
+ return self.forward_absorb(
626
+ positions, hidden_states, forward_batch, zero_allocator
627
+ )
742
628
  else:
743
- return self.forward_absorb(positions, hidden_states, forward_batch)
629
+ return self.forward_absorb(
630
+ positions, hidden_states, forward_batch, zero_allocator
631
+ )
744
632
 
745
633
  def forward_normal(
746
634
  self,
@@ -789,6 +677,7 @@ class DeepseekV2AttentionMLA(nn.Module):
789
677
  positions: torch.Tensor,
790
678
  hidden_states: torch.Tensor,
791
679
  forward_batch: ForwardBatch,
680
+ zero_allocator: BumpAllocator,
792
681
  ) -> torch.Tensor:
793
682
  q_len = hidden_states.shape[0]
794
683
  q_input = hidden_states.new_empty(
@@ -804,15 +693,33 @@ class DeepseekV2AttentionMLA(nn.Module):
804
693
  )
805
694
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
806
695
 
807
- if self.w_kc.dtype == torch.float8_e4m3fnuz:
696
+ if self.use_deep_gemm_bmm:
697
+ q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
698
+ per_tensor_quant_mla_deep_gemm_masked_fp8(
699
+ q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
700
+ )
701
+ )
702
+ q_nope_out = q_nope.new_empty(
703
+ (self.num_local_heads, aligned_m, self.kv_lora_rank)
704
+ )
705
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked(
706
+ (q_nope_val, q_nope_scale),
707
+ (self.w_kc, self.w_scale_k),
708
+ q_nope_out,
709
+ masked_m,
710
+ expected_m,
711
+ )
712
+ q_nope_out = q_nope_out[:, :expected_m, :]
713
+ elif self.w_kc.dtype == torch.float8_e4m3fnuz:
808
714
  # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
809
715
  q_nope_out = torch.bmm(
810
716
  q_nope.to(torch.bfloat16).transpose(0, 1),
811
717
  self.w_kc.to(torch.bfloat16) * self.w_scale,
812
718
  )
813
719
  elif self.w_kc.dtype == torch.float8_e4m3fn:
814
- q_nope_val, q_nope_scale = input_to_float8(
815
- q_nope.transpose(0, 1), torch.float8_e4m3fn
720
+ q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
721
+ q_nope.transpose(0, 1),
722
+ zero_allocator.allocate(1),
816
723
  )
817
724
  q_nope_out = bmm_fp8(
818
725
  q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -835,15 +742,33 @@ class DeepseekV2AttentionMLA(nn.Module):
835
742
  attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
836
743
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
837
744
 
838
- if self.w_vc.dtype == torch.float8_e4m3fnuz:
745
+ if self.use_deep_gemm_bmm:
746
+ attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
747
+ per_tensor_quant_mla_deep_gemm_masked_fp8(
748
+ attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
749
+ )
750
+ )
751
+ attn_bmm_output = attn_output.new_empty(
752
+ (self.num_local_heads, aligned_m, self.v_head_dim)
753
+ )
754
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked(
755
+ (attn_output_val, attn_output_scale),
756
+ (self.w_vc, self.w_scale_v),
757
+ attn_bmm_output,
758
+ masked_m,
759
+ expected_m,
760
+ )
761
+ attn_bmm_output = attn_bmm_output[:, :expected_m, :]
762
+ elif self.w_vc.dtype == torch.float8_e4m3fnuz:
839
763
  # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
840
764
  attn_bmm_output = torch.bmm(
841
765
  attn_output.to(torch.bfloat16).transpose(0, 1),
842
766
  self.w_vc.to(torch.bfloat16) * self.w_scale,
843
767
  )
844
768
  elif self.w_vc.dtype == torch.float8_e4m3fn:
845
- attn_output_val, attn_output_scale = input_to_float8(
846
- attn_output.transpose(0, 1), torch.float8_e4m3fn
769
+ attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
770
+ attn_output.transpose(0, 1),
771
+ zero_allocator.allocate(1),
847
772
  )
848
773
  attn_bmm_output = bmm_fp8(
849
774
  attn_output_val,
@@ -864,6 +789,7 @@ class DeepseekV2AttentionMLA(nn.Module):
864
789
  positions: torch.Tensor,
865
790
  hidden_states: torch.Tensor,
866
791
  forward_batch: ForwardBatch,
792
+ zero_allocator: BumpAllocator,
867
793
  ) -> torch.Tensor:
868
794
  enable_rope_fusion = (
869
795
  os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
@@ -889,8 +815,10 @@ class DeepseekV2AttentionMLA(nn.Module):
889
815
  self.w_kc.to(torch.bfloat16) * self.w_scale,
890
816
  )
891
817
  elif self.w_kc.dtype == torch.float8_e4m3fn:
892
- q_nope_val, q_nope_scale = input_to_float8(
893
- q_nope.transpose(0, 1), torch.float8_e4m3fn
818
+ q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
819
+ q_nope.transpose(0, 1),
820
+ zero_allocator.allocate(1),
821
+ dtype=torch.float8_e4m3fn,
894
822
  )
895
823
  q_nope_out = bmm_fp8(
896
824
  q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -985,8 +913,10 @@ class DeepseekV2AttentionMLA(nn.Module):
985
913
  self.w_vc.to(torch.bfloat16) * self.w_scale,
986
914
  )
987
915
  elif self.w_vc.dtype == torch.float8_e4m3fn:
988
- attn_output_val, attn_output_scale = input_to_float8(
989
- attn_output.transpose(0, 1), torch.float8_e4m3fn
916
+ attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
917
+ attn_output.transpose(0, 1),
918
+ zero_allocator.allocate(1),
919
+ dtype=torch.float8_e4m3fn,
990
920
  )
991
921
  attn_bmm_output = bmm_fp8(
992
922
  attn_output_val,
@@ -1002,6 +932,140 @@ class DeepseekV2AttentionMLA(nn.Module):
1002
932
 
1003
933
  return output
1004
934
 
935
+ def _chunked_prefix_attn_mha(
936
+ self,
937
+ q: torch.Tensor,
938
+ accum_output: torch.Tensor,
939
+ accum_lse: torch.Tensor,
940
+ forward_batch: ForwardBatch,
941
+ ) -> torch.Tensor:
942
+
943
+ assert forward_batch.num_prefix_chunks is not None
944
+ for i in range(forward_batch.num_prefix_chunks):
945
+ forward_batch.set_prefix_chunk_idx(i)
946
+
947
+ # Fetch latent cache from memory pool with precomputed chunked kv indices
948
+ latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
949
+ self.attn_mha.layer_id
950
+ )
951
+ latent_cache = latent_cache_buf[
952
+ forward_batch.prefix_chunk_kv_indices[i]
953
+ ].contiguous()
954
+
955
+ kv_a_normed, k_pe = latent_cache.split(
956
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
957
+ )
958
+ kv_a_normed = kv_a_normed.squeeze(1).contiguous()
959
+ kv = self.kv_b_proj(kv_a_normed)[0]
960
+ kv = kv.view(
961
+ -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
962
+ )
963
+ v = kv[..., self.qk_nope_head_dim :]
964
+ k_nope = kv[..., : self.qk_nope_head_dim]
965
+
966
+ k = torch.empty(
967
+ (
968
+ k_nope.shape[0],
969
+ self.num_local_heads,
970
+ self.qk_nope_head_dim + self.qk_rope_head_dim,
971
+ ),
972
+ dtype=v.dtype,
973
+ device=v.device,
974
+ )
975
+ k[..., : self.qk_nope_head_dim] = k_nope
976
+ k[..., self.qk_nope_head_dim :] = k_pe
977
+
978
+ output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
979
+ lse = torch.transpose(lse, 0, 1).contiguous()
980
+ tmp_output = torch.empty_like(accum_output)
981
+ tmp_lse = torch.empty_like(accum_lse)
982
+ merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
983
+ accum_output, accum_lse = tmp_output, tmp_lse
984
+
985
+ return accum_output
986
+
987
+ def forward_normal_chunked_kv(
988
+ self,
989
+ positions: torch.Tensor,
990
+ hidden_states: torch.Tensor,
991
+ forward_batch: ForwardBatch,
992
+ ) -> torch.Tensor:
993
+ # In normal mha, the k and v tensors will become overly large when the prefix length is long.
994
+ # To avoid this, we split the kv cache into chunks and process them one after another.
995
+ # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
996
+ # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
997
+ # will be helpful for understanding the purpose of this function.
998
+
999
+ # First do normal mha forward to get output for extended part
1000
+ if self.q_lora_rank is not None:
1001
+ q = self.q_a_proj(hidden_states)[0]
1002
+ q = self.q_a_layernorm(q)
1003
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
1004
+ else:
1005
+ q = self.q_proj(hidden_states)[0].view(
1006
+ -1, self.num_local_heads, self.qk_head_dim
1007
+ )
1008
+ _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1009
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1010
+ kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1011
+ latent_cache = latent_cache.unsqueeze(1)
1012
+ kv_a = self.kv_a_layernorm(kv_a.contiguous())
1013
+ kv = self.kv_b_proj(kv_a)[0]
1014
+ kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1015
+ k_nope = kv[..., : self.qk_nope_head_dim]
1016
+ v = kv[..., self.qk_nope_head_dim :]
1017
+ k_pe = latent_cache[:, :, self.kv_lora_rank :]
1018
+
1019
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1020
+ q[..., self.qk_nope_head_dim :] = q_pe
1021
+ k = torch.empty_like(q)
1022
+ k[..., : self.qk_nope_head_dim] = k_nope
1023
+ k[..., self.qk_nope_head_dim :] = k_pe
1024
+
1025
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1026
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
1027
+
1028
+ # Save latent cache
1029
+ forward_batch.token_to_kv_pool.set_kv_buffer(
1030
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1031
+ )
1032
+
1033
+ # Do mha for extended part without prefix
1034
+ forward_batch.set_attn_attend_prefix_cache(False)
1035
+ attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
1036
+ lse = torch.transpose(lse, 0, 1).contiguous()
1037
+
1038
+ # Do mha attention with chunked prefix cache if there are any sequence with prefix
1039
+ if any(forward_batch.extend_prefix_lens_cpu):
1040
+ # Only initialize the info once
1041
+ if forward_batch.num_prefix_chunks is None:
1042
+ forward_batch.prepare_chunked_prefix_cache_info(q.device)
1043
+
1044
+ forward_batch.set_attn_attend_prefix_cache(True)
1045
+ attn_output = self._chunked_prefix_attn_mha(
1046
+ q=q,
1047
+ accum_output=attn_output,
1048
+ accum_lse=lse,
1049
+ forward_batch=forward_batch,
1050
+ )
1051
+
1052
+ attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
1053
+ output, _ = self.o_proj(attn_output)
1054
+ return output
1055
+
1056
+
1057
+ class _FFNInputMode(Enum):
1058
+ # The MLP sublayer requires 1/tp_size tokens as input
1059
+ SCATTERED = auto()
1060
+ # The MLP sublayer requires all tokens as input
1061
+ FULL = auto()
1062
+
1063
+
1064
+ @dataclass
1065
+ class _DecoderLayerInfo:
1066
+ is_sparse: bool
1067
+ ffn_input_mode: _FFNInputMode
1068
+
1005
1069
 
1006
1070
  class DeepseekV2DecoderLayer(nn.Module):
1007
1071
 
@@ -1013,14 +1077,6 @@ class DeepseekV2DecoderLayer(nn.Module):
1013
1077
  is_nextn: bool = False,
1014
1078
  prefix: str = "",
1015
1079
  ) -> None:
1016
-
1017
- def is_sparse_layer(l: int):
1018
- return (
1019
- config.n_routed_experts is not None
1020
- and l >= config.first_k_dense_replace
1021
- and l % config.moe_layer_freq == 0
1022
- )
1023
-
1024
1080
  super().__init__()
1025
1081
  self.hidden_size = config.hidden_size
1026
1082
  rope_theta = getattr(config, "rope_theta", 10000)
@@ -1031,68 +1087,54 @@ class DeepseekV2DecoderLayer(nn.Module):
1031
1087
  self.dp_size = get_attention_dp_size()
1032
1088
  self.attn_tp_size = get_attention_tp_size()
1033
1089
  self.attn_tp_rank = get_attention_tp_rank()
1090
+ self.self_attn = DeepseekV2AttentionMLA(
1091
+ config=config,
1092
+ hidden_size=self.hidden_size,
1093
+ num_heads=config.num_attention_heads,
1094
+ qk_nope_head_dim=config.qk_nope_head_dim,
1095
+ qk_rope_head_dim=config.qk_rope_head_dim,
1096
+ v_head_dim=config.v_head_dim,
1097
+ q_lora_rank=(
1098
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
1099
+ ),
1100
+ kv_lora_rank=config.kv_lora_rank,
1101
+ rope_theta=rope_theta,
1102
+ rope_scaling=rope_scaling,
1103
+ max_position_embeddings=max_position_embeddings,
1104
+ quant_config=quant_config,
1105
+ layer_id=layer_id,
1106
+ reduce_results=False,
1107
+ prefix=add_prefix("self_attn", prefix),
1108
+ )
1034
1109
 
1035
- if not global_server_args_dict["disable_mla"]:
1036
- self.self_attn = DeepseekV2AttentionMLA(
1037
- config=config,
1038
- hidden_size=self.hidden_size,
1039
- num_heads=config.num_attention_heads,
1040
- qk_nope_head_dim=config.qk_nope_head_dim,
1041
- qk_rope_head_dim=config.qk_rope_head_dim,
1042
- v_head_dim=config.v_head_dim,
1043
- q_lora_rank=(
1044
- config.q_lora_rank if hasattr(config, "q_lora_rank") else None
1045
- ),
1046
- kv_lora_rank=config.kv_lora_rank,
1047
- rope_theta=rope_theta,
1048
- rope_scaling=rope_scaling,
1049
- max_position_embeddings=max_position_embeddings,
1050
- quant_config=quant_config,
1051
- layer_id=layer_id,
1052
- reduce_results=False,
1053
- prefix=add_prefix("self_attn", prefix),
1054
- )
1055
- else:
1056
- self.self_attn = DeepseekV2Attention(
1057
- config=config,
1058
- hidden_size=self.hidden_size,
1059
- num_heads=config.num_attention_heads,
1060
- qk_nope_head_dim=config.qk_nope_head_dim,
1061
- qk_rope_head_dim=config.qk_rope_head_dim,
1062
- v_head_dim=config.v_head_dim,
1063
- q_lora_rank=(
1064
- config.q_lora_rank if hasattr(config, "q_lora_rank") else None
1065
- ),
1066
- kv_lora_rank=config.kv_lora_rank,
1067
- rope_theta=rope_theta,
1068
- rope_scaling=rope_scaling,
1069
- max_position_embeddings=max_position_embeddings,
1070
- quant_config=quant_config,
1071
- layer_id=layer_id,
1072
- reduce_results=False,
1073
- prefix=add_prefix("self_attn", prefix),
1074
- )
1110
+ self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
1111
+ previous_layer_info = self._compute_info(
1112
+ config, layer_id=layer_id - 1, is_nextn=False
1113
+ )
1075
1114
 
1076
- if is_nextn or is_sparse_layer(layer_id):
1115
+ if self.info.is_sparse:
1077
1116
  self.mlp = DeepseekV2MoE(
1078
1117
  config=config,
1079
1118
  quant_config=quant_config,
1080
1119
  prefix=add_prefix("mlp", prefix),
1081
1120
  )
1082
- self.is_sparse = True
1083
1121
  else:
1122
+ if self._enable_moe_dense_fully_dp():
1123
+ mlp_tp_rank, mlp_tp_size = 0, 1
1124
+ else:
1125
+ mlp_tp_rank, mlp_tp_size = None, None
1084
1126
  self.mlp = DeepseekV2MLP(
1085
1127
  hidden_size=config.hidden_size,
1086
1128
  intermediate_size=config.intermediate_size,
1087
1129
  hidden_act=config.hidden_act,
1088
1130
  quant_config=quant_config,
1089
1131
  prefix=add_prefix("mlp", prefix),
1132
+ tp_rank=mlp_tp_rank,
1133
+ tp_size=mlp_tp_size,
1090
1134
  )
1091
- self.is_sparse = False
1092
1135
 
1093
1136
  self.input_is_scattered = (
1094
- is_sparse_layer(layer_id - 1)
1095
- and global_server_args_dict["enable_deepep_moe"]
1137
+ previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
1096
1138
  )
1097
1139
  self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
1098
1140
 
@@ -1101,28 +1143,51 @@ class DeepseekV2DecoderLayer(nn.Module):
1101
1143
  config.hidden_size, eps=config.rms_norm_eps
1102
1144
  )
1103
1145
 
1146
+ @staticmethod
1147
+ def _enable_moe_dense_fully_dp():
1148
+ return global_server_args_dict["moe_dense_tp_size"] == 1
1149
+
1150
+ @staticmethod
1151
+ def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
1152
+ is_sparse = is_nextn or (
1153
+ config.n_routed_experts is not None
1154
+ and layer_id >= config.first_k_dense_replace
1155
+ and layer_id % config.moe_layer_freq == 0
1156
+ )
1157
+ ffn_input_mode = (
1158
+ _FFNInputMode.SCATTERED
1159
+ if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
1160
+ or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
1161
+ else _FFNInputMode.FULL
1162
+ )
1163
+ return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
1164
+
1104
1165
  def forward(
1105
1166
  self,
1106
1167
  positions: torch.Tensor,
1107
1168
  hidden_states: torch.Tensor,
1108
1169
  forward_batch: ForwardBatch,
1109
1170
  residual: Optional[torch.Tensor],
1171
+ zero_allocator: BumpAllocator,
1110
1172
  ) -> torch.Tensor:
1111
- if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
1112
- return self.forward_deepep(
1113
- positions, hidden_states, forward_batch, residual
1173
+ if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
1174
+ return self.forward_ffn_with_scattered_input(
1175
+ positions, hidden_states, forward_batch, residual, zero_allocator
1114
1176
  )
1115
- else:
1116
- return self.forward_normal(
1117
- positions, hidden_states, forward_batch, residual
1177
+ elif self.info.ffn_input_mode == _FFNInputMode.FULL:
1178
+ return self.forward_ffn_with_full_input(
1179
+ positions, hidden_states, forward_batch, residual, zero_allocator
1118
1180
  )
1181
+ else:
1182
+ raise NotImplementedError
1119
1183
 
1120
- def forward_normal(
1184
+ def forward_ffn_with_full_input(
1121
1185
  self,
1122
1186
  positions: torch.Tensor,
1123
1187
  hidden_states: torch.Tensor,
1124
1188
  forward_batch: ForwardBatch,
1125
1189
  residual: Optional[torch.Tensor],
1190
+ zero_allocator: BumpAllocator,
1126
1191
  ) -> torch.Tensor:
1127
1192
 
1128
1193
  if hidden_states.shape[0] == 0:
@@ -1143,6 +1208,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1143
1208
  positions=positions,
1144
1209
  hidden_states=hidden_states,
1145
1210
  forward_batch=forward_batch,
1211
+ zero_allocator=zero_allocator,
1146
1212
  )
1147
1213
 
1148
1214
  # Gather
@@ -1184,12 +1250,13 @@ class DeepseekV2DecoderLayer(nn.Module):
1184
1250
 
1185
1251
  return hidden_states, residual
1186
1252
 
1187
- def forward_deepep(
1253
+ def forward_ffn_with_scattered_input(
1188
1254
  self,
1189
1255
  positions: torch.Tensor,
1190
1256
  hidden_states: torch.Tensor,
1191
1257
  forward_batch: ForwardBatch,
1192
1258
  residual: Optional[torch.Tensor],
1259
+ zero_allocator: BumpAllocator,
1193
1260
  ) -> torch.Tensor:
1194
1261
 
1195
1262
  if hidden_states.shape[0] == 0:
@@ -1215,6 +1282,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1215
1282
  positions=positions,
1216
1283
  hidden_states=hidden_states,
1217
1284
  forward_batch=forward_batch,
1285
+ zero_allocator=zero_allocator,
1218
1286
  )
1219
1287
 
1220
1288
  if self.attn_tp_size != 1:
@@ -1240,7 +1308,13 @@ class DeepseekV2DecoderLayer(nn.Module):
1240
1308
  hidden_states, residual = self.post_attention_layernorm(
1241
1309
  hidden_states, residual
1242
1310
  )
1243
- hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1311
+
1312
+ if not (
1313
+ self._enable_moe_dense_fully_dp()
1314
+ and (not self.info.is_sparse)
1315
+ and hidden_states.shape[0] == 0
1316
+ ):
1317
+ hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1244
1318
 
1245
1319
  if self.is_last_layer and self.attn_tp_size != 1:
1246
1320
  hidden_states += residual
@@ -1296,6 +1370,14 @@ class DeepseekV2Model(nn.Module):
1296
1370
  forward_batch: ForwardBatch,
1297
1371
  input_embeds: torch.Tensor = None,
1298
1372
  ) -> torch.Tensor:
1373
+ zero_allocator = BumpAllocator(
1374
+ # TODO for two-batch-overlap, we need a larger buffer size
1375
+ buffer_size=len(self.layers) * 2,
1376
+ dtype=torch.float32,
1377
+ device=(
1378
+ input_embeds.device if input_embeds is not None else input_ids.device
1379
+ ),
1380
+ )
1299
1381
 
1300
1382
  if input_embeds is None:
1301
1383
  hidden_states = self.embed_tokens(input_ids)
@@ -1307,7 +1389,7 @@ class DeepseekV2Model(nn.Module):
1307
1389
  expert_distribution_recorder.set_current_layer(i)
1308
1390
  layer = self.layers[i]
1309
1391
  hidden_states, residual = layer(
1310
- positions, hidden_states, forward_batch, residual
1392
+ positions, hidden_states, forward_batch, residual, zero_allocator
1311
1393
  )
1312
1394
  if not forward_batch.forward_mode.is_idle():
1313
1395
  if residual is None:
@@ -1330,24 +1412,33 @@ class DeepseekV2ForCausalLM(nn.Module):
1330
1412
  self.tp_size = get_tensor_model_parallel_world_size()
1331
1413
  self.quant_config = quant_config
1332
1414
  self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
1333
- # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1334
- if (
1335
- global_server_args_dict.get("disable_shared_experts_fusion", False)
1336
- or self.config.architectures[0] != "DeepseekV3ForCausalLM"
1337
- or self.config.n_routed_experts != 256
1338
- or self.config.routed_scaling_factor != 2.5
1339
- ):
1340
- self.n_share_experts_fusion = None
1341
- global_server_args_dict["n_share_experts_fusion"] = None
1342
- logger.info(
1343
- "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
1344
- )
1345
- elif self.n_share_experts_fusion is None:
1346
- global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1347
- self.n_share_experts_fusion = self.tp_size
1348
- logger.info(
1349
- f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
1350
- )
1415
+ if self.n_share_experts_fusion > 0:
1416
+ # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1417
+ if (
1418
+ self.config.architectures[0] != "DeepseekV3ForCausalLM"
1419
+ or self.config.n_routed_experts != 256
1420
+ ):
1421
+ self.n_share_experts_fusion = 0
1422
+ global_server_args_dict["n_share_experts_fusion"] = 0
1423
+ logger.info(
1424
+ "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
1425
+ )
1426
+ else:
1427
+ assert (
1428
+ self.n_share_experts_fusion == self.tp_size
1429
+ ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
1430
+ elif self.n_share_experts_fusion == 0:
1431
+ if (
1432
+ torch.cuda.get_device_capability("cuda") >= (9, 0)
1433
+ and self.config.architectures[0] == "DeepseekV3ForCausalLM"
1434
+ and self.config.n_routed_experts == 256
1435
+ and (not global_server_args_dict["enable_deepep_moe"])
1436
+ ):
1437
+ self.n_share_experts_fusion = self.tp_size
1438
+ global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1439
+ logger.info(
1440
+ "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
1441
+ )
1351
1442
 
1352
1443
  self.model = DeepseekV2Model(
1353
1444
  config, quant_config, prefix=add_prefix("model", prefix)
@@ -1382,35 +1473,38 @@ class DeepseekV2ForCausalLM(nn.Module):
1382
1473
  def post_load_weights(self):
1383
1474
 
1384
1475
  # Perform post-processing after loading weights
1385
-
1386
- if not global_server_args_dict["disable_mla"]:
1387
- for layer_id in range(self.config.num_hidden_layers):
1388
- self_attn = self.model.layers[layer_id].self_attn
1389
- if hasattr(self_attn.kv_b_proj, "qweight"):
1390
- # AWQ compatible
1391
- if _is_cuda:
1392
- w = awq_dequantize(
1393
- self_attn.kv_b_proj.qweight,
1394
- self_attn.kv_b_proj.scales,
1395
- self_attn.kv_b_proj.qzeros,
1396
- ).T
1397
- else:
1398
- w = ops.awq_dequantize(
1399
- self_attn.kv_b_proj.qweight,
1400
- self_attn.kv_b_proj.scales,
1401
- self_attn.kv_b_proj.qzeros,
1402
- 0,
1403
- 0,
1404
- 0,
1405
- ).T
1476
+ for layer_id in range(self.config.num_hidden_layers):
1477
+ self_attn = self.model.layers[layer_id].self_attn
1478
+ if hasattr(self_attn.kv_b_proj, "qweight"):
1479
+ # AWQ compatible
1480
+ if _is_cuda:
1481
+ w = awq_dequantize(
1482
+ self_attn.kv_b_proj.qweight,
1483
+ self_attn.kv_b_proj.scales,
1484
+ self_attn.kv_b_proj.qzeros,
1485
+ ).T
1406
1486
  else:
1407
- w = self_attn.kv_b_proj.weight
1408
- # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1409
- # This may affect the accuracy of fp8 model.
1410
- if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
1411
- torch.float8_e4m3fn,
1412
- torch.float8_e4m3fnuz,
1413
- ):
1487
+ w = awq_dequantize(
1488
+ self_attn.kv_b_proj.qweight,
1489
+ self_attn.kv_b_proj.scales,
1490
+ self_attn.kv_b_proj.qzeros,
1491
+ 0,
1492
+ 0,
1493
+ 0,
1494
+ ).T
1495
+ else:
1496
+ w = self_attn.kv_b_proj.weight
1497
+ # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1498
+ # This may affect the accuracy of fp8 model.
1499
+ # Fix deepseek v3 blockwise bmm by using deep_gemm
1500
+ use_deep_gemm_bmm = False
1501
+ model_dtype = torch.get_default_dtype()
1502
+
1503
+ if w.dtype in (
1504
+ torch.float8_e4m3fn,
1505
+ torch.float8_e4m3fnuz,
1506
+ ):
1507
+ if hasattr(self.quant_config, "weight_block_size"):
1414
1508
  weight_block_size = self.quant_config.weight_block_size
1415
1509
  if weight_block_size is not None:
1416
1510
  assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
@@ -1424,29 +1518,47 @@ class DeepseekV2ForCausalLM(nn.Module):
1424
1518
  weight = w
1425
1519
  weight_scale = self_attn.kv_b_proj.weight_scale_inv
1426
1520
 
1427
- w, scale = block_quant_to_tensor_quant(
1428
- weight, weight_scale, weight_block_size
1429
- )
1430
- self_attn.w_scale = scale
1431
- if w.dtype == torch.int8:
1432
- if hasattr(self.quant_config, "weight_block_size"):
1433
- # block-wise int8 need it
1434
- weight_block_size = self.quant_config.weight_block_size
1435
- if weight_block_size is not None:
1436
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1437
- weight = w
1438
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1439
- w = int8_block_dequant(
1521
+ if (
1522
+ _is_cuda
1523
+ and _enable_jit_deepgemm_bmm
1524
+ and weight_block_size[0] == 128
1525
+ and weight_block_size[1] == 128
1526
+ and model_dtype == torch.bfloat16
1527
+ ):
1528
+ block_scale = weight_scale
1529
+ use_deep_gemm_bmm = True
1530
+ else:
1531
+ w, scale = block_quant_to_tensor_quant(
1440
1532
  weight, weight_scale, weight_block_size
1441
- ).to(torch.bfloat16)
1442
- else:
1443
- # channel-wise int8 need it
1444
- w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
1445
- torch.bfloat16
1446
- )
1447
- w_kc, w_vc = w.unflatten(
1448
- 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1449
- ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1533
+ )
1534
+ self_attn.w_scale = scale
1535
+ else:
1536
+ weight = w
1537
+ weight_scale = self_attn.kv_b_proj.weight_scale
1538
+ w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
1539
+ self_attn.w_scale = scale
1540
+
1541
+ if w.dtype == torch.int8:
1542
+ if hasattr(self.quant_config, "weight_block_size"):
1543
+ # block-wise int8 need it
1544
+ weight_block_size = self.quant_config.weight_block_size
1545
+ if weight_block_size is not None:
1546
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1547
+ weight = w
1548
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1549
+ w = int8_block_dequant(
1550
+ weight, weight_scale, weight_block_size
1551
+ ).to(torch.bfloat16)
1552
+ else:
1553
+ # channel-wise int8 need it
1554
+ w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
1555
+ torch.bfloat16
1556
+ )
1557
+
1558
+ w_kc, w_vc = w.unflatten(
1559
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1560
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1561
+ if not use_deep_gemm_bmm:
1450
1562
  self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1451
1563
  self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1452
1564
  if (
@@ -1456,6 +1568,17 @@ class DeepseekV2ForCausalLM(nn.Module):
1456
1568
  self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1457
1569
  if _is_hip:
1458
1570
  self_attn.w_scale *= 2.0
1571
+ else:
1572
+ num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
1573
+ num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
1574
+ ws_kc, ws_vc = block_scale.unflatten(
1575
+ 0, (-1, (num_tiles_k + num_tiles_n))
1576
+ ).split([num_tiles_k, num_tiles_n], dim=1)
1577
+ self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
1578
+ self_attn.w_scale_v = ws_vc.contiguous()
1579
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
1580
+ self_attn.w_vc = w_vc.contiguous()
1581
+ self_attn.use_deep_gemm_bmm = True
1459
1582
 
1460
1583
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1461
1584
  stacked_params_mapping = [
@@ -1463,17 +1586,27 @@ class DeepseekV2ForCausalLM(nn.Module):
1463
1586
  ("gate_up_proj", "gate_proj", 0),
1464
1587
  ("gate_up_proj", "up_proj", 1),
1465
1588
  ]
1466
- if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
1589
+ if self.n_share_experts_fusion > 0:
1467
1590
  weights_list = list(weights)
1468
1591
  weights_dict = dict(weights_list)
1469
- suffix_list = [
1470
- "down_proj.weight",
1471
- "down_proj.weight_scale_inv",
1472
- "gate_proj.weight",
1473
- "gate_proj.weight_scale_inv",
1474
- "up_proj.weight",
1475
- "up_proj.weight_scale_inv",
1476
- ]
1592
+ if self.quant_config.get_name() == "w8a8_int8":
1593
+ suffix_list = [
1594
+ "down_proj.weight",
1595
+ "down_proj.weight_scale",
1596
+ "gate_proj.weight",
1597
+ "gate_proj.weight_scale",
1598
+ "up_proj.weight",
1599
+ "up_proj.weight_scale",
1600
+ ]
1601
+ else:
1602
+ suffix_list = [
1603
+ "down_proj.weight",
1604
+ "down_proj.weight_scale_inv",
1605
+ "gate_proj.weight",
1606
+ "gate_proj.weight_scale_inv",
1607
+ "up_proj.weight",
1608
+ "up_proj.weight_scale_inv",
1609
+ ]
1477
1610
  names_to_remove = []
1478
1611
  for moe_layer in tqdm(
1479
1612
  range(
@@ -1512,12 +1645,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1512
1645
  ckpt_gate_proj_name="gate_proj",
1513
1646
  ckpt_down_proj_name="down_proj",
1514
1647
  ckpt_up_proj_name="up_proj",
1515
- num_experts=self.config.n_routed_experts
1516
- + (
1517
- self.n_share_experts_fusion
1518
- if self.n_share_experts_fusion is not None
1519
- else 0
1520
- ),
1648
+ num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
1521
1649
  )
1522
1650
 
1523
1651
  params_dict = dict(self.named_parameters())