sglang 0.4.4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/managers/cache_controller.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/scheduler.py +52 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +9 -1
- sglang/srt/mem_cache/memory_pool.py +4 -1
- sglang/srt/model_executor/cuda_graph_runner.py +59 -16
- sglang/srt/model_executor/forward_batch_info.py +13 -4
- sglang/srt/models/deepseek_v2.py +180 -177
- sglang/srt/models/grok.py +374 -119
- sglang/srt/openai_api/adapter.py +22 -20
- sglang/srt/server_args.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +24 -22
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -26,15 +26,20 @@ from transformers import PretrainedConfig
|
|
26
26
|
from vllm import _custom_ops as ops
|
27
27
|
|
28
28
|
from sglang.srt.distributed import (
|
29
|
-
get_tensor_model_parallel_rank,
|
30
29
|
get_tensor_model_parallel_world_size,
|
31
|
-
get_tp_group,
|
32
30
|
tensor_model_parallel_all_reduce,
|
33
31
|
)
|
34
32
|
from sglang.srt.layers.activation import SiluAndMul
|
35
33
|
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
36
34
|
decode_attention_fwd_grouped_rope,
|
37
35
|
)
|
36
|
+
from sglang.srt.layers.dp_attention import (
|
37
|
+
dp_gather,
|
38
|
+
dp_scatter,
|
39
|
+
get_attention_dp_size,
|
40
|
+
get_attention_tp_rank,
|
41
|
+
get_attention_tp_size,
|
42
|
+
)
|
38
43
|
from sglang.srt.layers.layernorm import RMSNorm
|
39
44
|
from sglang.srt.layers.linear import (
|
40
45
|
ColumnParallelLinear,
|
@@ -230,6 +235,7 @@ class DeepseekV2Attention(nn.Module):
|
|
230
235
|
max_position_embeddings: int = 8192,
|
231
236
|
quant_config: Optional[QuantizationConfig] = None,
|
232
237
|
layer_id=None,
|
238
|
+
reduce_results: bool = True,
|
233
239
|
prefix: str = "",
|
234
240
|
) -> None:
|
235
241
|
super().__init__()
|
@@ -241,10 +247,14 @@ class DeepseekV2Attention(nn.Module):
|
|
241
247
|
self.v_head_dim = v_head_dim
|
242
248
|
self.q_lora_rank = q_lora_rank
|
243
249
|
self.kv_lora_rank = kv_lora_rank
|
250
|
+
|
251
|
+
self.dp_size = get_attention_dp_size()
|
252
|
+
attn_tp_rank = get_attention_tp_rank()
|
253
|
+
attn_tp_size = get_attention_tp_size()
|
254
|
+
|
244
255
|
self.num_heads = num_heads
|
245
|
-
|
246
|
-
|
247
|
-
self.num_local_heads = num_heads // tp_size
|
256
|
+
assert num_heads % attn_tp_size == 0
|
257
|
+
self.num_local_heads = num_heads // attn_tp_size
|
248
258
|
self.scaling = self.qk_head_dim**-0.5
|
249
259
|
self.rope_theta = rope_theta
|
250
260
|
self.max_position_embeddings = max_position_embeddings
|
@@ -272,6 +282,8 @@ class DeepseekV2Attention(nn.Module):
|
|
272
282
|
bias=False,
|
273
283
|
quant_config=quant_config,
|
274
284
|
prefix=add_prefix("q_proj", prefix),
|
285
|
+
tp_rank=attn_tp_rank,
|
286
|
+
tp_size=attn_tp_size,
|
275
287
|
)
|
276
288
|
|
277
289
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
@@ -296,6 +308,9 @@ class DeepseekV2Attention(nn.Module):
|
|
296
308
|
bias=False,
|
297
309
|
quant_config=quant_config,
|
298
310
|
prefix=add_prefix("o_proj", prefix),
|
311
|
+
reduce_results=reduce_results,
|
312
|
+
tp_rank=attn_tp_rank,
|
313
|
+
tp_size=attn_tp_size,
|
299
314
|
)
|
300
315
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
301
316
|
self.rotary_emb = get_rope_wrapper(
|
@@ -330,6 +345,12 @@ class DeepseekV2Attention(nn.Module):
|
|
330
345
|
hidden_states: torch.Tensor,
|
331
346
|
forward_batch: ForwardBatch,
|
332
347
|
) -> torch.Tensor:
|
348
|
+
if hidden_states.shape[0] == 0:
|
349
|
+
assert (
|
350
|
+
not self.o_proj.reduce_results
|
351
|
+
), "short-circuiting allreduce will lead to hangs"
|
352
|
+
return hidden_states
|
353
|
+
|
333
354
|
if self.q_lora_rank is not None:
|
334
355
|
q = self.q_a_proj(hidden_states)[0]
|
335
356
|
q = self.q_a_layernorm(q)
|
@@ -385,8 +406,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
385
406
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
386
407
|
max_position_embeddings: int = 8192,
|
387
408
|
quant_config: Optional[QuantizationConfig] = None,
|
388
|
-
|
389
|
-
|
409
|
+
reduce_results: bool = True,
|
410
|
+
layer_id: int = None,
|
390
411
|
prefix: str = "",
|
391
412
|
) -> None:
|
392
413
|
super().__init__()
|
@@ -398,96 +419,66 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
398
419
|
self.v_head_dim = v_head_dim
|
399
420
|
self.q_lora_rank = q_lora_rank
|
400
421
|
self.kv_lora_rank = kv_lora_rank
|
422
|
+
self.dp_size = get_attention_dp_size()
|
423
|
+
attn_tp_rank = get_attention_tp_rank()
|
424
|
+
attn_tp_size = get_attention_tp_size()
|
425
|
+
|
401
426
|
self.num_heads = num_heads
|
402
|
-
|
403
|
-
|
404
|
-
self.num_local_heads = num_heads if use_dp else num_heads // tp_size
|
427
|
+
assert num_heads % attn_tp_size == 0
|
428
|
+
self.num_local_heads = num_heads // attn_tp_size
|
405
429
|
self.scaling = self.qk_head_dim**-0.5
|
406
430
|
self.rope_theta = rope_theta
|
407
431
|
self.max_position_embeddings = max_position_embeddings
|
408
432
|
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
self.q_a_proj = ReplicatedLinear(
|
413
|
-
self.hidden_size,
|
414
|
-
self.q_lora_rank,
|
415
|
-
bias=False,
|
416
|
-
quant_config=quant_config,
|
417
|
-
prefix=add_prefix("q_a_proj", prefix),
|
418
|
-
)
|
419
|
-
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
420
|
-
self.q_b_proj = ReplicatedLinear(
|
421
|
-
q_lora_rank,
|
422
|
-
self.num_heads * self.qk_head_dim,
|
423
|
-
bias=False,
|
424
|
-
quant_config=quant_config,
|
425
|
-
prefix=add_prefix("q_b_proj", prefix),
|
426
|
-
)
|
427
|
-
else:
|
428
|
-
self.q_proj = ReplicatedLinear(
|
429
|
-
self.hidden_size,
|
430
|
-
self.num_heads * self.qk_head_dim,
|
431
|
-
bias=False,
|
432
|
-
quant_config=quant_config,
|
433
|
-
prefix=add_prefix("q_proj", prefix),
|
434
|
-
)
|
435
|
-
self.kv_b_proj = ReplicatedLinear(
|
436
|
-
self.kv_lora_rank,
|
437
|
-
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
438
|
-
bias=False,
|
439
|
-
quant_config=quant_config,
|
440
|
-
prefix=add_prefix("kv_b_proj", prefix),
|
441
|
-
)
|
442
|
-
# O projection.
|
443
|
-
self.o_proj = ReplicatedLinear(
|
444
|
-
self.num_heads * self.v_head_dim,
|
433
|
+
# For tensor parallel attention
|
434
|
+
if self.q_lora_rank is not None:
|
435
|
+
self.q_a_proj = ReplicatedLinear(
|
445
436
|
self.hidden_size,
|
437
|
+
self.q_lora_rank,
|
446
438
|
bias=False,
|
447
439
|
quant_config=quant_config,
|
448
|
-
prefix=add_prefix("
|
440
|
+
prefix=add_prefix("q_a_proj", prefix),
|
449
441
|
)
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
self.
|
454
|
-
self.hidden_size,
|
455
|
-
self.q_lora_rank,
|
456
|
-
bias=False,
|
457
|
-
quant_config=quant_config,
|
458
|
-
prefix=add_prefix("q_a_proj", prefix),
|
459
|
-
)
|
460
|
-
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
461
|
-
self.q_b_proj = ColumnParallelLinear(
|
462
|
-
q_lora_rank,
|
463
|
-
self.num_heads * self.qk_head_dim,
|
464
|
-
bias=False,
|
465
|
-
quant_config=quant_config,
|
466
|
-
prefix=add_prefix("q_b_proj", prefix),
|
467
|
-
)
|
468
|
-
else:
|
469
|
-
self.q_proj = ColumnParallelLinear(
|
470
|
-
self.hidden_size,
|
471
|
-
self.num_heads * self.qk_head_dim,
|
472
|
-
bias=False,
|
473
|
-
quant_config=quant_config,
|
474
|
-
prefix=add_prefix("q_proj", prefix),
|
475
|
-
)
|
476
|
-
self.kv_b_proj = ColumnParallelLinear(
|
477
|
-
self.kv_lora_rank,
|
478
|
-
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
442
|
+
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
443
|
+
self.q_b_proj = ColumnParallelLinear(
|
444
|
+
q_lora_rank,
|
445
|
+
self.num_heads * self.qk_head_dim,
|
479
446
|
bias=False,
|
480
447
|
quant_config=quant_config,
|
481
|
-
prefix=add_prefix("
|
448
|
+
prefix=add_prefix("q_b_proj", prefix),
|
449
|
+
tp_rank=attn_tp_rank,
|
450
|
+
tp_size=attn_tp_size,
|
482
451
|
)
|
483
|
-
|
484
|
-
self.
|
485
|
-
self.num_heads * self.v_head_dim,
|
452
|
+
else:
|
453
|
+
self.q_proj = ColumnParallelLinear(
|
486
454
|
self.hidden_size,
|
455
|
+
self.num_heads * self.qk_head_dim,
|
487
456
|
bias=False,
|
488
457
|
quant_config=quant_config,
|
489
|
-
prefix=add_prefix("
|
458
|
+
prefix=add_prefix("q_proj", prefix),
|
459
|
+
tp_rank=attn_tp_rank,
|
460
|
+
tp_size=attn_tp_size,
|
490
461
|
)
|
462
|
+
self.kv_b_proj = ColumnParallelLinear(
|
463
|
+
self.kv_lora_rank,
|
464
|
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
465
|
+
bias=False,
|
466
|
+
quant_config=quant_config,
|
467
|
+
prefix=add_prefix("kv_b_proj", prefix),
|
468
|
+
tp_rank=attn_tp_rank,
|
469
|
+
tp_size=attn_tp_size,
|
470
|
+
)
|
471
|
+
# O projection.
|
472
|
+
self.o_proj = RowParallelLinear(
|
473
|
+
self.num_heads * self.v_head_dim,
|
474
|
+
self.hidden_size,
|
475
|
+
bias=False,
|
476
|
+
quant_config=quant_config,
|
477
|
+
reduce_results=reduce_results,
|
478
|
+
prefix=add_prefix("o_proj", prefix),
|
479
|
+
tp_rank=attn_tp_rank,
|
480
|
+
tp_size=attn_tp_size,
|
481
|
+
)
|
491
482
|
|
492
483
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
493
484
|
self.hidden_size,
|
@@ -542,38 +533,49 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
542
533
|
self.w_vc = None
|
543
534
|
self.w_scale = None
|
544
535
|
|
536
|
+
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
|
537
|
+
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
538
|
+
"flashinfer_mla_disable_ragged"
|
539
|
+
]
|
540
|
+
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
541
|
+
|
542
|
+
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
543
|
+
if self.enable_flashinfer_mla:
|
544
|
+
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
545
|
+
return (
|
546
|
+
not self.flashinfer_mla_disable_ragged
|
547
|
+
and forward_batch.forward_mode.is_extend()
|
548
|
+
and not forward_batch.forward_mode.is_target_verify()
|
549
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
550
|
+
and forward_batch.extend_prefix_lens.sum() == 0
|
551
|
+
)
|
552
|
+
else:
|
553
|
+
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
554
|
+
return (
|
555
|
+
forward_batch.forward_mode.is_extend()
|
556
|
+
and not forward_batch.forward_mode.is_target_verify()
|
557
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
558
|
+
and forward_batch.extend_prefix_lens.sum() == 0
|
559
|
+
)
|
560
|
+
|
545
561
|
def forward(
|
546
562
|
self,
|
547
563
|
positions: torch.Tensor,
|
548
564
|
hidden_states: torch.Tensor,
|
549
565
|
forward_batch: ForwardBatch,
|
550
566
|
) -> torch.Tensor:
|
567
|
+
if hidden_states.shape[0] == 0:
|
568
|
+
assert (
|
569
|
+
not self.o_proj.reduce_results
|
570
|
+
), "short-circuiting allreduce will lead to hangs"
|
571
|
+
return hidden_states
|
551
572
|
|
552
|
-
|
553
|
-
if global_server_args_dict["enable_flashinfer_mla"]:
|
554
|
-
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
555
|
-
return (
|
556
|
-
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
557
|
-
and forward_batch.forward_mode.is_extend()
|
558
|
-
and not forward_batch.forward_mode.is_target_verify()
|
559
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
560
|
-
and forward_batch.extend_prefix_lens.sum() == 0
|
561
|
-
)
|
562
|
-
else:
|
563
|
-
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
564
|
-
return (
|
565
|
-
forward_batch.forward_mode.is_extend()
|
566
|
-
and not forward_batch.forward_mode.is_target_verify()
|
567
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
568
|
-
and forward_batch.extend_prefix_lens.sum() == 0
|
569
|
-
)
|
570
|
-
|
571
|
-
if no_absorb():
|
573
|
+
if self.no_absorb(forward_batch):
|
572
574
|
return self.forward_normal(positions, hidden_states, forward_batch)
|
573
575
|
else:
|
574
576
|
if _is_hip:
|
575
577
|
if (
|
576
|
-
|
578
|
+
self.rocm_fused_decode_mla
|
577
579
|
and forward_batch.forward_mode.is_decode()
|
578
580
|
):
|
579
581
|
return self.forward_absorb_fused_mla_rope(
|
@@ -845,34 +847,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
845
847
|
return output
|
846
848
|
|
847
849
|
|
848
|
-
def all_gather(
|
849
|
-
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
850
|
-
):
|
851
|
-
all_lens = forward_batch.global_num_tokens_cpu
|
852
|
-
max_len = max(forward_batch.global_num_tokens_cpu)
|
853
|
-
|
854
|
-
if world_size == 1:
|
855
|
-
return input_tensor, 0, all_lens[0]
|
856
|
-
|
857
|
-
padded_tensor = torch.nn.functional.pad(
|
858
|
-
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
859
|
-
)
|
860
|
-
|
861
|
-
group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
|
862
|
-
|
863
|
-
gathered_tensors = torch.concat(
|
864
|
-
[
|
865
|
-
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
|
866
|
-
for i in range(world_size)
|
867
|
-
]
|
868
|
-
)
|
869
|
-
|
870
|
-
start_index = 0 if rank == 0 else sum(all_lens[:rank])
|
871
|
-
end_index = start_index + all_lens[rank]
|
872
|
-
|
873
|
-
return gathered_tensors, start_index, end_index
|
874
|
-
|
875
|
-
|
876
850
|
class DeepseekV2DecoderLayer(nn.Module):
|
877
851
|
|
878
852
|
def __init__(
|
@@ -888,14 +862,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
888
862
|
rope_theta = getattr(config, "rope_theta", 10000)
|
889
863
|
rope_scaling = getattr(config, "rope_scaling", None)
|
890
864
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
891
|
-
self.enable_dp_attention =
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
if self.enable_dp_attention:
|
896
|
-
self.tp_rank = get_tensor_model_parallel_rank()
|
897
|
-
self.tp_size = get_tensor_model_parallel_world_size()
|
898
|
-
self.tp_group = get_tp_group()
|
865
|
+
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
866
|
+
self.layer_id = layer_id
|
867
|
+
self.dp_size = get_attention_dp_size()
|
868
|
+
|
899
869
|
if not global_server_args_dict["disable_mla"]:
|
900
870
|
self.self_attn = DeepseekV2AttentionMLA(
|
901
871
|
config=config,
|
@@ -913,7 +883,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
913
883
|
max_position_embeddings=max_position_embeddings,
|
914
884
|
quant_config=quant_config,
|
915
885
|
layer_id=layer_id,
|
916
|
-
|
886
|
+
reduce_results=False,
|
917
887
|
prefix=add_prefix("self_attn", prefix),
|
918
888
|
)
|
919
889
|
else:
|
@@ -933,8 +903,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
933
903
|
max_position_embeddings=max_position_embeddings,
|
934
904
|
quant_config=quant_config,
|
935
905
|
layer_id=layer_id,
|
906
|
+
reduce_results=False,
|
936
907
|
prefix=add_prefix("self_attn", prefix),
|
937
908
|
)
|
909
|
+
|
938
910
|
if is_nextn or (
|
939
911
|
config.n_routed_experts is not None
|
940
912
|
and layer_id >= config.first_k_dense_replace
|
@@ -965,33 +937,47 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
965
937
|
forward_batch: ForwardBatch,
|
966
938
|
residual: Optional[torch.Tensor],
|
967
939
|
) -> torch.Tensor:
|
940
|
+
if residual is None:
|
941
|
+
residual = hidden_states
|
942
|
+
hidden_states = self.input_layernorm(hidden_states)
|
943
|
+
else:
|
944
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
945
|
+
|
946
|
+
# Scatter
|
947
|
+
if self.dp_size != 1:
|
948
|
+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
949
|
+
# be careful about this!
|
950
|
+
hidden_states, global_hidden_states = (
|
951
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
952
|
+
hidden_states,
|
953
|
+
)
|
954
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
955
|
+
|
968
956
|
# Self Attention
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
|
957
|
+
hidden_states = self.self_attn(
|
958
|
+
positions=positions,
|
959
|
+
hidden_states=hidden_states,
|
960
|
+
forward_batch=forward_batch,
|
961
|
+
)
|
962
|
+
|
963
|
+
# Gather
|
964
|
+
if get_tensor_model_parallel_world_size() > 1:
|
965
|
+
# all gather and all reduce
|
966
|
+
if self.dp_size != 1:
|
967
|
+
hidden_states, local_hidden_states = (
|
968
|
+
forward_batch.gathered_buffer,
|
969
|
+
hidden_states,
|
970
|
+
)
|
971
|
+
dp_gather(
|
972
|
+
hidden_states, local_hidden_states, forward_batch, self.layer_id
|
973
|
+
)
|
973
974
|
else:
|
974
|
-
hidden_states
|
975
|
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
975
976
|
|
976
|
-
|
977
|
-
positions=positions,
|
978
|
-
hidden_states=hidden_states,
|
979
|
-
forward_batch=forward_batch,
|
980
|
-
)
|
981
|
-
hidden_states, residual = self.post_attention_layernorm(
|
982
|
-
hidden_states, residual
|
983
|
-
)
|
977
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
984
978
|
|
985
979
|
# Fully Connected
|
986
|
-
|
987
|
-
hidden_states, start_idx, end_idx = all_gather(
|
988
|
-
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
|
989
|
-
)
|
990
|
-
hidden_states = self.mlp(hidden_states)
|
991
|
-
hidden_states = hidden_states[start_idx:end_idx]
|
992
|
-
else:
|
993
|
-
hidden_states = self.mlp(hidden_states)
|
994
|
-
|
980
|
+
hidden_states = self.mlp(hidden_states)
|
995
981
|
return hidden_states, residual
|
996
982
|
|
997
983
|
|
@@ -1027,12 +1013,27 @@ class DeepseekV2Model(nn.Module):
|
|
1027
1013
|
)
|
1028
1014
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1029
1015
|
|
1016
|
+
self.dp_size = get_attention_dp_size()
|
1017
|
+
|
1030
1018
|
def forward(
|
1031
1019
|
self,
|
1032
1020
|
input_ids: torch.Tensor,
|
1033
1021
|
positions: torch.Tensor,
|
1034
1022
|
forward_batch: ForwardBatch,
|
1035
1023
|
) -> torch.Tensor:
|
1024
|
+
|
1025
|
+
# Gather
|
1026
|
+
if self.dp_size != 1:
|
1027
|
+
input_ids, local_input_ids = (
|
1028
|
+
torch.empty(
|
1029
|
+
(forward_batch.gathered_buffer.shape[0],),
|
1030
|
+
dtype=input_ids.dtype,
|
1031
|
+
device=input_ids.device,
|
1032
|
+
),
|
1033
|
+
input_ids,
|
1034
|
+
)
|
1035
|
+
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
|
1036
|
+
|
1036
1037
|
hidden_states = self.embed_tokens(input_ids)
|
1037
1038
|
residual = None
|
1038
1039
|
for i in range(len(self.layers)):
|
@@ -1059,22 +1060,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1059
1060
|
self.model = DeepseekV2Model(
|
1060
1061
|
config, quant_config, prefix=add_prefix("model", prefix)
|
1061
1062
|
)
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
else:
|
1071
|
-
self.lm_head = ParallelLMHead(
|
1072
|
-
config.vocab_size,
|
1073
|
-
config.hidden_size,
|
1074
|
-
quant_config=quant_config,
|
1075
|
-
prefix=add_prefix("lm_head", prefix),
|
1076
|
-
)
|
1077
|
-
self.logits_processor = LogitsProcessor(config)
|
1063
|
+
self.lm_head = ParallelLMHead(
|
1064
|
+
config.vocab_size,
|
1065
|
+
config.hidden_size,
|
1066
|
+
quant_config=quant_config,
|
1067
|
+
prefix=add_prefix("lm_head", prefix),
|
1068
|
+
)
|
1069
|
+
self.logits_processor = LogitsProcessor(config)
|
1070
|
+
self.dp_size = get_attention_dp_size()
|
1078
1071
|
|
1079
1072
|
@torch.no_grad()
|
1080
1073
|
def forward(
|
@@ -1084,6 +1077,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1084
1077
|
forward_batch: ForwardBatch,
|
1085
1078
|
) -> torch.Tensor:
|
1086
1079
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
1080
|
+
|
1081
|
+
if self.dp_size != 1:
|
1082
|
+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
1083
|
+
# be careful about this!
|
1084
|
+
hidden_states, global_hidden_states = (
|
1085
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1086
|
+
hidden_states,
|
1087
|
+
)
|
1088
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
1089
|
+
|
1087
1090
|
return self.logits_processor(
|
1088
1091
|
input_ids, hidden_states, self.lm_head, forward_batch
|
1089
1092
|
)
|