sglang 0.4.4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,15 +26,20 @@ from transformers import PretrainedConfig
26
26
  from vllm import _custom_ops as ops
27
27
 
28
28
  from sglang.srt.distributed import (
29
- get_tensor_model_parallel_rank,
30
29
  get_tensor_model_parallel_world_size,
31
- get_tp_group,
32
30
  tensor_model_parallel_all_reduce,
33
31
  )
34
32
  from sglang.srt.layers.activation import SiluAndMul
35
33
  from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
36
34
  decode_attention_fwd_grouped_rope,
37
35
  )
36
+ from sglang.srt.layers.dp_attention import (
37
+ dp_gather,
38
+ dp_scatter,
39
+ get_attention_dp_size,
40
+ get_attention_tp_rank,
41
+ get_attention_tp_size,
42
+ )
38
43
  from sglang.srt.layers.layernorm import RMSNorm
39
44
  from sglang.srt.layers.linear import (
40
45
  ColumnParallelLinear,
@@ -230,6 +235,7 @@ class DeepseekV2Attention(nn.Module):
230
235
  max_position_embeddings: int = 8192,
231
236
  quant_config: Optional[QuantizationConfig] = None,
232
237
  layer_id=None,
238
+ reduce_results: bool = True,
233
239
  prefix: str = "",
234
240
  ) -> None:
235
241
  super().__init__()
@@ -241,10 +247,14 @@ class DeepseekV2Attention(nn.Module):
241
247
  self.v_head_dim = v_head_dim
242
248
  self.q_lora_rank = q_lora_rank
243
249
  self.kv_lora_rank = kv_lora_rank
250
+
251
+ self.dp_size = get_attention_dp_size()
252
+ attn_tp_rank = get_attention_tp_rank()
253
+ attn_tp_size = get_attention_tp_size()
254
+
244
255
  self.num_heads = num_heads
245
- tp_size = get_tensor_model_parallel_world_size()
246
- assert num_heads % tp_size == 0
247
- self.num_local_heads = num_heads // tp_size
256
+ assert num_heads % attn_tp_size == 0
257
+ self.num_local_heads = num_heads // attn_tp_size
248
258
  self.scaling = self.qk_head_dim**-0.5
249
259
  self.rope_theta = rope_theta
250
260
  self.max_position_embeddings = max_position_embeddings
@@ -272,6 +282,8 @@ class DeepseekV2Attention(nn.Module):
272
282
  bias=False,
273
283
  quant_config=quant_config,
274
284
  prefix=add_prefix("q_proj", prefix),
285
+ tp_rank=attn_tp_rank,
286
+ tp_size=attn_tp_size,
275
287
  )
276
288
 
277
289
  self.kv_a_proj_with_mqa = ReplicatedLinear(
@@ -296,6 +308,9 @@ class DeepseekV2Attention(nn.Module):
296
308
  bias=False,
297
309
  quant_config=quant_config,
298
310
  prefix=add_prefix("o_proj", prefix),
311
+ reduce_results=reduce_results,
312
+ tp_rank=attn_tp_rank,
313
+ tp_size=attn_tp_size,
299
314
  )
300
315
  rope_scaling["rope_type"] = "deepseek_yarn"
301
316
  self.rotary_emb = get_rope_wrapper(
@@ -330,6 +345,12 @@ class DeepseekV2Attention(nn.Module):
330
345
  hidden_states: torch.Tensor,
331
346
  forward_batch: ForwardBatch,
332
347
  ) -> torch.Tensor:
348
+ if hidden_states.shape[0] == 0:
349
+ assert (
350
+ not self.o_proj.reduce_results
351
+ ), "short-circuiting allreduce will lead to hangs"
352
+ return hidden_states
353
+
333
354
  if self.q_lora_rank is not None:
334
355
  q = self.q_a_proj(hidden_states)[0]
335
356
  q = self.q_a_layernorm(q)
@@ -385,8 +406,8 @@ class DeepseekV2AttentionMLA(nn.Module):
385
406
  rope_scaling: Optional[Dict[str, Any]] = None,
386
407
  max_position_embeddings: int = 8192,
387
408
  quant_config: Optional[QuantizationConfig] = None,
388
- layer_id=None,
389
- use_dp=False,
409
+ reduce_results: bool = True,
410
+ layer_id: int = None,
390
411
  prefix: str = "",
391
412
  ) -> None:
392
413
  super().__init__()
@@ -398,96 +419,66 @@ class DeepseekV2AttentionMLA(nn.Module):
398
419
  self.v_head_dim = v_head_dim
399
420
  self.q_lora_rank = q_lora_rank
400
421
  self.kv_lora_rank = kv_lora_rank
422
+ self.dp_size = get_attention_dp_size()
423
+ attn_tp_rank = get_attention_tp_rank()
424
+ attn_tp_size = get_attention_tp_size()
425
+
401
426
  self.num_heads = num_heads
402
- tp_size = get_tensor_model_parallel_world_size()
403
- assert num_heads % tp_size == 0
404
- self.num_local_heads = num_heads if use_dp else num_heads // tp_size
427
+ assert num_heads % attn_tp_size == 0
428
+ self.num_local_heads = num_heads // attn_tp_size
405
429
  self.scaling = self.qk_head_dim**-0.5
406
430
  self.rope_theta = rope_theta
407
431
  self.max_position_embeddings = max_position_embeddings
408
432
 
409
- if use_dp:
410
- # For data parallel attention
411
- if self.q_lora_rank is not None:
412
- self.q_a_proj = ReplicatedLinear(
413
- self.hidden_size,
414
- self.q_lora_rank,
415
- bias=False,
416
- quant_config=quant_config,
417
- prefix=add_prefix("q_a_proj", prefix),
418
- )
419
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
420
- self.q_b_proj = ReplicatedLinear(
421
- q_lora_rank,
422
- self.num_heads * self.qk_head_dim,
423
- bias=False,
424
- quant_config=quant_config,
425
- prefix=add_prefix("q_b_proj", prefix),
426
- )
427
- else:
428
- self.q_proj = ReplicatedLinear(
429
- self.hidden_size,
430
- self.num_heads * self.qk_head_dim,
431
- bias=False,
432
- quant_config=quant_config,
433
- prefix=add_prefix("q_proj", prefix),
434
- )
435
- self.kv_b_proj = ReplicatedLinear(
436
- self.kv_lora_rank,
437
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
438
- bias=False,
439
- quant_config=quant_config,
440
- prefix=add_prefix("kv_b_proj", prefix),
441
- )
442
- # O projection.
443
- self.o_proj = ReplicatedLinear(
444
- self.num_heads * self.v_head_dim,
433
+ # For tensor parallel attention
434
+ if self.q_lora_rank is not None:
435
+ self.q_a_proj = ReplicatedLinear(
445
436
  self.hidden_size,
437
+ self.q_lora_rank,
446
438
  bias=False,
447
439
  quant_config=quant_config,
448
- prefix=add_prefix("o_proj", prefix),
440
+ prefix=add_prefix("q_a_proj", prefix),
449
441
  )
450
- else:
451
- # For tensor parallel attention
452
- if self.q_lora_rank is not None:
453
- self.q_a_proj = ReplicatedLinear(
454
- self.hidden_size,
455
- self.q_lora_rank,
456
- bias=False,
457
- quant_config=quant_config,
458
- prefix=add_prefix("q_a_proj", prefix),
459
- )
460
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
461
- self.q_b_proj = ColumnParallelLinear(
462
- q_lora_rank,
463
- self.num_heads * self.qk_head_dim,
464
- bias=False,
465
- quant_config=quant_config,
466
- prefix=add_prefix("q_b_proj", prefix),
467
- )
468
- else:
469
- self.q_proj = ColumnParallelLinear(
470
- self.hidden_size,
471
- self.num_heads * self.qk_head_dim,
472
- bias=False,
473
- quant_config=quant_config,
474
- prefix=add_prefix("q_proj", prefix),
475
- )
476
- self.kv_b_proj = ColumnParallelLinear(
477
- self.kv_lora_rank,
478
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
442
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
443
+ self.q_b_proj = ColumnParallelLinear(
444
+ q_lora_rank,
445
+ self.num_heads * self.qk_head_dim,
479
446
  bias=False,
480
447
  quant_config=quant_config,
481
- prefix=add_prefix("kv_b_proj", prefix),
448
+ prefix=add_prefix("q_b_proj", prefix),
449
+ tp_rank=attn_tp_rank,
450
+ tp_size=attn_tp_size,
482
451
  )
483
- # O projection.
484
- self.o_proj = RowParallelLinear(
485
- self.num_heads * self.v_head_dim,
452
+ else:
453
+ self.q_proj = ColumnParallelLinear(
486
454
  self.hidden_size,
455
+ self.num_heads * self.qk_head_dim,
487
456
  bias=False,
488
457
  quant_config=quant_config,
489
- prefix=add_prefix("o_proj", prefix),
458
+ prefix=add_prefix("q_proj", prefix),
459
+ tp_rank=attn_tp_rank,
460
+ tp_size=attn_tp_size,
490
461
  )
462
+ self.kv_b_proj = ColumnParallelLinear(
463
+ self.kv_lora_rank,
464
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
465
+ bias=False,
466
+ quant_config=quant_config,
467
+ prefix=add_prefix("kv_b_proj", prefix),
468
+ tp_rank=attn_tp_rank,
469
+ tp_size=attn_tp_size,
470
+ )
471
+ # O projection.
472
+ self.o_proj = RowParallelLinear(
473
+ self.num_heads * self.v_head_dim,
474
+ self.hidden_size,
475
+ bias=False,
476
+ quant_config=quant_config,
477
+ reduce_results=reduce_results,
478
+ prefix=add_prefix("o_proj", prefix),
479
+ tp_rank=attn_tp_rank,
480
+ tp_size=attn_tp_size,
481
+ )
491
482
 
492
483
  self.kv_a_proj_with_mqa = ReplicatedLinear(
493
484
  self.hidden_size,
@@ -542,38 +533,49 @@ class DeepseekV2AttentionMLA(nn.Module):
542
533
  self.w_vc = None
543
534
  self.w_scale = None
544
535
 
536
+ self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
537
+ self.flashinfer_mla_disable_ragged = global_server_args_dict[
538
+ "flashinfer_mla_disable_ragged"
539
+ ]
540
+ self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
541
+
542
+ def no_absorb(self, forward_batch: ForwardBatch) -> bool:
543
+ if self.enable_flashinfer_mla:
544
+ # Flashinfer MLA: Do not absorb when enabling ragged prefill
545
+ return (
546
+ not self.flashinfer_mla_disable_ragged
547
+ and forward_batch.forward_mode.is_extend()
548
+ and not forward_batch.forward_mode.is_target_verify()
549
+ and not forward_batch.forward_mode.is_draft_extend()
550
+ and forward_batch.extend_prefix_lens.sum() == 0
551
+ )
552
+ else:
553
+ # Triton: Use normal computation for prefill and use weight absorption for extend/decode
554
+ return (
555
+ forward_batch.forward_mode.is_extend()
556
+ and not forward_batch.forward_mode.is_target_verify()
557
+ and not forward_batch.forward_mode.is_draft_extend()
558
+ and forward_batch.extend_prefix_lens.sum() == 0
559
+ )
560
+
545
561
  def forward(
546
562
  self,
547
563
  positions: torch.Tensor,
548
564
  hidden_states: torch.Tensor,
549
565
  forward_batch: ForwardBatch,
550
566
  ) -> torch.Tensor:
567
+ if hidden_states.shape[0] == 0:
568
+ assert (
569
+ not self.o_proj.reduce_results
570
+ ), "short-circuiting allreduce will lead to hangs"
571
+ return hidden_states
551
572
 
552
- def no_absorb() -> bool:
553
- if global_server_args_dict["enable_flashinfer_mla"]:
554
- # Flashinfer MLA: Do not absorb when enabling ragged prefill
555
- return (
556
- not global_server_args_dict["flashinfer_mla_disable_ragged"]
557
- and forward_batch.forward_mode.is_extend()
558
- and not forward_batch.forward_mode.is_target_verify()
559
- and not forward_batch.forward_mode.is_draft_extend()
560
- and forward_batch.extend_prefix_lens.sum() == 0
561
- )
562
- else:
563
- # Triton: Use normal computation for prefill and use weight absorption for extend/decode
564
- return (
565
- forward_batch.forward_mode.is_extend()
566
- and not forward_batch.forward_mode.is_target_verify()
567
- and not forward_batch.forward_mode.is_draft_extend()
568
- and forward_batch.extend_prefix_lens.sum() == 0
569
- )
570
-
571
- if no_absorb():
573
+ if self.no_absorb(forward_batch):
572
574
  return self.forward_normal(positions, hidden_states, forward_batch)
573
575
  else:
574
576
  if _is_hip:
575
577
  if (
576
- os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
578
+ self.rocm_fused_decode_mla
577
579
  and forward_batch.forward_mode.is_decode()
578
580
  ):
579
581
  return self.forward_absorb_fused_mla_rope(
@@ -845,34 +847,6 @@ class DeepseekV2AttentionMLA(nn.Module):
845
847
  return output
846
848
 
847
849
 
848
- def all_gather(
849
- input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
850
- ):
851
- all_lens = forward_batch.global_num_tokens_cpu
852
- max_len = max(forward_batch.global_num_tokens_cpu)
853
-
854
- if world_size == 1:
855
- return input_tensor, 0, all_lens[0]
856
-
857
- padded_tensor = torch.nn.functional.pad(
858
- input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
859
- )
860
-
861
- group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
862
-
863
- gathered_tensors = torch.concat(
864
- [
865
- forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
866
- for i in range(world_size)
867
- ]
868
- )
869
-
870
- start_index = 0 if rank == 0 else sum(all_lens[:rank])
871
- end_index = start_index + all_lens[rank]
872
-
873
- return gathered_tensors, start_index, end_index
874
-
875
-
876
850
  class DeepseekV2DecoderLayer(nn.Module):
877
851
 
878
852
  def __init__(
@@ -888,14 +862,10 @@ class DeepseekV2DecoderLayer(nn.Module):
888
862
  rope_theta = getattr(config, "rope_theta", 10000)
889
863
  rope_scaling = getattr(config, "rope_scaling", None)
890
864
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
891
- self.enable_dp_attention = (
892
- not global_server_args_dict["disable_mla"]
893
- and global_server_args_dict["enable_dp_attention"]
894
- )
895
- if self.enable_dp_attention:
896
- self.tp_rank = get_tensor_model_parallel_rank()
897
- self.tp_size = get_tensor_model_parallel_world_size()
898
- self.tp_group = get_tp_group()
865
+ self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
866
+ self.layer_id = layer_id
867
+ self.dp_size = get_attention_dp_size()
868
+
899
869
  if not global_server_args_dict["disable_mla"]:
900
870
  self.self_attn = DeepseekV2AttentionMLA(
901
871
  config=config,
@@ -913,7 +883,7 @@ class DeepseekV2DecoderLayer(nn.Module):
913
883
  max_position_embeddings=max_position_embeddings,
914
884
  quant_config=quant_config,
915
885
  layer_id=layer_id,
916
- use_dp=self.enable_dp_attention,
886
+ reduce_results=False,
917
887
  prefix=add_prefix("self_attn", prefix),
918
888
  )
919
889
  else:
@@ -933,8 +903,10 @@ class DeepseekV2DecoderLayer(nn.Module):
933
903
  max_position_embeddings=max_position_embeddings,
934
904
  quant_config=quant_config,
935
905
  layer_id=layer_id,
906
+ reduce_results=False,
936
907
  prefix=add_prefix("self_attn", prefix),
937
908
  )
909
+
938
910
  if is_nextn or (
939
911
  config.n_routed_experts is not None
940
912
  and layer_id >= config.first_k_dense_replace
@@ -965,33 +937,47 @@ class DeepseekV2DecoderLayer(nn.Module):
965
937
  forward_batch: ForwardBatch,
966
938
  residual: Optional[torch.Tensor],
967
939
  ) -> torch.Tensor:
940
+ if residual is None:
941
+ residual = hidden_states
942
+ hidden_states = self.input_layernorm(hidden_states)
943
+ else:
944
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
945
+
946
+ # Scatter
947
+ if self.dp_size != 1:
948
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
949
+ # be careful about this!
950
+ hidden_states, global_hidden_states = (
951
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
952
+ hidden_states,
953
+ )
954
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
955
+
968
956
  # Self Attention
969
- if not forward_batch.forward_mode.is_idle():
970
- if residual is None:
971
- residual = hidden_states
972
- hidden_states = self.input_layernorm(hidden_states)
957
+ hidden_states = self.self_attn(
958
+ positions=positions,
959
+ hidden_states=hidden_states,
960
+ forward_batch=forward_batch,
961
+ )
962
+
963
+ # Gather
964
+ if get_tensor_model_parallel_world_size() > 1:
965
+ # all gather and all reduce
966
+ if self.dp_size != 1:
967
+ hidden_states, local_hidden_states = (
968
+ forward_batch.gathered_buffer,
969
+ hidden_states,
970
+ )
971
+ dp_gather(
972
+ hidden_states, local_hidden_states, forward_batch, self.layer_id
973
+ )
973
974
  else:
974
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
975
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
975
976
 
976
- hidden_states = self.self_attn(
977
- positions=positions,
978
- hidden_states=hidden_states,
979
- forward_batch=forward_batch,
980
- )
981
- hidden_states, residual = self.post_attention_layernorm(
982
- hidden_states, residual
983
- )
977
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
984
978
 
985
979
  # Fully Connected
986
- if self.enable_dp_attention:
987
- hidden_states, start_idx, end_idx = all_gather(
988
- hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
989
- )
990
- hidden_states = self.mlp(hidden_states)
991
- hidden_states = hidden_states[start_idx:end_idx]
992
- else:
993
- hidden_states = self.mlp(hidden_states)
994
-
980
+ hidden_states = self.mlp(hidden_states)
995
981
  return hidden_states, residual
996
982
 
997
983
 
@@ -1027,12 +1013,27 @@ class DeepseekV2Model(nn.Module):
1027
1013
  )
1028
1014
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1029
1015
 
1016
+ self.dp_size = get_attention_dp_size()
1017
+
1030
1018
  def forward(
1031
1019
  self,
1032
1020
  input_ids: torch.Tensor,
1033
1021
  positions: torch.Tensor,
1034
1022
  forward_batch: ForwardBatch,
1035
1023
  ) -> torch.Tensor:
1024
+
1025
+ # Gather
1026
+ if self.dp_size != 1:
1027
+ input_ids, local_input_ids = (
1028
+ torch.empty(
1029
+ (forward_batch.gathered_buffer.shape[0],),
1030
+ dtype=input_ids.dtype,
1031
+ device=input_ids.device,
1032
+ ),
1033
+ input_ids,
1034
+ )
1035
+ dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
1036
+
1036
1037
  hidden_states = self.embed_tokens(input_ids)
1037
1038
  residual = None
1038
1039
  for i in range(len(self.layers)):
@@ -1059,22 +1060,14 @@ class DeepseekV2ForCausalLM(nn.Module):
1059
1060
  self.model = DeepseekV2Model(
1060
1061
  config, quant_config, prefix=add_prefix("model", prefix)
1061
1062
  )
1062
- if global_server_args_dict["enable_dp_attention"]:
1063
- self.lm_head = ReplicatedLinear(
1064
- config.hidden_size,
1065
- config.vocab_size,
1066
- bias=False,
1067
- prefix=add_prefix("lm_head", prefix),
1068
- )
1069
- self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
1070
- else:
1071
- self.lm_head = ParallelLMHead(
1072
- config.vocab_size,
1073
- config.hidden_size,
1074
- quant_config=quant_config,
1075
- prefix=add_prefix("lm_head", prefix),
1076
- )
1077
- self.logits_processor = LogitsProcessor(config)
1063
+ self.lm_head = ParallelLMHead(
1064
+ config.vocab_size,
1065
+ config.hidden_size,
1066
+ quant_config=quant_config,
1067
+ prefix=add_prefix("lm_head", prefix),
1068
+ )
1069
+ self.logits_processor = LogitsProcessor(config)
1070
+ self.dp_size = get_attention_dp_size()
1078
1071
 
1079
1072
  @torch.no_grad()
1080
1073
  def forward(
@@ -1084,6 +1077,16 @@ class DeepseekV2ForCausalLM(nn.Module):
1084
1077
  forward_batch: ForwardBatch,
1085
1078
  ) -> torch.Tensor:
1086
1079
  hidden_states = self.model(input_ids, positions, forward_batch)
1080
+
1081
+ if self.dp_size != 1:
1082
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
1083
+ # be careful about this!
1084
+ hidden_states, global_hidden_states = (
1085
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1086
+ hidden_states,
1087
+ )
1088
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
1089
+
1087
1090
  return self.logits_processor(
1088
1091
  input_ids, hidden_states, self.lm_head, forward_batch
1089
1092
  )