sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
30
30
  block_quant_to_tensor_quant,
31
31
  normalize_e4m3fn_to_e4m3fnuz,
32
32
  )
33
+ from sglang.srt.layers.quantization.int8_utils import (
34
+ block_dequant as int8_block_dequant,
35
+ )
33
36
  from sglang.srt.layers.vocab_parallel_embedding import (
34
37
  ParallelLMHead,
35
38
  VocabParallelEmbedding,
@@ -40,7 +43,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
40
43
  from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
41
44
  from sglang.srt.utils import add_prefix, is_hip
42
45
 
43
- is_hip_ = is_hip()
46
+ _is_hip = is_hip()
44
47
 
45
48
 
46
49
  class DeepseekModelNextN(nn.Module):
@@ -277,7 +280,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
277
280
  weight_block_size = self.quant_config.weight_block_size
278
281
  if weight_block_size is not None:
279
282
  assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
280
- if is_hip_:
283
+ if _is_hip:
281
284
  weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
282
285
  weight=w,
283
286
  weight_scale=self_attn.kv_b_proj.weight_scale_inv,
@@ -291,6 +294,23 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
291
294
  weight, weight_scale, weight_block_size
292
295
  )
293
296
  self_attn.w_scale = scale
297
+ if w.dtype == torch.int8:
298
+ if hasattr(self.quant_config, "weight_block_size"):
299
+ # block-wise int8 need it
300
+ weight_block_size = self.quant_config.weight_block_size
301
+ if weight_block_size is not None:
302
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
303
+ weight = w
304
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
305
+ w = int8_block_dequant(
306
+ weight, weight_scale, weight_block_size
307
+ ).to(torch.bfloat16)
308
+ else:
309
+ # channel-wise int8 need it
310
+ assert hasattr(self_attn.kv_b_proj, "weight_scale")
311
+ w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
312
+ torch.bfloat16
313
+ )
294
314
  w_kc, w_vc = w.unflatten(
295
315
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
296
316
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
@@ -301,7 +321,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
301
321
  and self_attn.w_scale is None
302
322
  ):
303
323
  self_attn.w_scale = self_attn.kv_b_proj.weight_scale
304
- if is_hip_:
324
+ if _is_hip:
305
325
  self_attn.w_scale *= 2.0
306
326
 
307
327
 
@@ -26,15 +26,20 @@ from transformers import PretrainedConfig
26
26
  from vllm import _custom_ops as ops
27
27
 
28
28
  from sglang.srt.distributed import (
29
- get_tensor_model_parallel_rank,
30
29
  get_tensor_model_parallel_world_size,
31
- get_tp_group,
32
30
  tensor_model_parallel_all_reduce,
33
31
  )
34
32
  from sglang.srt.layers.activation import SiluAndMul
35
33
  from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
36
34
  decode_attention_fwd_grouped_rope,
37
35
  )
36
+ from sglang.srt.layers.dp_attention import (
37
+ dp_gather,
38
+ dp_scatter,
39
+ get_attention_dp_size,
40
+ get_attention_tp_rank,
41
+ get_attention_tp_size,
42
+ )
38
43
  from sglang.srt.layers.layernorm import RMSNorm
39
44
  from sglang.srt.layers.linear import (
40
45
  ColumnParallelLinear,
@@ -65,7 +70,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
65
70
  from sglang.srt.model_loader.weight_utils import default_weight_loader
66
71
  from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
67
72
 
68
- is_hip_ = is_hip()
73
+ _is_hip = is_hip()
69
74
 
70
75
  if is_cuda_available():
71
76
  from sgl_kernel import bmm_fp8
@@ -230,6 +235,7 @@ class DeepseekV2Attention(nn.Module):
230
235
  max_position_embeddings: int = 8192,
231
236
  quant_config: Optional[QuantizationConfig] = None,
232
237
  layer_id=None,
238
+ reduce_results: bool = True,
233
239
  prefix: str = "",
234
240
  ) -> None:
235
241
  super().__init__()
@@ -241,10 +247,14 @@ class DeepseekV2Attention(nn.Module):
241
247
  self.v_head_dim = v_head_dim
242
248
  self.q_lora_rank = q_lora_rank
243
249
  self.kv_lora_rank = kv_lora_rank
250
+
251
+ self.dp_size = get_attention_dp_size()
252
+ attn_tp_rank = get_attention_tp_rank()
253
+ attn_tp_size = get_attention_tp_size()
254
+
244
255
  self.num_heads = num_heads
245
- tp_size = get_tensor_model_parallel_world_size()
246
- assert num_heads % tp_size == 0
247
- self.num_local_heads = num_heads // tp_size
256
+ assert num_heads % attn_tp_size == 0
257
+ self.num_local_heads = num_heads // attn_tp_size
248
258
  self.scaling = self.qk_head_dim**-0.5
249
259
  self.rope_theta = rope_theta
250
260
  self.max_position_embeddings = max_position_embeddings
@@ -272,6 +282,8 @@ class DeepseekV2Attention(nn.Module):
272
282
  bias=False,
273
283
  quant_config=quant_config,
274
284
  prefix=add_prefix("q_proj", prefix),
285
+ tp_rank=attn_tp_rank,
286
+ tp_size=attn_tp_size,
275
287
  )
276
288
 
277
289
  self.kv_a_proj_with_mqa = ReplicatedLinear(
@@ -296,6 +308,9 @@ class DeepseekV2Attention(nn.Module):
296
308
  bias=False,
297
309
  quant_config=quant_config,
298
310
  prefix=add_prefix("o_proj", prefix),
311
+ reduce_results=reduce_results,
312
+ tp_rank=attn_tp_rank,
313
+ tp_size=attn_tp_size,
299
314
  )
300
315
  rope_scaling["rope_type"] = "deepseek_yarn"
301
316
  self.rotary_emb = get_rope_wrapper(
@@ -330,6 +345,12 @@ class DeepseekV2Attention(nn.Module):
330
345
  hidden_states: torch.Tensor,
331
346
  forward_batch: ForwardBatch,
332
347
  ) -> torch.Tensor:
348
+ if hidden_states.shape[0] == 0:
349
+ assert (
350
+ not self.o_proj.reduce_results
351
+ ), "short-circuiting allreduce will lead to hangs"
352
+ return hidden_states
353
+
333
354
  if self.q_lora_rank is not None:
334
355
  q = self.q_a_proj(hidden_states)[0]
335
356
  q = self.q_a_layernorm(q)
@@ -385,8 +406,8 @@ class DeepseekV2AttentionMLA(nn.Module):
385
406
  rope_scaling: Optional[Dict[str, Any]] = None,
386
407
  max_position_embeddings: int = 8192,
387
408
  quant_config: Optional[QuantizationConfig] = None,
388
- layer_id=None,
389
- use_dp=False,
409
+ reduce_results: bool = True,
410
+ layer_id: int = None,
390
411
  prefix: str = "",
391
412
  ) -> None:
392
413
  super().__init__()
@@ -398,96 +419,66 @@ class DeepseekV2AttentionMLA(nn.Module):
398
419
  self.v_head_dim = v_head_dim
399
420
  self.q_lora_rank = q_lora_rank
400
421
  self.kv_lora_rank = kv_lora_rank
422
+ self.dp_size = get_attention_dp_size()
423
+ attn_tp_rank = get_attention_tp_rank()
424
+ attn_tp_size = get_attention_tp_size()
425
+
401
426
  self.num_heads = num_heads
402
- tp_size = get_tensor_model_parallel_world_size()
403
- assert num_heads % tp_size == 0
404
- self.num_local_heads = num_heads if use_dp else num_heads // tp_size
427
+ assert num_heads % attn_tp_size == 0
428
+ self.num_local_heads = num_heads // attn_tp_size
405
429
  self.scaling = self.qk_head_dim**-0.5
406
430
  self.rope_theta = rope_theta
407
431
  self.max_position_embeddings = max_position_embeddings
408
432
 
409
- if use_dp:
410
- # For data parallel attention
411
- if self.q_lora_rank is not None:
412
- self.q_a_proj = ReplicatedLinear(
413
- self.hidden_size,
414
- self.q_lora_rank,
415
- bias=False,
416
- quant_config=quant_config,
417
- prefix=add_prefix("q_a_proj", prefix),
418
- )
419
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
420
- self.q_b_proj = ReplicatedLinear(
421
- q_lora_rank,
422
- self.num_heads * self.qk_head_dim,
423
- bias=False,
424
- quant_config=quant_config,
425
- prefix=add_prefix("q_b_proj", prefix),
426
- )
427
- else:
428
- self.q_proj = ReplicatedLinear(
429
- self.hidden_size,
430
- self.num_heads * self.qk_head_dim,
431
- bias=False,
432
- quant_config=quant_config,
433
- prefix=add_prefix("q_proj", prefix),
434
- )
435
- self.kv_b_proj = ReplicatedLinear(
436
- self.kv_lora_rank,
437
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
438
- bias=False,
439
- quant_config=quant_config,
440
- prefix=add_prefix("kv_b_proj", prefix),
441
- )
442
- # O projection.
443
- self.o_proj = ReplicatedLinear(
444
- self.num_heads * self.v_head_dim,
433
+ # For tensor parallel attention
434
+ if self.q_lora_rank is not None:
435
+ self.q_a_proj = ReplicatedLinear(
445
436
  self.hidden_size,
437
+ self.q_lora_rank,
446
438
  bias=False,
447
439
  quant_config=quant_config,
448
- prefix=add_prefix("o_proj", prefix),
440
+ prefix=add_prefix("q_a_proj", prefix),
449
441
  )
450
- else:
451
- # For tensor parallel attention
452
- if self.q_lora_rank is not None:
453
- self.q_a_proj = ReplicatedLinear(
454
- self.hidden_size,
455
- self.q_lora_rank,
456
- bias=False,
457
- quant_config=quant_config,
458
- prefix=add_prefix("q_a_proj", prefix),
459
- )
460
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
461
- self.q_b_proj = ColumnParallelLinear(
462
- q_lora_rank,
463
- self.num_heads * self.qk_head_dim,
464
- bias=False,
465
- quant_config=quant_config,
466
- prefix=add_prefix("q_b_proj", prefix),
467
- )
468
- else:
469
- self.q_proj = ColumnParallelLinear(
470
- self.hidden_size,
471
- self.num_heads * self.qk_head_dim,
472
- bias=False,
473
- quant_config=quant_config,
474
- prefix=add_prefix("q_proj", prefix),
475
- )
476
- self.kv_b_proj = ColumnParallelLinear(
477
- self.kv_lora_rank,
478
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
442
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
443
+ self.q_b_proj = ColumnParallelLinear(
444
+ q_lora_rank,
445
+ self.num_heads * self.qk_head_dim,
479
446
  bias=False,
480
447
  quant_config=quant_config,
481
- prefix=add_prefix("kv_b_proj", prefix),
448
+ prefix=add_prefix("q_b_proj", prefix),
449
+ tp_rank=attn_tp_rank,
450
+ tp_size=attn_tp_size,
482
451
  )
483
- # O projection.
484
- self.o_proj = RowParallelLinear(
485
- self.num_heads * self.v_head_dim,
452
+ else:
453
+ self.q_proj = ColumnParallelLinear(
486
454
  self.hidden_size,
455
+ self.num_heads * self.qk_head_dim,
487
456
  bias=False,
488
457
  quant_config=quant_config,
489
- prefix=add_prefix("o_proj", prefix),
458
+ prefix=add_prefix("q_proj", prefix),
459
+ tp_rank=attn_tp_rank,
460
+ tp_size=attn_tp_size,
490
461
  )
462
+ self.kv_b_proj = ColumnParallelLinear(
463
+ self.kv_lora_rank,
464
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
465
+ bias=False,
466
+ quant_config=quant_config,
467
+ prefix=add_prefix("kv_b_proj", prefix),
468
+ tp_rank=attn_tp_rank,
469
+ tp_size=attn_tp_size,
470
+ )
471
+ # O projection.
472
+ self.o_proj = RowParallelLinear(
473
+ self.num_heads * self.v_head_dim,
474
+ self.hidden_size,
475
+ bias=False,
476
+ quant_config=quant_config,
477
+ reduce_results=reduce_results,
478
+ prefix=add_prefix("o_proj", prefix),
479
+ tp_rank=attn_tp_rank,
480
+ tp_size=attn_tp_size,
481
+ )
491
482
 
492
483
  self.kv_a_proj_with_mqa = ReplicatedLinear(
493
484
  self.hidden_size,
@@ -542,36 +533,49 @@ class DeepseekV2AttentionMLA(nn.Module):
542
533
  self.w_vc = None
543
534
  self.w_scale = None
544
535
 
536
+ self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
537
+ self.flashinfer_mla_disable_ragged = global_server_args_dict[
538
+ "flashinfer_mla_disable_ragged"
539
+ ]
540
+ self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
541
+
542
+ def no_absorb(self, forward_batch: ForwardBatch) -> bool:
543
+ if self.enable_flashinfer_mla:
544
+ # Flashinfer MLA: Do not absorb when enabling ragged prefill
545
+ return (
546
+ not self.flashinfer_mla_disable_ragged
547
+ and forward_batch.forward_mode.is_extend()
548
+ and not forward_batch.forward_mode.is_target_verify()
549
+ and not forward_batch.forward_mode.is_draft_extend()
550
+ and forward_batch.extend_prefix_lens.sum() == 0
551
+ )
552
+ else:
553
+ # Triton: Use normal computation for prefill and use weight absorption for extend/decode
554
+ return (
555
+ forward_batch.forward_mode.is_extend()
556
+ and not forward_batch.forward_mode.is_target_verify()
557
+ and not forward_batch.forward_mode.is_draft_extend()
558
+ and forward_batch.extend_prefix_lens.sum() == 0
559
+ )
560
+
545
561
  def forward(
546
562
  self,
547
563
  positions: torch.Tensor,
548
564
  hidden_states: torch.Tensor,
549
565
  forward_batch: ForwardBatch,
550
566
  ) -> torch.Tensor:
567
+ if hidden_states.shape[0] == 0:
568
+ assert (
569
+ not self.o_proj.reduce_results
570
+ ), "short-circuiting allreduce will lead to hangs"
571
+ return hidden_states
551
572
 
552
- def no_absorb() -> bool:
553
- if global_server_args_dict["enable_flashinfer_mla"]:
554
- # Flashinfer MLA: Do not absorb when enabling ragged prefill
555
- return (
556
- not global_server_args_dict["flashinfer_mla_disable_ragged"]
557
- and forward_batch.forward_mode.is_extend()
558
- and forward_batch.extend_prefix_lens.sum() == 0
559
- )
560
- else:
561
- # Triton: Use normal computation for prefill and use weight absorption for extend/decode
562
- return (
563
- forward_batch.forward_mode.is_extend()
564
- and not forward_batch.forward_mode.is_target_verify()
565
- and not forward_batch.forward_mode.is_draft_extend()
566
- and forward_batch.extend_prefix_lens.sum() == 0
567
- )
568
-
569
- if no_absorb():
573
+ if self.no_absorb(forward_batch):
570
574
  return self.forward_normal(positions, hidden_states, forward_batch)
571
575
  else:
572
- if is_hip_:
576
+ if _is_hip:
573
577
  if (
574
- os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
578
+ self.rocm_fused_decode_mla
575
579
  and forward_batch.forward_mode.is_decode()
576
580
  ):
577
581
  return self.forward_absorb_fused_mla_rope(
@@ -843,34 +847,6 @@ class DeepseekV2AttentionMLA(nn.Module):
843
847
  return output
844
848
 
845
849
 
846
- def all_gather(
847
- input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
848
- ):
849
- if world_size == 1:
850
- return input_tensor
851
-
852
- all_lens = forward_batch.global_num_tokens_cpu
853
- max_len = max(forward_batch.global_num_tokens_cpu)
854
-
855
- padded_tensor = torch.nn.functional.pad(
856
- input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
857
- )
858
-
859
- group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
860
-
861
- gathered_tensors = torch.concat(
862
- [
863
- forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
864
- for i in range(world_size)
865
- ]
866
- )
867
-
868
- start_index = 0 if rank == 0 else sum(all_lens[:rank])
869
- end_index = start_index + all_lens[rank]
870
-
871
- return gathered_tensors, start_index, end_index
872
-
873
-
874
850
  class DeepseekV2DecoderLayer(nn.Module):
875
851
 
876
852
  def __init__(
@@ -886,14 +862,10 @@ class DeepseekV2DecoderLayer(nn.Module):
886
862
  rope_theta = getattr(config, "rope_theta", 10000)
887
863
  rope_scaling = getattr(config, "rope_scaling", None)
888
864
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
889
- self.enable_dp_attention = (
890
- not global_server_args_dict["disable_mla"]
891
- and global_server_args_dict["enable_dp_attention"]
892
- )
893
- if self.enable_dp_attention:
894
- self.tp_rank = get_tensor_model_parallel_rank()
895
- self.tp_size = get_tensor_model_parallel_world_size()
896
- self.tp_group = get_tp_group()
865
+ self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
866
+ self.layer_id = layer_id
867
+ self.dp_size = get_attention_dp_size()
868
+
897
869
  if not global_server_args_dict["disable_mla"]:
898
870
  self.self_attn = DeepseekV2AttentionMLA(
899
871
  config=config,
@@ -911,7 +883,7 @@ class DeepseekV2DecoderLayer(nn.Module):
911
883
  max_position_embeddings=max_position_embeddings,
912
884
  quant_config=quant_config,
913
885
  layer_id=layer_id,
914
- use_dp=self.enable_dp_attention,
886
+ reduce_results=False,
915
887
  prefix=add_prefix("self_attn", prefix),
916
888
  )
917
889
  else:
@@ -931,8 +903,10 @@ class DeepseekV2DecoderLayer(nn.Module):
931
903
  max_position_embeddings=max_position_embeddings,
932
904
  quant_config=quant_config,
933
905
  layer_id=layer_id,
906
+ reduce_results=False,
934
907
  prefix=add_prefix("self_attn", prefix),
935
908
  )
909
+
936
910
  if is_nextn or (
937
911
  config.n_routed_experts is not None
938
912
  and layer_id >= config.first_k_dense_replace
@@ -963,33 +937,47 @@ class DeepseekV2DecoderLayer(nn.Module):
963
937
  forward_batch: ForwardBatch,
964
938
  residual: Optional[torch.Tensor],
965
939
  ) -> torch.Tensor:
940
+ if residual is None:
941
+ residual = hidden_states
942
+ hidden_states = self.input_layernorm(hidden_states)
943
+ else:
944
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
945
+
946
+ # Scatter
947
+ if self.dp_size != 1:
948
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
949
+ # be careful about this!
950
+ hidden_states, global_hidden_states = (
951
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
952
+ hidden_states,
953
+ )
954
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
955
+
966
956
  # Self Attention
967
- if not forward_batch.forward_mode.is_idle():
968
- if residual is None:
969
- residual = hidden_states
970
- hidden_states = self.input_layernorm(hidden_states)
957
+ hidden_states = self.self_attn(
958
+ positions=positions,
959
+ hidden_states=hidden_states,
960
+ forward_batch=forward_batch,
961
+ )
962
+
963
+ # Gather
964
+ if get_tensor_model_parallel_world_size() > 1:
965
+ # all gather and all reduce
966
+ if self.dp_size != 1:
967
+ hidden_states, local_hidden_states = (
968
+ forward_batch.gathered_buffer,
969
+ hidden_states,
970
+ )
971
+ dp_gather(
972
+ hidden_states, local_hidden_states, forward_batch, self.layer_id
973
+ )
971
974
  else:
972
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
975
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
973
976
 
974
- hidden_states = self.self_attn(
975
- positions=positions,
976
- hidden_states=hidden_states,
977
- forward_batch=forward_batch,
978
- )
979
- hidden_states, residual = self.post_attention_layernorm(
980
- hidden_states, residual
981
- )
977
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
982
978
 
983
979
  # Fully Connected
984
- if self.enable_dp_attention:
985
- hidden_states, start_idx, end_idx = all_gather(
986
- hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
987
- )
988
- hidden_states = self.mlp(hidden_states)
989
- hidden_states = hidden_states[start_idx:end_idx]
990
- else:
991
- hidden_states = self.mlp(hidden_states)
992
-
980
+ hidden_states = self.mlp(hidden_states)
993
981
  return hidden_states, residual
994
982
 
995
983
 
@@ -1025,12 +1013,27 @@ class DeepseekV2Model(nn.Module):
1025
1013
  )
1026
1014
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1027
1015
 
1016
+ self.dp_size = get_attention_dp_size()
1017
+
1028
1018
  def forward(
1029
1019
  self,
1030
1020
  input_ids: torch.Tensor,
1031
1021
  positions: torch.Tensor,
1032
1022
  forward_batch: ForwardBatch,
1033
1023
  ) -> torch.Tensor:
1024
+
1025
+ # Gather
1026
+ if self.dp_size != 1:
1027
+ input_ids, local_input_ids = (
1028
+ torch.empty(
1029
+ (forward_batch.gathered_buffer.shape[0],),
1030
+ dtype=input_ids.dtype,
1031
+ device=input_ids.device,
1032
+ ),
1033
+ input_ids,
1034
+ )
1035
+ dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
1036
+
1034
1037
  hidden_states = self.embed_tokens(input_ids)
1035
1038
  residual = None
1036
1039
  for i in range(len(self.layers)):
@@ -1057,22 +1060,14 @@ class DeepseekV2ForCausalLM(nn.Module):
1057
1060
  self.model = DeepseekV2Model(
1058
1061
  config, quant_config, prefix=add_prefix("model", prefix)
1059
1062
  )
1060
- if global_server_args_dict["enable_dp_attention"]:
1061
- self.lm_head = ReplicatedLinear(
1062
- config.hidden_size,
1063
- config.vocab_size,
1064
- bias=False,
1065
- prefix=add_prefix("lm_head", prefix),
1066
- )
1067
- self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
1068
- else:
1069
- self.lm_head = ParallelLMHead(
1070
- config.vocab_size,
1071
- config.hidden_size,
1072
- quant_config=quant_config,
1073
- prefix=add_prefix("lm_head", prefix),
1074
- )
1075
- self.logits_processor = LogitsProcessor(config)
1063
+ self.lm_head = ParallelLMHead(
1064
+ config.vocab_size,
1065
+ config.hidden_size,
1066
+ quant_config=quant_config,
1067
+ prefix=add_prefix("lm_head", prefix),
1068
+ )
1069
+ self.logits_processor = LogitsProcessor(config)
1070
+ self.dp_size = get_attention_dp_size()
1076
1071
 
1077
1072
  @torch.no_grad()
1078
1073
  def forward(
@@ -1082,6 +1077,16 @@ class DeepseekV2ForCausalLM(nn.Module):
1082
1077
  forward_batch: ForwardBatch,
1083
1078
  ) -> torch.Tensor:
1084
1079
  hidden_states = self.model(input_ids, positions, forward_batch)
1080
+
1081
+ if self.dp_size != 1:
1082
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
1083
+ # be careful about this!
1084
+ hidden_states, global_hidden_states = (
1085
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1086
+ hidden_states,
1087
+ )
1088
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
1089
+
1085
1090
  return self.logits_processor(
1086
1091
  input_ids, hidden_states, self.lm_head, forward_batch
1087
1092
  )
@@ -1188,7 +1193,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1188
1193
  weight_block_size = self.quant_config.weight_block_size
1189
1194
  if weight_block_size is not None:
1190
1195
  assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1191
- if is_hip_:
1196
+ if _is_hip:
1192
1197
  weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1193
1198
  weight=w,
1194
1199
  weight_scale=self_attn.kv_b_proj.weight_scale_inv,
@@ -1202,18 +1207,22 @@ class DeepseekV2ForCausalLM(nn.Module):
1202
1207
  weight, weight_scale, weight_block_size
1203
1208
  )
1204
1209
  self_attn.w_scale = scale
1205
- if (
1206
- hasattr(self.quant_config, "weight_block_size")
1207
- and w.dtype == torch.int8
1208
- ):
1209
- weight_block_size = self.quant_config.weight_block_size
1210
- if weight_block_size is not None:
1211
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1212
- weight = w
1213
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1214
- w = int8_block_dequant(
1215
- weight, weight_scale, weight_block_size
1216
- ).to(torch.bfloat16)
1210
+ if w.dtype == torch.int8:
1211
+ if hasattr(self.quant_config, "weight_block_size"):
1212
+ # block-wise int8 need it
1213
+ weight_block_size = self.quant_config.weight_block_size
1214
+ if weight_block_size is not None:
1215
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1216
+ weight = w
1217
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1218
+ w = int8_block_dequant(
1219
+ weight, weight_scale, weight_block_size
1220
+ ).to(torch.bfloat16)
1221
+ else:
1222
+ # channel-wise int8 need it
1223
+ w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
1224
+ torch.bfloat16
1225
+ )
1217
1226
  w_kc, w_vc = w.unflatten(
1218
1227
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1219
1228
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
@@ -1224,7 +1233,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1224
1233
  and self_attn.w_scale is None
1225
1234
  ):
1226
1235
  self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1227
- if is_hip_:
1236
+ if _is_hip:
1228
1237
  self_attn.w_scale *= 2.0
1229
1238
 
1230
1239
  def get_embed_and_head(self):