sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -155,6 +155,7 @@ class ExaoneAttention(nn.Module):
155
155
  self.scaling,
156
156
  num_kv_heads=self.num_kv_heads,
157
157
  layer_id=layer_id,
158
+ quant_config=quant_config,
158
159
  )
159
160
 
160
161
  def forward(
@@ -137,6 +137,7 @@ class GemmaAttention(nn.Module):
137
137
  self.scaling,
138
138
  num_kv_heads=self.num_kv_heads,
139
139
  layer_id=layer_id,
140
+ quant_config=quant_config,
140
141
  prefix=add_prefix("attn", prefix),
141
142
  )
142
143
 
@@ -163,6 +163,7 @@ class Gemma2Attention(nn.Module):
163
163
  if use_sliding_window
164
164
  else None
165
165
  ),
166
+ quant_config=quant_config,
166
167
  prefix=add_prefix("attn", prefix),
167
168
  )
168
169
 
@@ -193,6 +193,7 @@ class Gemma3Attention(nn.Module):
193
193
  # Module must also define `get_attention_sliding_window_size` to correctly initialize
194
194
  # attention backend in `ForwardBatch`.
195
195
  sliding_window_size=self.sliding_window,
196
+ quant_config=quant_config,
196
197
  prefix=add_prefix("attn", prefix),
197
198
  )
198
199
 
sglang/srt/models/gpt2.py CHANGED
@@ -78,6 +78,7 @@ class GPT2Attention(nn.Module):
78
78
  scaling=self.scale,
79
79
  num_kv_heads=total_num_heads,
80
80
  layer_id=layer_id,
81
+ quant_config=quant_config,
81
82
  )
82
83
 
83
84
  def forward(
@@ -87,6 +87,7 @@ class GPTBigCodeAttention(nn.Module):
87
87
  scaling=self.scale,
88
88
  num_kv_heads=self.num_kv_heads,
89
89
  layer_id=layer_id,
90
+ quant_config=quant_config,
90
91
  prefix=add_prefix("attn", prefix),
91
92
  )
92
93
 
@@ -158,6 +158,7 @@ class GraniteAttention(nn.Module):
158
158
  self.scaling,
159
159
  num_kv_heads=self.num_kv_heads,
160
160
  layer_id=layer_id,
161
+ quant_config=quant_config,
161
162
  prefix=add_prefix("attn", prefix),
162
163
  )
163
164
 
sglang/srt/models/grok.py CHANGED
@@ -215,6 +215,7 @@ class Grok1Attention(nn.Module):
215
215
  num_kv_heads=self.num_kv_heads,
216
216
  layer_id=layer_id,
217
217
  logit_cap=logit_cap,
218
+ quant_config=quant_config,
218
219
  )
219
220
 
220
221
  def forward(
@@ -145,6 +145,7 @@ class InternLM2Attention(nn.Module):
145
145
  self.scaling,
146
146
  self.num_kv_heads,
147
147
  layer_id,
148
+ quant_config=quant_config,
148
149
  prefix=add_prefix("attn", prefix),
149
150
  )
150
151
 
@@ -170,6 +170,7 @@ class LlamaAttention(nn.Module):
170
170
  self.scaling,
171
171
  num_kv_heads=self.num_kv_heads,
172
172
  layer_id=layer_id,
173
+ quant_config=quant_config,
173
174
  prefix=add_prefix("attn", prefix),
174
175
  )
175
176
 
@@ -361,11 +362,11 @@ class LlamaForCausalLM(nn.Module):
361
362
  column_parallel_weights_modules = [".down_proj.", ".o_proj."]
362
363
  bitsandbytes_stacked_params_mapping = {
363
364
  # shard_name, weight_name, index
364
- "q_proj": ("qkv_proj", 0),
365
- "k_proj": ("qkv_proj", 1),
366
- "v_proj": ("qkv_proj", 2),
367
- "gate_proj": ("gate_up_proj", 0),
368
- "up_proj": ("gate_up_proj", 1),
365
+ ".q_proj": (".qkv_proj", 0),
366
+ ".k_proj": (".qkv_proj", 1),
367
+ ".v_proj": (".qkv_proj", 2),
368
+ ".gate_proj": (".gate_up_proj", 0),
369
+ ".up_proj": (".gate_up_proj", 1),
369
370
  }
370
371
 
371
372
  def __init__(
@@ -27,6 +27,13 @@ from sglang.srt.distributed import (
27
27
  get_tensor_model_parallel_world_size,
28
28
  tensor_model_parallel_all_reduce,
29
29
  )
30
+ from sglang.srt.layers.dp_attention import (
31
+ dp_gather_partial,
32
+ dp_scatter,
33
+ get_attention_dp_size,
34
+ get_attention_tp_rank,
35
+ get_attention_tp_size,
36
+ )
30
37
  from sglang.srt.layers.layernorm import RMSNorm
31
38
  from sglang.srt.layers.linear import (
32
39
  QKVParallelLinear,
@@ -38,9 +45,10 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
45
  from sglang.srt.layers.radix_attention import RadixAttention
39
46
  from sglang.srt.layers.rotary_embedding import get_rope
40
47
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
48
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
41
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
50
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
43
- from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
51
+ from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
44
52
 
45
53
  logger = logging.getLogger(__name__)
46
54
 
@@ -55,7 +63,7 @@ class Llama4MoE(nn.Module):
55
63
  topk: int,
56
64
  renormalize: bool,
57
65
  ) -> Tuple[torch.Tensor, torch.Tensor]:
58
- router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
66
+ router_scores_aK, router_indices_aK = fast_topk(gating_output, topk, dim=-1)
59
67
  router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
60
68
  hidden_states.dtype
61
69
  )
@@ -143,20 +151,24 @@ class Llama4Attention(nn.Module):
143
151
  self.hidden_size = hidden_size
144
152
  self.use_rope = int((layer_id + 1) % 4 != 0)
145
153
  self.use_qk_norm = config.use_qk_norm and self.use_rope
146
- tp_size = get_tensor_model_parallel_world_size()
154
+
155
+ self.dp_size = get_attention_dp_size()
156
+ attn_tp_rank = get_attention_tp_rank()
157
+ attn_tp_size = get_attention_tp_size()
158
+
147
159
  self.total_num_heads = num_heads
148
- assert self.total_num_heads % tp_size == 0
149
- self.num_heads = self.total_num_heads // tp_size
160
+ assert self.total_num_heads % attn_tp_size == 0
161
+ self.num_heads = self.total_num_heads // attn_tp_size
150
162
  self.total_num_kv_heads = num_kv_heads
151
- if self.total_num_kv_heads >= tp_size:
163
+ if self.total_num_kv_heads >= attn_tp_size:
152
164
  # Number of KV heads is greater than TP size, so we partition
153
165
  # the KV heads across multiple tensor parallel GPUs.
154
- assert self.total_num_kv_heads % tp_size == 0
166
+ assert self.total_num_kv_heads % attn_tp_size == 0
155
167
  else:
156
168
  # Number of KV heads is less than TP size, so we replicate
157
169
  # the KV heads across multiple tensor parallel GPUs.
158
- assert tp_size % self.total_num_kv_heads == 0
159
- self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
170
+ assert attn_tp_size % self.total_num_kv_heads == 0
171
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
160
172
  self.head_dim = config.head_dim
161
173
  self.q_size = self.num_heads * self.head_dim
162
174
  self.kv_size = self.num_kv_heads * self.head_dim
@@ -183,6 +195,8 @@ class Llama4Attention(nn.Module):
183
195
  bias=bias,
184
196
  quant_config=quant_config,
185
197
  prefix=add_prefix("qkv_proj", prefix),
198
+ tp_rank=attn_tp_rank,
199
+ tp_size=attn_tp_size,
186
200
  )
187
201
 
188
202
  self.o_proj = RowParallelLinear(
@@ -191,6 +205,9 @@ class Llama4Attention(nn.Module):
191
205
  bias=bias_o_proj,
192
206
  quant_config=quant_config,
193
207
  prefix=add_prefix("o_proj", prefix),
208
+ tp_rank=attn_tp_rank,
209
+ tp_size=attn_tp_size,
210
+ reduce_results=False,
194
211
  )
195
212
  is_neox_style = True
196
213
  is_gguf = quant_config and quant_config.get_name() == "gguf"
@@ -223,9 +240,13 @@ class Llama4Attention(nn.Module):
223
240
  def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
224
241
  floor = torch.floor((positions + 1.0) / self.floor_scale)
225
242
  attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
226
-
227
243
  return attn_scale.unsqueeze(-1)
228
244
 
245
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
246
+ def _mul_attn_scale(self, positions, q):
247
+ attn_scale = self._get_attn_scale(positions)
248
+ return (q * attn_scale).to(q.dtype)
249
+
229
250
  def forward(
230
251
  self,
231
252
  positions: torch.Tensor,
@@ -233,27 +254,29 @@ class Llama4Attention(nn.Module):
233
254
  forward_batch: ForwardBatch,
234
255
  ) -> torch.Tensor:
235
256
  qkv, _ = self.qkv_proj(hidden_states)
236
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
257
+
258
+ qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
237
259
 
238
260
  if self.rotary_emb is not None:
239
- q, k = self.rotary_emb(positions, q, k)
261
+ q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1)
262
+ q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view)
263
+ assert (q_out_unused is q_view) and (k_out_unused is k_view)
264
+ del q_view, k_view, q_out_unused, k_out_unused
240
265
 
241
266
  if self.qk_norm is not None:
242
- # TODO: support float
243
- q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
244
- k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
245
- q = self.qk_norm(q).to(q.dtype)
246
- k = self.qk_norm(k).to(k.dtype)
247
- q = q.reshape(-1, self.q_size)
248
- k = k.reshape(-1, self.kv_size)
267
+ # TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
268
+ qk = qk.reshape(-1, self.head_dim).contiguous().bfloat16()
269
+ qk = self.qk_norm(qk).to(torch.bfloat16)
270
+ qk = qk.reshape(-1, self.q_size + self.kv_size)
271
+
272
+ q, k = qk.split([self.q_size, self.kv_size], dim=-1)
249
273
 
250
274
  # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
251
275
  # the inference-time temperature tuning function is customized to not affect short context
252
276
  # while working at very long context
253
277
  # https://arxiv.org/abs/2501.19399
254
278
  if self.attn_temperature_tuning and not self.use_rope:
255
- attn_scale = self._get_attn_scale(positions)
256
- q = (q * attn_scale).to(q.dtype)
279
+ q = self._mul_attn_scale(positions=positions, q=q)
257
280
 
258
281
  attn_output = self.attn(q, k, v, forward_batch)
259
282
  output, _ = self.o_proj(attn_output)
@@ -274,6 +297,9 @@ class Llama4DecoderLayer(nn.Module):
274
297
  rope_theta = config.rope_theta
275
298
  rope_scaling = config.rope_scaling
276
299
  max_position_embeddings = config.max_position_embeddings
300
+ self.dp_size = get_attention_dp_size()
301
+ self.attn_tp_size = get_attention_tp_size()
302
+ self.attn_tp_rank = get_attention_tp_rank()
277
303
 
278
304
  self.self_attn = Llama4Attention(
279
305
  config=config,
@@ -316,21 +342,58 @@ class Llama4DecoderLayer(nn.Module):
316
342
  forward_batch: ForwardBatch,
317
343
  residual: Optional[torch.Tensor],
318
344
  ) -> Tuple[torch.Tensor, torch.Tensor]:
319
- # Self Attention
320
- if residual is None:
345
+ if hidden_states.shape[0] == 0:
321
346
  residual = hidden_states
322
- hidden_states = self.input_layernorm(hidden_states)
323
347
  else:
324
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
325
- hidden_states = self.self_attn(
326
- positions=positions,
327
- hidden_states=hidden_states,
328
- forward_batch=forward_batch,
329
- )
348
+ # Self Attention
349
+ if residual is None:
350
+ residual = hidden_states
351
+ hidden_states = self.input_layernorm(hidden_states)
352
+ else:
353
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
354
+ hidden_states = self.self_attn(
355
+ positions=positions,
356
+ hidden_states=hidden_states,
357
+ forward_batch=forward_batch,
358
+ )
359
+
360
+ # Gather
361
+ if get_tensor_model_parallel_world_size() > 1:
362
+ # all gather and all reduce
363
+ if self.dp_size != 1:
364
+ if self.attn_tp_rank == 0:
365
+ hidden_states += residual
366
+ hidden_states, local_hidden_states = (
367
+ forward_batch.gathered_buffer,
368
+ hidden_states,
369
+ )
370
+ dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
371
+ dp_scatter(residual, hidden_states, forward_batch)
372
+ hidden_states = self.post_attention_layernorm(hidden_states)
373
+ else:
374
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
375
+ hidden_states, residual = self.post_attention_layernorm(
376
+ hidden_states, residual
377
+ )
378
+ else:
379
+ hidden_states, residual = self.post_attention_layernorm(
380
+ hidden_states, residual
381
+ )
330
382
 
331
383
  # Fully Connected
332
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
333
384
  hidden_states = self.feed_forward(hidden_states)
385
+
386
+ # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
387
+ # Scatter
388
+ if self.dp_size != 1:
389
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
390
+ # be careful about this!
391
+ hidden_states, global_hidden_states = (
392
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
393
+ hidden_states,
394
+ )
395
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
396
+
334
397
  return hidden_states, residual
335
398
 
336
399
 
@@ -350,13 +413,14 @@ class Llama4Model(nn.Module):
350
413
  config.hidden_size,
351
414
  quant_config=quant_config,
352
415
  prefix=add_prefix("embed_tokens", prefix),
416
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
353
417
  )
354
418
  self.layers = make_layers(
355
419
  config.num_hidden_layers,
356
420
  lambda idx, prefix: Llama4DecoderLayer(
357
421
  config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
358
422
  ),
359
- prefix="model.layers",
423
+ prefix=add_prefix("layers", prefix),
360
424
  )
361
425
 
362
426
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -385,7 +449,8 @@ class Llama4Model(nn.Module):
385
449
  forward_batch,
386
450
  residual,
387
451
  )
388
- hidden_states, _ = self.norm(hidden_states, residual)
452
+ if not forward_batch.forward_mode.is_idle():
453
+ hidden_states, _ = self.norm(hidden_states, residual)
389
454
 
390
455
  if len(aux_hidden_states) == 0:
391
456
  return hidden_states
@@ -394,7 +459,6 @@ class Llama4Model(nn.Module):
394
459
 
395
460
 
396
461
  class Llama4ForCausalLM(LlamaForCausalLM):
397
-
398
462
  packed_modules_mapping = {
399
463
  "qkv_proj": ["q_proj", "k_proj", "v_proj"],
400
464
  "gate_up_proj": ["gate_proj", "up_proj"],
@@ -408,6 +472,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
408
472
  ):
409
473
  super().__init__(config, quant_config, prefix)
410
474
 
475
+ def get_input_embeddings(self):
476
+ return self.model.embed_tokens
477
+
411
478
  def _init_model(
412
479
  self,
413
480
  config: Llama4TextConfig,
@@ -146,6 +146,7 @@ class MiniCPMAttention(nn.Module):
146
146
  self.scaling,
147
147
  num_kv_heads=self.num_kv_heads,
148
148
  layer_id=layer_id,
149
+ quant_config=quant_config,
149
150
  prefix=add_prefix("attn", prefix),
150
151
  )
151
152
 
@@ -93,157 +93,6 @@ def input_to_float8(x, dtype=torch.float8_e4m3fn):
93
93
  return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
94
94
 
95
95
 
96
- class MiniCPM3Attention(nn.Module):
97
-
98
- def __init__(
99
- self,
100
- config: PretrainedConfig,
101
- hidden_size: int,
102
- num_heads: int,
103
- qk_nope_head_dim: int,
104
- qk_rope_head_dim: int,
105
- v_head_dim: int,
106
- q_lora_rank: int,
107
- kv_lora_rank: int,
108
- rope_theta: float = 10000,
109
- rope_scaling: Optional[Dict[str, Any]] = None,
110
- max_position_embeddings: int = 8192,
111
- quant_config: Optional[QuantizationConfig] = None,
112
- layer_id=None,
113
- prefix: str = "",
114
- ) -> None:
115
- super().__init__()
116
- self.layer_id = layer_id
117
- self.hidden_size = hidden_size
118
- self.qk_nope_head_dim = qk_nope_head_dim
119
- self.qk_rope_head_dim = qk_rope_head_dim
120
- self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
121
- self.v_head_dim = v_head_dim
122
- self.q_lora_rank = q_lora_rank
123
- self.kv_lora_rank = kv_lora_rank
124
- self.num_heads = num_heads
125
- tp_size = get_tensor_model_parallel_world_size()
126
- assert num_heads % tp_size == 0
127
- self.num_local_heads = num_heads // tp_size
128
- self.scaling = self.qk_head_dim**-0.5
129
- self.rope_theta = rope_theta
130
- self.max_position_embeddings = max_position_embeddings
131
-
132
- if self.q_lora_rank is not None:
133
- self.q_a_proj = ReplicatedLinear(
134
- self.hidden_size,
135
- self.q_lora_rank,
136
- bias=False,
137
- quant_config=quant_config,
138
- prefix=add_prefix("q_a_proj", prefix),
139
- )
140
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
141
- self.q_b_proj = ColumnParallelLinear(
142
- q_lora_rank,
143
- self.num_heads * self.qk_head_dim,
144
- bias=False,
145
- quant_config=quant_config,
146
- prefix=add_prefix("q_b_proj", prefix),
147
- )
148
- else:
149
- self.q_proj = ColumnParallelLinear(
150
- self.hidden_size,
151
- self.num_heads * self.qk_head_dim,
152
- bias=False,
153
- quant_config=quant_config,
154
- prefix=add_prefix("q_proj", prefix),
155
- )
156
-
157
- self.kv_a_proj_with_mqa = ReplicatedLinear(
158
- self.hidden_size,
159
- self.kv_lora_rank + self.qk_rope_head_dim,
160
- bias=False,
161
- quant_config=quant_config,
162
- prefix=add_prefix("kv_a_proj_with_mqa", prefix),
163
- )
164
- self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
165
- self.kv_b_proj = ColumnParallelLinear(
166
- self.kv_lora_rank,
167
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
168
- bias=False,
169
- quant_config=quant_config,
170
- prefix=add_prefix("kv_b_proj", prefix),
171
- )
172
- # O projection.
173
- self.o_proj = RowParallelLinear(
174
- self.num_heads * self.v_head_dim,
175
- self.hidden_size,
176
- bias=False,
177
- quant_config=quant_config,
178
- prefix=add_prefix("o_proj", prefix),
179
- )
180
- self.rotary_emb = get_rope(
181
- qk_rope_head_dim,
182
- rotary_dim=qk_rope_head_dim,
183
- max_position=max_position_embeddings,
184
- base=rope_theta,
185
- rope_scaling=rope_scaling,
186
- )
187
-
188
- # TODO support head_size 96
189
- self.attn = RadixAttention(
190
- self.num_local_heads,
191
- 128,
192
- self.scaling,
193
- num_kv_heads=self.num_local_heads,
194
- layer_id=layer_id,
195
- prefix=add_prefix("attn", prefix),
196
- )
197
-
198
- def forward(
199
- self,
200
- positions: torch.Tensor,
201
- hidden_states: torch.Tensor,
202
- forward_batch: ForwardBatch,
203
- ) -> torch.Tensor:
204
- if self.q_lora_rank is not None:
205
- q = self.q_a_proj(hidden_states)[0]
206
- q = self.q_a_layernorm(q)
207
- q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
208
- else:
209
- q = self.q_proj(hidden_states)[0].view(
210
- -1, self.num_local_heads, self.qk_head_dim
211
- )
212
- _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
213
- latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
214
- kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
215
- latent_cache = latent_cache.unsqueeze(1)
216
- kv_a = self.kv_a_layernorm(kv_a.contiguous())
217
- kv = self.kv_b_proj(kv_a)[0]
218
- kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
219
- k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
220
- k_pe = latent_cache[:, :, self.kv_lora_rank :]
221
- original_shapes = [q_pe.shape, k_pe.shape]
222
- q_pe, k_pe = self.rotary_emb(
223
- positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1)
224
- )
225
- q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1])
226
- q[..., self.qk_nope_head_dim :] = q_pe
227
- k = torch.empty_like(q)
228
- k[..., : self.qk_nope_head_dim] = k_nope
229
- k[..., self.qk_nope_head_dim :] = k_pe
230
- q = torch.nn.functional.pad(q, [0, 128 - self.qk_head_dim], value=0).view(
231
- -1, self.num_local_heads * 128
232
- )
233
- k = torch.nn.functional.pad(k, [0, 128 - self.qk_head_dim], value=0).view(
234
- -1, self.num_local_heads * 128
235
- )
236
- v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
237
- -1, self.num_local_heads * 128
238
- )
239
- attn_output = self.attn(q, k, v, forward_batch)
240
- attn_output = attn_output.view(-1, self.num_local_heads, 128)[
241
- ..., : self.v_head_dim
242
- ].reshape(-1, self.num_local_heads * self.v_head_dim)
243
- output, _ = self.o_proj(attn_output)
244
- return output
245
-
246
-
247
96
  class MiniCPM3AttentionMLA(nn.Module):
248
97
 
249
98
  def __init__(
@@ -343,6 +192,7 @@ class MiniCPM3AttentionMLA(nn.Module):
343
192
  num_kv_heads=1,
344
193
  layer_id=layer_id,
345
194
  v_head_dim=self.kv_lora_rank,
195
+ quant_config=quant_config,
346
196
  prefix=add_prefix("attn", prefix),
347
197
  )
348
198
 
@@ -432,44 +282,25 @@ class MiniCPM3DecoderLayer(nn.Module):
432
282
  rope_theta = getattr(config, "rope_theta", 10000)
433
283
  rope_scaling = getattr(config, "rope_scaling", None)
434
284
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
435
- if not global_server_args_dict["disable_mla"]:
436
- self.self_attn = MiniCPM3AttentionMLA(
437
- config=config,
438
- hidden_size=self.hidden_size,
439
- num_heads=config.num_attention_heads,
440
- qk_nope_head_dim=config.qk_nope_head_dim,
441
- qk_rope_head_dim=config.qk_rope_head_dim,
442
- v_head_dim=self.hidden_size // config.num_attention_heads,
443
- q_lora_rank=(
444
- config.q_lora_rank if hasattr(config, "q_lora_rank") else None
445
- ),
446
- kv_lora_rank=config.kv_lora_rank,
447
- rope_theta=rope_theta,
448
- rope_scaling=rope_scaling,
449
- max_position_embeddings=max_position_embeddings,
450
- quant_config=quant_config,
451
- layer_id=layer_id,
452
- prefix=add_prefix("self_attn", prefix),
453
- )
454
- else:
455
- self.self_attn = MiniCPM3Attention(
456
- config=config,
457
- hidden_size=self.hidden_size,
458
- num_heads=config.num_attention_heads,
459
- qk_nope_head_dim=config.qk_nope_head_dim,
460
- qk_rope_head_dim=config.qk_rope_head_dim,
461
- v_head_dim=self.hidden_size // config.num_attention_heads,
462
- q_lora_rank=(
463
- config.q_lora_rank if hasattr(config, "q_lora_rank") else None
464
- ),
465
- kv_lora_rank=config.kv_lora_rank,
466
- rope_theta=rope_theta,
467
- rope_scaling=rope_scaling,
468
- max_position_embeddings=max_position_embeddings,
469
- quant_config=quant_config,
470
- layer_id=layer_id,
471
- prefix=add_prefix("self_attn", prefix),
472
- )
285
+ self.self_attn = MiniCPM3AttentionMLA(
286
+ config=config,
287
+ hidden_size=self.hidden_size,
288
+ num_heads=config.num_attention_heads,
289
+ qk_nope_head_dim=config.qk_nope_head_dim,
290
+ qk_rope_head_dim=config.qk_rope_head_dim,
291
+ v_head_dim=self.hidden_size // config.num_attention_heads,
292
+ q_lora_rank=(
293
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
294
+ ),
295
+ kv_lora_rank=config.kv_lora_rank,
296
+ rope_theta=rope_theta,
297
+ rope_scaling=rope_scaling,
298
+ max_position_embeddings=max_position_embeddings,
299
+ quant_config=quant_config,
300
+ layer_id=layer_id,
301
+ prefix=add_prefix("self_attn", prefix),
302
+ )
303
+
473
304
  self.mlp = MiniCPM3MLP(
474
305
  hidden_size=self.hidden_size,
475
306
  intermediate_size=config.intermediate_size,
@@ -672,17 +503,16 @@ class MiniCPM3ForCausalLM(nn.Module):
672
503
  )
673
504
  weight_loader(param, loaded_weight)
674
505
 
675
- if not global_server_args_dict["disable_mla"]:
676
- for layer_id in range(self.config.num_hidden_layers):
677
- self_attn = self.model.layers[layer_id].self_attn
678
- w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
679
- 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
680
- ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
681
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
682
- self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
683
- if hasattr(self_attn.kv_b_proj, "weight_scale"):
684
- self_attn.w_scale = self_attn.kv_b_proj.weight_scale
685
- del self_attn.kv_b_proj
506
+ for layer_id in range(self.config.num_hidden_layers):
507
+ self_attn = self.model.layers[layer_id].self_attn
508
+ w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
509
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
510
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
511
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
512
+ self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
513
+ if hasattr(self_attn.kv_b_proj, "weight_scale"):
514
+ self_attn.w_scale = self_attn.kv_b_proj.weight_scale
515
+ del self_attn.kv_b_proj
686
516
 
687
517
 
688
518
  EntryClass = MiniCPM3ForCausalLM
@@ -169,6 +169,7 @@ class MixtralAttention(nn.Module):
169
169
  self.scaling,
170
170
  num_kv_heads=self.num_kv_heads,
171
171
  layer_id=layer_id,
172
+ quant_config=quant_config,
172
173
  prefix=add_prefix("attn", prefix),
173
174
  )
174
175
 
@@ -232,6 +232,7 @@ class MixtralAttention(nn.Module):
232
232
  self.scaling,
233
233
  num_kv_heads=self.num_kv_heads,
234
234
  layer_id=layer_id,
235
+ quant_config=quant_config,
235
236
  prefix=add_prefix("attn", prefix),
236
237
  )
237
238