sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -20,7 +20,7 @@ import concurrent.futures
|
|
20
20
|
import logging
|
21
21
|
import os
|
22
22
|
from enum import IntEnum, auto
|
23
|
-
from typing import Any, Dict, Iterable, Optional, Tuple
|
23
|
+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
24
24
|
|
25
25
|
import torch
|
26
26
|
import torch.nn.functional as F
|
@@ -30,6 +30,7 @@ from transformers import PretrainedConfig
|
|
30
30
|
|
31
31
|
from sglang.srt.distributed import (
|
32
32
|
get_moe_expert_parallel_world_size,
|
33
|
+
get_pp_group,
|
33
34
|
get_tensor_model_parallel_world_size,
|
34
35
|
parallel_state,
|
35
36
|
tensor_model_parallel_all_reduce,
|
@@ -50,7 +51,7 @@ from sglang.srt.layers.communicator import (
|
|
50
51
|
from sglang.srt.layers.dp_attention import (
|
51
52
|
get_attention_tp_rank,
|
52
53
|
get_attention_tp_size,
|
53
|
-
|
54
|
+
is_dp_attention_enabled,
|
54
55
|
)
|
55
56
|
from sglang.srt.layers.layernorm import RMSNorm
|
56
57
|
from sglang.srt.layers.linear import (
|
@@ -60,9 +61,14 @@ from sglang.srt.layers.linear import (
|
|
60
61
|
RowParallelLinear,
|
61
62
|
)
|
62
63
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
64
|
+
from sglang.srt.layers.moe import (
|
65
|
+
get_deepep_mode,
|
66
|
+
get_moe_a2a_backend,
|
67
|
+
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
68
|
+
)
|
63
69
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
70
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
64
71
|
from sglang.srt.layers.moe.topk import TopK
|
65
|
-
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
66
72
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
67
73
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
68
74
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
@@ -82,13 +88,13 @@ from sglang.srt.layers.quantization.int8_utils import (
|
|
82
88
|
)
|
83
89
|
from sglang.srt.layers.radix_attention import RadixAttention
|
84
90
|
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
85
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
91
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
|
86
92
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
87
93
|
ParallelLMHead,
|
88
94
|
VocabParallelEmbedding,
|
89
95
|
)
|
90
96
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
91
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
97
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
92
98
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
93
99
|
from sglang.srt.two_batch_overlap import (
|
94
100
|
MaybeTboDeepEPDispatcher,
|
@@ -109,6 +115,7 @@ from sglang.srt.utils import (
|
|
109
115
|
is_hip,
|
110
116
|
is_non_idle_and_non_empty,
|
111
117
|
log_info_on_rank0,
|
118
|
+
make_layers,
|
112
119
|
use_intel_amx_backend,
|
113
120
|
)
|
114
121
|
|
@@ -312,18 +319,7 @@ class DeepseekV2MoE(nn.Module):
|
|
312
319
|
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
313
320
|
)
|
314
321
|
|
315
|
-
self.
|
316
|
-
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
317
|
-
renormalize=config.norm_topk_prob,
|
318
|
-
use_grouped_topk=True,
|
319
|
-
num_expert_group=config.n_group,
|
320
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
321
|
-
topk_group=config.topk_group,
|
322
|
-
correction_bias=self.gate.e_score_correction_bias,
|
323
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
324
|
-
)
|
325
|
-
|
326
|
-
self.experts = get_moe_impl_class()(
|
322
|
+
self.experts = get_moe_impl_class(quant_config)(
|
327
323
|
num_experts=config.n_routed_experts
|
328
324
|
+ self.num_fused_shared_experts
|
329
325
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
@@ -335,30 +331,19 @@ class DeepseekV2MoE(nn.Module):
|
|
335
331
|
quant_config=quant_config,
|
336
332
|
routed_scaling_factor=self.routed_scaling_factor,
|
337
333
|
prefix=add_prefix("experts", prefix),
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
**(
|
352
|
-
dict(
|
353
|
-
renormalize=config.norm_topk_prob,
|
354
|
-
use_grouped_topk=True,
|
355
|
-
num_expert_group=config.n_group,
|
356
|
-
topk_group=config.topk_group,
|
357
|
-
correction_bias=self.gate.e_score_correction_bias,
|
358
|
-
)
|
359
|
-
if should_use_flashinfer_trtllm_moe()
|
360
|
-
else {}
|
361
|
-
),
|
334
|
+
)
|
335
|
+
|
336
|
+
self.topk = TopK(
|
337
|
+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
338
|
+
renormalize=config.norm_topk_prob,
|
339
|
+
use_grouped_topk=True,
|
340
|
+
num_expert_group=config.n_group,
|
341
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
342
|
+
topk_group=config.topk_group,
|
343
|
+
correction_bias=self.gate.e_score_correction_bias,
|
344
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
345
|
+
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
346
|
+
force_topk=quant_config is None,
|
362
347
|
)
|
363
348
|
|
364
349
|
self.shared_experts_is_int8 = False
|
@@ -366,7 +351,7 @@ class DeepseekV2MoE(nn.Module):
|
|
366
351
|
self.shared_experts_weight_block_size = None
|
367
352
|
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
368
353
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
369
|
-
# disable tp for shared experts when enable deepep moe
|
354
|
+
# disable tp for shared experts when enable deepep moe, or with fp4 allgather
|
370
355
|
self.shared_experts = DeepseekV2MLP(
|
371
356
|
hidden_size=config.hidden_size,
|
372
357
|
intermediate_size=intermediate_size,
|
@@ -376,7 +361,8 @@ class DeepseekV2MoE(nn.Module):
|
|
376
361
|
prefix=add_prefix("shared_experts", prefix),
|
377
362
|
**(
|
378
363
|
dict(tp_rank=0, tp_size=1)
|
379
|
-
if
|
364
|
+
if get_moe_a2a_backend().is_deepep()
|
365
|
+
or should_use_flashinfer_cutlass_moe_fp4_allgather()
|
380
366
|
else {}
|
381
367
|
),
|
382
368
|
)
|
@@ -406,7 +392,7 @@ class DeepseekV2MoE(nn.Module):
|
|
406
392
|
|
407
393
|
self.top_k = config.num_experts_per_tok
|
408
394
|
|
409
|
-
if
|
395
|
+
if get_moe_a2a_backend().is_deepep():
|
410
396
|
# TODO: we will support tp < ep in the future
|
411
397
|
self.ep_size = get_moe_expert_parallel_world_size()
|
412
398
|
self.num_experts = (
|
@@ -430,12 +416,12 @@ class DeepseekV2MoE(nn.Module):
|
|
430
416
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
431
417
|
hidden_size=config.hidden_size,
|
432
418
|
params_dtype=config.torch_dtype,
|
433
|
-
deepep_mode=
|
419
|
+
deepep_mode=get_deepep_mode(),
|
434
420
|
async_finish=True,
|
435
421
|
return_recv_hook=True,
|
436
422
|
)
|
437
423
|
|
438
|
-
self._enable_deepep_moe =
|
424
|
+
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
|
439
425
|
|
440
426
|
def get_moe_weights(self):
|
441
427
|
return [
|
@@ -456,14 +442,19 @@ class DeepseekV2MoE(nn.Module):
|
|
456
442
|
if (
|
457
443
|
self.alt_stream is not None
|
458
444
|
and self.num_fused_shared_experts == 0
|
445
|
+
and hidden_states.shape[0] > 0
|
459
446
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
460
447
|
):
|
461
448
|
return self.forward_normal_dual_stream(
|
462
|
-
hidden_states,
|
449
|
+
hidden_states,
|
450
|
+
should_allreduce_fusion,
|
451
|
+
use_reduce_scatter,
|
463
452
|
)
|
464
453
|
else:
|
465
454
|
return self.forward_normal(
|
466
|
-
hidden_states,
|
455
|
+
hidden_states,
|
456
|
+
should_allreduce_fusion,
|
457
|
+
use_reduce_scatter,
|
467
458
|
)
|
468
459
|
else:
|
469
460
|
return self.forward_deepep(hidden_states, forward_batch)
|
@@ -482,25 +473,24 @@ class DeepseekV2MoE(nn.Module):
|
|
482
473
|
with torch.cuda.stream(self.alt_stream):
|
483
474
|
# router_logits: (num_tokens, n_experts)
|
484
475
|
router_logits = self.gate(hidden_states)
|
485
|
-
|
486
|
-
|
487
|
-
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
488
|
-
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
489
|
-
if should_use_flashinfer_trtllm_moe():
|
490
|
-
kwargs["topk_output"] = (self.topk, router_logits)
|
491
|
-
else:
|
492
|
-
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
493
|
-
|
494
|
-
final_hidden_states = self.experts(**kwargs)
|
476
|
+
topk_output = self.topk(hidden_states, router_logits)
|
477
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
495
478
|
if not _is_cuda:
|
496
479
|
final_hidden_states *= self.routed_scaling_factor
|
480
|
+
|
497
481
|
current_stream.wait_stream(self.alt_stream)
|
498
482
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
499
483
|
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
484
|
+
|
500
485
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
501
486
|
final_hidden_states = final_hidden_states_out
|
502
487
|
sm.tag(final_hidden_states)
|
503
|
-
if
|
488
|
+
if (
|
489
|
+
self.tp_size > 1
|
490
|
+
and not should_allreduce_fusion
|
491
|
+
and not use_reduce_scatter
|
492
|
+
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
|
493
|
+
):
|
504
494
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
505
495
|
return final_hidden_states
|
506
496
|
|
@@ -515,19 +505,16 @@ class DeepseekV2MoE(nn.Module):
|
|
515
505
|
):
|
516
506
|
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
517
507
|
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
524
|
-
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
525
|
-
if should_use_flashinfer_trtllm_moe():
|
526
|
-
kwargs["topk_output"] = (self.topk, router_logits)
|
508
|
+
if hidden_states.shape[0] > 0:
|
509
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
510
|
+
# router_logits: (num_tokens, n_experts)
|
511
|
+
router_logits = self.gate(hidden_states)
|
512
|
+
topk_output = self.topk(hidden_states, router_logits)
|
527
513
|
else:
|
528
|
-
|
514
|
+
shared_output = None
|
515
|
+
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
529
516
|
|
530
|
-
final_hidden_states = self.experts(
|
517
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
531
518
|
if not _is_cuda and not _use_aiter:
|
532
519
|
# fused in biased_grouped_topk so we can skip here
|
533
520
|
final_hidden_states *= self.routed_scaling_factor
|
@@ -537,7 +524,12 @@ class DeepseekV2MoE(nn.Module):
|
|
537
524
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
538
525
|
final_hidden_states = final_hidden_states_out
|
539
526
|
sm.tag(final_hidden_states)
|
540
|
-
if
|
527
|
+
if (
|
528
|
+
self.tp_size > 1
|
529
|
+
and not should_allreduce_fusion
|
530
|
+
and not use_reduce_scatter
|
531
|
+
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
|
532
|
+
):
|
541
533
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
542
534
|
return final_hidden_states
|
543
535
|
|
@@ -616,11 +608,8 @@ class DeepseekV2MoE(nn.Module):
|
|
616
608
|
),
|
617
609
|
)
|
618
610
|
else:
|
619
|
-
topk_idx =
|
620
|
-
|
621
|
-
)
|
622
|
-
topk_weights = torch.empty(
|
623
|
-
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
611
|
+
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
|
612
|
+
hidden_states.device
|
624
613
|
)
|
625
614
|
|
626
615
|
final_hidden_states = self.experts(
|
@@ -1006,29 +995,33 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1006
995
|
|
1007
996
|
if attention_backend == "ascend":
|
1008
997
|
return AttnForwardMethod.MLA
|
1009
|
-
elif
|
998
|
+
elif (
|
999
|
+
attention_backend == "flashinfer"
|
1000
|
+
or attention_backend == "fa3"
|
1001
|
+
or attention_backend == "flashmla"
|
1002
|
+
or attention_backend == "trtllm_mla"
|
1003
|
+
or attention_backend == "cutlass_mla"
|
1004
|
+
):
|
1005
|
+
# Use MHA with chunked KV cache when prefilling on long sequences.
|
1006
|
+
sum_extend_prefix_lens = (
|
1007
|
+
sum(forward_batch.extend_prefix_lens_cpu)
|
1008
|
+
if forward_batch.extend_prefix_lens_cpu is not None
|
1009
|
+
else 0
|
1010
|
+
)
|
1010
1011
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
1012
|
+
disable_ragged = (
|
1013
|
+
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
1014
|
+
) and self.flashinfer_mla_disable_ragged
|
1011
1015
|
if (
|
1012
|
-
not
|
1016
|
+
not disable_ragged
|
1013
1017
|
and forward_batch.forward_mode.is_extend()
|
1014
1018
|
and not forward_batch.forward_mode.is_target_verify()
|
1015
1019
|
and not forward_batch.forward_mode.is_draft_extend()
|
1016
|
-
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
1017
|
-
):
|
1018
|
-
return AttnForwardMethod.MHA
|
1019
|
-
else:
|
1020
|
-
return _dispatch_mla_subtype()
|
1021
|
-
elif attention_backend == "fa3":
|
1022
|
-
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
1023
|
-
if forward_batch.extend_prefix_lens_cpu is not None:
|
1024
|
-
sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
|
1025
|
-
if (
|
1026
|
-
forward_batch.forward_mode.is_extend()
|
1027
|
-
and not self.disable_chunked_prefix_cache
|
1028
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1029
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1030
1020
|
and (
|
1031
|
-
|
1021
|
+
(
|
1022
|
+
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
|
1023
|
+
and not self.disable_chunked_prefix_cache
|
1024
|
+
)
|
1032
1025
|
or sum_extend_prefix_lens == 0
|
1033
1026
|
)
|
1034
1027
|
):
|
@@ -1696,7 +1689,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1696
1689
|
k[..., self.qk_nope_head_dim :] = k_pe
|
1697
1690
|
|
1698
1691
|
output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
1699
|
-
lse = torch.transpose(lse, 0, 1).contiguous()
|
1700
1692
|
tmp_output = torch.empty_like(accum_output)
|
1701
1693
|
tmp_lse = torch.empty_like(accum_lse)
|
1702
1694
|
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
|
@@ -1718,55 +1710,26 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1718
1710
|
# will be helpful for understanding the purpose of this function.
|
1719
1711
|
|
1720
1712
|
# First do normal mha forward to get output for extended part
|
1721
|
-
|
1722
|
-
|
1723
|
-
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
1724
|
-
)
|
1725
|
-
q = self.q_a_layernorm(q)
|
1726
|
-
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
1727
|
-
else:
|
1728
|
-
q = self.q_proj(hidden_states)[0].view(
|
1729
|
-
-1, self.num_local_heads, self.qk_head_dim
|
1730
|
-
)
|
1731
|
-
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
1732
|
-
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
1733
|
-
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
1734
|
-
latent_cache = latent_cache.unsqueeze(1)
|
1735
|
-
kv_a = self.kv_a_layernorm(kv_a)
|
1736
|
-
kv = self.kv_b_proj(kv_a)[0]
|
1737
|
-
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
1738
|
-
k_nope = kv[..., : self.qk_nope_head_dim]
|
1739
|
-
v = kv[..., self.qk_nope_head_dim :]
|
1740
|
-
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
1741
|
-
|
1742
|
-
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1743
|
-
q[..., self.qk_nope_head_dim :] = q_pe
|
1744
|
-
k = torch.empty_like(q)
|
1745
|
-
k[..., : self.qk_nope_head_dim] = k_nope
|
1746
|
-
k[..., self.qk_nope_head_dim :] = k_pe
|
1747
|
-
|
1748
|
-
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
1749
|
-
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
1750
|
-
|
1751
|
-
# Save latent cache
|
1752
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1753
|
-
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
1713
|
+
return self.forward_normal_prepare(
|
1714
|
+
positions, hidden_states, forward_batch, zero_allocator
|
1754
1715
|
)
|
1755
1716
|
|
1756
|
-
return q, k, v, forward_batch
|
1757
|
-
|
1758
1717
|
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
|
1718
|
+
has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
|
1719
|
+
# Only initialize the info once
|
1720
|
+
if has_extend_prefix and forward_batch.num_prefix_chunks is None:
|
1721
|
+
forward_batch.prepare_chunked_prefix_cache_info(q.device)
|
1722
|
+
if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
|
1723
|
+
forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
|
1724
|
+
|
1725
|
+
forward_batch.mha_return_lse = has_extend_prefix
|
1759
1726
|
# Do mha for extended part without prefix
|
1760
1727
|
forward_batch.set_attn_attend_prefix_cache(False)
|
1761
|
-
attn_output
|
1762
|
-
lse = torch.transpose(lse, 0, 1).contiguous()
|
1728
|
+
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
1763
1729
|
|
1764
1730
|
# Do mha attention with chunked prefix cache if there are any sequence with prefix
|
1765
|
-
if
|
1766
|
-
|
1767
|
-
if forward_batch.num_prefix_chunks is None:
|
1768
|
-
forward_batch.prepare_chunked_prefix_cache_info(q.device)
|
1769
|
-
|
1731
|
+
if has_extend_prefix:
|
1732
|
+
attn_output, lse = attn_output
|
1770
1733
|
forward_batch.set_attn_attend_prefix_cache(True)
|
1771
1734
|
attn_output = self._chunked_prefix_attn_mha(
|
1772
1735
|
q=q,
|
@@ -1797,7 +1760,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1797
1760
|
rope_theta = getattr(config, "rope_theta", 10000)
|
1798
1761
|
rope_scaling = getattr(config, "rope_scaling", None)
|
1799
1762
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
1800
|
-
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
1801
1763
|
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
|
1802
1764
|
self.layer_id = layer_id
|
1803
1765
|
self.is_nextn = is_nextn
|
@@ -1866,10 +1828,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1866
1828
|
input_layernorm=self.input_layernorm,
|
1867
1829
|
post_attention_layernorm=self.post_attention_layernorm,
|
1868
1830
|
allow_reduce_scatter=True,
|
1831
|
+
is_last_layer=(
|
1832
|
+
is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
|
1833
|
+
),
|
1869
1834
|
)
|
1870
1835
|
|
1871
|
-
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
1872
|
-
|
1873
1836
|
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
1874
1837
|
return is_nextn or (
|
1875
1838
|
self.config.n_routed_experts is not None
|
@@ -1877,20 +1840,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1877
1840
|
and layer_id % self.config.moe_layer_freq == 0
|
1878
1841
|
)
|
1879
1842
|
|
1880
|
-
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
1881
|
-
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
1882
|
-
|
1883
|
-
batch_size = (
|
1884
|
-
forward_batch.input_ids.shape[0]
|
1885
|
-
if hasattr(forward_batch, "input_ids")
|
1886
|
-
else 0
|
1887
|
-
)
|
1888
|
-
|
1889
|
-
if batch_size > 128:
|
1890
|
-
return False
|
1891
|
-
|
1892
|
-
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
1893
|
-
|
1894
1843
|
def forward(
|
1895
1844
|
self,
|
1896
1845
|
positions: torch.Tensor,
|
@@ -1916,9 +1865,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1916
1865
|
)
|
1917
1866
|
|
1918
1867
|
should_allreduce_fusion = (
|
1919
|
-
self.
|
1920
|
-
|
1921
|
-
|
1868
|
+
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
|
1869
|
+
forward_batch
|
1870
|
+
)
|
1922
1871
|
)
|
1923
1872
|
|
1924
1873
|
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
@@ -2009,26 +1958,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
2009
1958
|
)
|
2010
1959
|
return output
|
2011
1960
|
|
2012
|
-
def _build_fuse_allreduce_lookup_table(self):
|
2013
|
-
static_conditions_met = (
|
2014
|
-
self.layer_id != self.config.num_hidden_layers - 1
|
2015
|
-
and get_tensor_model_parallel_world_size() > 1
|
2016
|
-
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
2017
|
-
and _is_sm100_supported
|
2018
|
-
and _is_flashinfer_available
|
2019
|
-
)
|
2020
|
-
|
2021
|
-
if not static_conditions_met:
|
2022
|
-
return {}
|
2023
|
-
|
2024
|
-
lookup_table = {}
|
2025
|
-
for batch_size in range(129): # 0 to 128
|
2026
|
-
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
2027
|
-
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
2028
|
-
lookup_table[batch_size] = should_fuse
|
2029
|
-
|
2030
|
-
return lookup_table
|
2031
|
-
|
2032
1961
|
|
2033
1962
|
class DeepseekV2Model(nn.Module):
|
2034
1963
|
fall_back_to_pt_during_load = False
|
@@ -2043,26 +1972,52 @@ class DeepseekV2Model(nn.Module):
|
|
2043
1972
|
self.padding_id = config.pad_token_id
|
2044
1973
|
self.vocab_size = config.vocab_size
|
2045
1974
|
self.first_k_dense_replace = config.first_k_dense_replace
|
1975
|
+
self.pp_group = get_pp_group()
|
1976
|
+
|
1977
|
+
if self.pp_group.is_first_rank:
|
1978
|
+
self.embed_tokens = VocabParallelEmbedding(
|
1979
|
+
config.vocab_size,
|
1980
|
+
config.hidden_size,
|
1981
|
+
enable_tp=not is_dp_attention_enabled(),
|
1982
|
+
)
|
1983
|
+
else:
|
1984
|
+
self.embed_tokens = PPMissingLayer()
|
2046
1985
|
|
2047
|
-
self.embed_tokens = VocabParallelEmbedding(
|
2048
|
-
config.vocab_size,
|
2049
|
-
config.hidden_size,
|
2050
|
-
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
2051
|
-
)
|
2052
1986
|
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
2053
|
-
self.layers =
|
2054
|
-
|
2055
|
-
|
2056
|
-
|
2057
|
-
|
2058
|
-
|
2059
|
-
|
2060
|
-
|
2061
|
-
|
2062
|
-
|
2063
|
-
|
1987
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
1988
|
+
config.num_hidden_layers,
|
1989
|
+
lambda idx, prefix: DeepseekV2DecoderLayer(
|
1990
|
+
config=config,
|
1991
|
+
layer_id=idx,
|
1992
|
+
quant_config=quant_config,
|
1993
|
+
prefix=prefix,
|
1994
|
+
alt_stream=self.alt_stream,
|
1995
|
+
),
|
1996
|
+
pp_rank=self.pp_group.rank_in_group,
|
1997
|
+
pp_size=self.pp_group.world_size,
|
1998
|
+
prefix=add_prefix("layers", prefix),
|
1999
|
+
offloader_kwargs=dict(
|
2000
|
+
submodule_accessor=lambda layer: (
|
2001
|
+
layer.mlp.experts
|
2002
|
+
if isinstance(layer.mlp, DeepseekV2MoE)
|
2003
|
+
else layer.mlp
|
2004
|
+
),
|
2005
|
+
whitelist_param_names_creator=lambda module: (
|
2006
|
+
[
|
2007
|
+
"w13_weight",
|
2008
|
+
"w2_weight",
|
2009
|
+
"w13_blockscale_swizzled",
|
2010
|
+
"w2_blockscale_swizzled",
|
2011
|
+
]
|
2012
|
+
if isinstance(module, FusedMoE)
|
2013
|
+
else []
|
2014
|
+
),
|
2015
|
+
),
|
2064
2016
|
)
|
2065
|
-
self.
|
2017
|
+
if self.pp_group.is_last_rank:
|
2018
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
2019
|
+
else:
|
2020
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
2066
2021
|
|
2067
2022
|
def get_input_embeddings(self) -> torch.Tensor:
|
2068
2023
|
return self.embed_tokens
|
@@ -2073,8 +2028,9 @@ class DeepseekV2Model(nn.Module):
|
|
2073
2028
|
positions: torch.Tensor,
|
2074
2029
|
forward_batch: ForwardBatch,
|
2075
2030
|
input_embeds: torch.Tensor = None,
|
2076
|
-
|
2077
|
-
|
2031
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
2032
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
2033
|
+
total_num_layers = self.end_layer - self.start_layer
|
2078
2034
|
device = input_embeds.device if input_embeds is not None else input_ids.device
|
2079
2035
|
zero_allocator = BumpAllocator(
|
2080
2036
|
buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
|
@@ -2082,44 +2038,62 @@ class DeepseekV2Model(nn.Module):
|
|
2082
2038
|
device=device,
|
2083
2039
|
)
|
2084
2040
|
|
2085
|
-
if
|
2086
|
-
|
2041
|
+
if self.pp_group.is_first_rank:
|
2042
|
+
if input_embeds is None:
|
2043
|
+
hidden_states = self.embed_tokens(input_ids)
|
2044
|
+
else:
|
2045
|
+
hidden_states = input_embeds
|
2046
|
+
residual = None
|
2087
2047
|
else:
|
2088
|
-
|
2048
|
+
assert pp_proxy_tensors is not None
|
2049
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
2050
|
+
residual = pp_proxy_tensors["residual"]
|
2089
2051
|
|
2090
|
-
|
2052
|
+
normal_start_layer = self.start_layer
|
2053
|
+
normal_end_layer = self.end_layer
|
2054
|
+
if forward_batch.can_run_tbo:
|
2055
|
+
if (
|
2056
|
+
self.first_k_dense_replace > normal_start_layer
|
2057
|
+
and self.first_k_dense_replace < normal_end_layer
|
2058
|
+
):
|
2059
|
+
normal_end_layer = self.first_k_dense_replace
|
2060
|
+
elif self.first_k_dense_replace < normal_start_layer:
|
2061
|
+
normal_end_layer = normal_start_layer = 0
|
2091
2062
|
|
2092
|
-
|
2093
|
-
self.first_k_dense_replace
|
2094
|
-
if forward_batch.can_run_tbo
|
2095
|
-
else total_num_layers
|
2096
|
-
)
|
2097
|
-
for i in range(normal_num_layers):
|
2063
|
+
for i in range(normal_start_layer, normal_end_layer):
|
2098
2064
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
2099
2065
|
layer = self.layers[i]
|
2100
2066
|
hidden_states, residual = layer(
|
2101
2067
|
positions, hidden_states, forward_batch, residual, zero_allocator
|
2102
2068
|
)
|
2103
2069
|
|
2104
|
-
if
|
2070
|
+
if normal_end_layer != self.end_layer:
|
2105
2071
|
hidden_states, residual = model_forward_maybe_tbo(
|
2106
|
-
layers=self.layers[
|
2072
|
+
layers=self.layers[normal_end_layer : self.end_layer],
|
2107
2073
|
enable_tbo=True,
|
2108
2074
|
positions=positions,
|
2109
2075
|
forward_batch=forward_batch,
|
2110
2076
|
hidden_states=hidden_states,
|
2111
2077
|
residual=residual,
|
2112
2078
|
input_data_scatter_mode=self.layers[
|
2113
|
-
|
2079
|
+
normal_end_layer - 1
|
2114
2080
|
].layer_scatter_modes.layer_output_mode,
|
2115
2081
|
zero_allocator=zero_allocator,
|
2116
2082
|
)
|
2117
2083
|
|
2118
|
-
if not
|
2119
|
-
|
2120
|
-
|
2121
|
-
|
2122
|
-
|
2084
|
+
if not self.pp_group.is_last_rank:
|
2085
|
+
return PPProxyTensors(
|
2086
|
+
{
|
2087
|
+
"hidden_states": hidden_states,
|
2088
|
+
"residual": residual,
|
2089
|
+
}
|
2090
|
+
)
|
2091
|
+
else:
|
2092
|
+
if not forward_batch.forward_mode.is_idle():
|
2093
|
+
if residual is None:
|
2094
|
+
hidden_states = self.norm(hidden_states)
|
2095
|
+
else:
|
2096
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
2123
2097
|
return hidden_states
|
2124
2098
|
|
2125
2099
|
|
@@ -2146,6 +2120,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2146
2120
|
"kv_a_proj_with_mqa",
|
2147
2121
|
]
|
2148
2122
|
|
2123
|
+
self.pp_group = get_pp_group()
|
2149
2124
|
self.config = config
|
2150
2125
|
self.tp_size = get_tensor_model_parallel_world_size()
|
2151
2126
|
self.quant_config = quant_config
|
@@ -2215,13 +2190,27 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2215
2190
|
positions: torch.Tensor,
|
2216
2191
|
forward_batch: ForwardBatch,
|
2217
2192
|
input_embeds: torch.Tensor = None,
|
2193
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
2218
2194
|
) -> torch.Tensor:
|
2219
|
-
hidden_states = self.model(
|
2220
|
-
|
2221
|
-
return self.logits_processor(
|
2222
|
-
input_ids, hidden_states, self.lm_head, forward_batch
|
2195
|
+
hidden_states = self.model(
|
2196
|
+
input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
|
2223
2197
|
)
|
2224
2198
|
|
2199
|
+
if self.pp_group.is_last_rank:
|
2200
|
+
return self.logits_processor(
|
2201
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
2202
|
+
)
|
2203
|
+
else:
|
2204
|
+
return hidden_states
|
2205
|
+
|
2206
|
+
@property
|
2207
|
+
def start_layer(self):
|
2208
|
+
return self.model.start_layer
|
2209
|
+
|
2210
|
+
@property
|
2211
|
+
def end_layer(self):
|
2212
|
+
return self.model.end_layer
|
2213
|
+
|
2225
2214
|
def post_load_weights(self, is_nextn=False, weight_names=None):
|
2226
2215
|
|
2227
2216
|
# Perform post-processing after loading weights
|
@@ -2229,7 +2218,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2229
2218
|
layer_ids = [self.config.num_hidden_layers]
|
2230
2219
|
else:
|
2231
2220
|
if weight_names is None:
|
2232
|
-
layer_ids = range(self.
|
2221
|
+
layer_ids = range(self.model.start_layer, self.model.end_layer)
|
2233
2222
|
else:
|
2234
2223
|
layer_ids = set()
|
2235
2224
|
for name in weight_names:
|
@@ -2476,17 +2465,15 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2476
2465
|
|
2477
2466
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
2478
2467
|
# (param_name, weight_name, expert_id, shard_id)
|
2479
|
-
expert_params_mapping =
|
2468
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
2480
2469
|
ckpt_gate_proj_name="gate_proj",
|
2481
2470
|
ckpt_down_proj_name="down_proj",
|
2482
2471
|
ckpt_up_proj_name="up_proj",
|
2483
2472
|
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
2484
2473
|
)
|
2485
2474
|
if self.quant_config and self.quant_config.get_name() == "w4afp8":
|
2486
|
-
expert_params_mapping += (
|
2487
|
-
|
2488
|
-
num_experts=self.config.n_routed_experts
|
2489
|
-
)
|
2475
|
+
expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
|
2476
|
+
num_experts=self.config.n_routed_experts
|
2490
2477
|
)
|
2491
2478
|
|
2492
2479
|
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
@@ -2513,6 +2500,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2513
2500
|
params_dict = dict(self.named_parameters())
|
2514
2501
|
weight_names = []
|
2515
2502
|
for name, loaded_weight in weights:
|
2503
|
+
layer_id = get_layer_id(name)
|
2504
|
+
if (
|
2505
|
+
layer_id is not None
|
2506
|
+
and hasattr(self.model, "start_layer")
|
2507
|
+
and (
|
2508
|
+
layer_id < self.model.start_layer
|
2509
|
+
or layer_id >= self.model.end_layer
|
2510
|
+
)
|
2511
|
+
):
|
2512
|
+
continue
|
2516
2513
|
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
|
2517
2514
|
name = name.replace(
|
2518
2515
|
"mlp.shared_experts",
|
@@ -2597,6 +2594,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2597
2594
|
# Skip loading extra bias for GPTQ models.
|
2598
2595
|
if name.endswith(".bias") and name not in params_dict:
|
2599
2596
|
continue
|
2597
|
+
# Skip loading embed_tokens if not first rank in pipeline parallelism
|
2598
|
+
if ".embed_tokens." in name and not self.pp_group.is_first_rank:
|
2599
|
+
continue
|
2600
|
+
# Skip loading norm if not last rank in pipeline parallelism
|
2601
|
+
if ".norm." in name and not self.pp_group.is_last_rank:
|
2602
|
+
continue
|
2600
2603
|
if fuse_qkv_a_proj and (
|
2601
2604
|
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
2602
2605
|
):
|