sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -20,7 +20,7 @@ import concurrent.futures
|
|
20
20
|
import logging
|
21
21
|
import os
|
22
22
|
from enum import IntEnum, auto
|
23
|
-
from typing import Any, Dict, Iterable, Optional, Tuple
|
23
|
+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
24
24
|
|
25
25
|
import torch
|
26
26
|
import torch.nn.functional as F
|
@@ -30,6 +30,7 @@ from transformers import PretrainedConfig
|
|
30
30
|
|
31
31
|
from sglang.srt.distributed import (
|
32
32
|
get_moe_expert_parallel_world_size,
|
33
|
+
get_pp_group,
|
33
34
|
get_tensor_model_parallel_world_size,
|
34
35
|
parallel_state,
|
35
36
|
tensor_model_parallel_all_reduce,
|
@@ -50,7 +51,6 @@ from sglang.srt.layers.communicator import (
|
|
50
51
|
from sglang.srt.layers.dp_attention import (
|
51
52
|
get_attention_tp_rank,
|
52
53
|
get_attention_tp_size,
|
53
|
-
get_local_attention_dp_size,
|
54
54
|
is_dp_attention_enabled,
|
55
55
|
)
|
56
56
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -61,9 +61,14 @@ from sglang.srt.layers.linear import (
|
|
61
61
|
RowParallelLinear,
|
62
62
|
)
|
63
63
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
64
|
+
from sglang.srt.layers.moe import (
|
65
|
+
get_deepep_mode,
|
66
|
+
get_moe_a2a_backend,
|
67
|
+
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
68
|
+
)
|
64
69
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
70
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
65
71
|
from sglang.srt.layers.moe.topk import TopK
|
66
|
-
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
67
72
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
68
73
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
69
74
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
@@ -83,13 +88,13 @@ from sglang.srt.layers.quantization.int8_utils import (
|
|
83
88
|
)
|
84
89
|
from sglang.srt.layers.radix_attention import RadixAttention
|
85
90
|
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
86
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
91
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
|
87
92
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
88
93
|
ParallelLMHead,
|
89
94
|
VocabParallelEmbedding,
|
90
95
|
)
|
91
96
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
92
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
97
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
93
98
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
94
99
|
from sglang.srt.two_batch_overlap import (
|
95
100
|
MaybeTboDeepEPDispatcher,
|
@@ -110,6 +115,7 @@ from sglang.srt.utils import (
|
|
110
115
|
is_hip,
|
111
116
|
is_non_idle_and_non_empty,
|
112
117
|
log_info_on_rank0,
|
118
|
+
make_layers,
|
113
119
|
use_intel_amx_backend,
|
114
120
|
)
|
115
121
|
|
@@ -313,18 +319,7 @@ class DeepseekV2MoE(nn.Module):
|
|
313
319
|
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
314
320
|
)
|
315
321
|
|
316
|
-
self.
|
317
|
-
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
318
|
-
renormalize=config.norm_topk_prob,
|
319
|
-
use_grouped_topk=True,
|
320
|
-
num_expert_group=config.n_group,
|
321
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
322
|
-
topk_group=config.topk_group,
|
323
|
-
correction_bias=self.gate.e_score_correction_bias,
|
324
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
325
|
-
)
|
326
|
-
|
327
|
-
self.experts = get_moe_impl_class()(
|
322
|
+
self.experts = get_moe_impl_class(quant_config)(
|
328
323
|
num_experts=config.n_routed_experts
|
329
324
|
+ self.num_fused_shared_experts
|
330
325
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
@@ -336,30 +331,19 @@ class DeepseekV2MoE(nn.Module):
|
|
336
331
|
quant_config=quant_config,
|
337
332
|
routed_scaling_factor=self.routed_scaling_factor,
|
338
333
|
prefix=add_prefix("experts", prefix),
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
**(
|
353
|
-
dict(
|
354
|
-
renormalize=config.norm_topk_prob,
|
355
|
-
use_grouped_topk=True,
|
356
|
-
num_expert_group=config.n_group,
|
357
|
-
topk_group=config.topk_group,
|
358
|
-
correction_bias=self.gate.e_score_correction_bias,
|
359
|
-
)
|
360
|
-
if should_use_flashinfer_trtllm_moe()
|
361
|
-
else {}
|
362
|
-
),
|
334
|
+
)
|
335
|
+
|
336
|
+
self.topk = TopK(
|
337
|
+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
338
|
+
renormalize=config.norm_topk_prob,
|
339
|
+
use_grouped_topk=True,
|
340
|
+
num_expert_group=config.n_group,
|
341
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
342
|
+
topk_group=config.topk_group,
|
343
|
+
correction_bias=self.gate.e_score_correction_bias,
|
344
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
345
|
+
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
346
|
+
force_topk=quant_config is None,
|
363
347
|
)
|
364
348
|
|
365
349
|
self.shared_experts_is_int8 = False
|
@@ -367,7 +351,7 @@ class DeepseekV2MoE(nn.Module):
|
|
367
351
|
self.shared_experts_weight_block_size = None
|
368
352
|
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
369
353
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
370
|
-
# disable tp for shared experts when enable deepep moe
|
354
|
+
# disable tp for shared experts when enable deepep moe, or with fp4 allgather
|
371
355
|
self.shared_experts = DeepseekV2MLP(
|
372
356
|
hidden_size=config.hidden_size,
|
373
357
|
intermediate_size=intermediate_size,
|
@@ -377,7 +361,8 @@ class DeepseekV2MoE(nn.Module):
|
|
377
361
|
prefix=add_prefix("shared_experts", prefix),
|
378
362
|
**(
|
379
363
|
dict(tp_rank=0, tp_size=1)
|
380
|
-
if
|
364
|
+
if get_moe_a2a_backend().is_deepep()
|
365
|
+
or should_use_flashinfer_cutlass_moe_fp4_allgather()
|
381
366
|
else {}
|
382
367
|
),
|
383
368
|
)
|
@@ -407,7 +392,7 @@ class DeepseekV2MoE(nn.Module):
|
|
407
392
|
|
408
393
|
self.top_k = config.num_experts_per_tok
|
409
394
|
|
410
|
-
if
|
395
|
+
if get_moe_a2a_backend().is_deepep():
|
411
396
|
# TODO: we will support tp < ep in the future
|
412
397
|
self.ep_size = get_moe_expert_parallel_world_size()
|
413
398
|
self.num_experts = (
|
@@ -431,12 +416,12 @@ class DeepseekV2MoE(nn.Module):
|
|
431
416
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
432
417
|
hidden_size=config.hidden_size,
|
433
418
|
params_dtype=config.torch_dtype,
|
434
|
-
deepep_mode=
|
419
|
+
deepep_mode=get_deepep_mode(),
|
435
420
|
async_finish=True,
|
436
421
|
return_recv_hook=True,
|
437
422
|
)
|
438
423
|
|
439
|
-
self._enable_deepep_moe =
|
424
|
+
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
|
440
425
|
|
441
426
|
def get_moe_weights(self):
|
442
427
|
return [
|
@@ -457,14 +442,19 @@ class DeepseekV2MoE(nn.Module):
|
|
457
442
|
if (
|
458
443
|
self.alt_stream is not None
|
459
444
|
and self.num_fused_shared_experts == 0
|
445
|
+
and hidden_states.shape[0] > 0
|
460
446
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
461
447
|
):
|
462
448
|
return self.forward_normal_dual_stream(
|
463
|
-
hidden_states,
|
449
|
+
hidden_states,
|
450
|
+
should_allreduce_fusion,
|
451
|
+
use_reduce_scatter,
|
464
452
|
)
|
465
453
|
else:
|
466
454
|
return self.forward_normal(
|
467
|
-
hidden_states,
|
455
|
+
hidden_states,
|
456
|
+
should_allreduce_fusion,
|
457
|
+
use_reduce_scatter,
|
468
458
|
)
|
469
459
|
else:
|
470
460
|
return self.forward_deepep(hidden_states, forward_batch)
|
@@ -483,25 +473,24 @@ class DeepseekV2MoE(nn.Module):
|
|
483
473
|
with torch.cuda.stream(self.alt_stream):
|
484
474
|
# router_logits: (num_tokens, n_experts)
|
485
475
|
router_logits = self.gate(hidden_states)
|
486
|
-
|
487
|
-
|
488
|
-
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
489
|
-
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
490
|
-
if should_use_flashinfer_trtllm_moe():
|
491
|
-
kwargs["topk_output"] = (self.topk, router_logits)
|
492
|
-
else:
|
493
|
-
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
494
|
-
|
495
|
-
final_hidden_states = self.experts(**kwargs)
|
476
|
+
topk_output = self.topk(hidden_states, router_logits)
|
477
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
496
478
|
if not _is_cuda:
|
497
479
|
final_hidden_states *= self.routed_scaling_factor
|
480
|
+
|
498
481
|
current_stream.wait_stream(self.alt_stream)
|
499
482
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
500
483
|
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
484
|
+
|
501
485
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
502
486
|
final_hidden_states = final_hidden_states_out
|
503
487
|
sm.tag(final_hidden_states)
|
504
|
-
if
|
488
|
+
if (
|
489
|
+
self.tp_size > 1
|
490
|
+
and not should_allreduce_fusion
|
491
|
+
and not use_reduce_scatter
|
492
|
+
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
|
493
|
+
):
|
505
494
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
506
495
|
return final_hidden_states
|
507
496
|
|
@@ -516,19 +505,16 @@ class DeepseekV2MoE(nn.Module):
|
|
516
505
|
):
|
517
506
|
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
518
507
|
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
525
|
-
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
526
|
-
if should_use_flashinfer_trtllm_moe():
|
527
|
-
kwargs["topk_output"] = (self.topk, router_logits)
|
508
|
+
if hidden_states.shape[0] > 0:
|
509
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
510
|
+
# router_logits: (num_tokens, n_experts)
|
511
|
+
router_logits = self.gate(hidden_states)
|
512
|
+
topk_output = self.topk(hidden_states, router_logits)
|
528
513
|
else:
|
529
|
-
|
514
|
+
shared_output = None
|
515
|
+
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
530
516
|
|
531
|
-
final_hidden_states = self.experts(
|
517
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
532
518
|
if not _is_cuda and not _use_aiter:
|
533
519
|
# fused in biased_grouped_topk so we can skip here
|
534
520
|
final_hidden_states *= self.routed_scaling_factor
|
@@ -538,7 +524,12 @@ class DeepseekV2MoE(nn.Module):
|
|
538
524
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
539
525
|
final_hidden_states = final_hidden_states_out
|
540
526
|
sm.tag(final_hidden_states)
|
541
|
-
if
|
527
|
+
if (
|
528
|
+
self.tp_size > 1
|
529
|
+
and not should_allreduce_fusion
|
530
|
+
and not use_reduce_scatter
|
531
|
+
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
|
532
|
+
):
|
542
533
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
543
534
|
return final_hidden_states
|
544
535
|
|
@@ -617,11 +608,8 @@ class DeepseekV2MoE(nn.Module):
|
|
617
608
|
),
|
618
609
|
)
|
619
610
|
else:
|
620
|
-
topk_idx =
|
621
|
-
|
622
|
-
)
|
623
|
-
topk_weights = torch.empty(
|
624
|
-
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
611
|
+
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
|
612
|
+
hidden_states.device
|
625
613
|
)
|
626
614
|
|
627
615
|
final_hidden_states = self.experts(
|
@@ -1007,29 +995,33 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1007
995
|
|
1008
996
|
if attention_backend == "ascend":
|
1009
997
|
return AttnForwardMethod.MLA
|
1010
|
-
elif
|
998
|
+
elif (
|
999
|
+
attention_backend == "flashinfer"
|
1000
|
+
or attention_backend == "fa3"
|
1001
|
+
or attention_backend == "flashmla"
|
1002
|
+
or attention_backend == "trtllm_mla"
|
1003
|
+
or attention_backend == "cutlass_mla"
|
1004
|
+
):
|
1005
|
+
# Use MHA with chunked KV cache when prefilling on long sequences.
|
1006
|
+
sum_extend_prefix_lens = (
|
1007
|
+
sum(forward_batch.extend_prefix_lens_cpu)
|
1008
|
+
if forward_batch.extend_prefix_lens_cpu is not None
|
1009
|
+
else 0
|
1010
|
+
)
|
1011
1011
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
1012
|
+
disable_ragged = (
|
1013
|
+
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
1014
|
+
) and self.flashinfer_mla_disable_ragged
|
1012
1015
|
if (
|
1013
|
-
not
|
1016
|
+
not disable_ragged
|
1014
1017
|
and forward_batch.forward_mode.is_extend()
|
1015
1018
|
and not forward_batch.forward_mode.is_target_verify()
|
1016
1019
|
and not forward_batch.forward_mode.is_draft_extend()
|
1017
|
-
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
1018
|
-
):
|
1019
|
-
return AttnForwardMethod.MHA
|
1020
|
-
else:
|
1021
|
-
return _dispatch_mla_subtype()
|
1022
|
-
elif attention_backend == "fa3":
|
1023
|
-
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
1024
|
-
if forward_batch.extend_prefix_lens_cpu is not None:
|
1025
|
-
sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
|
1026
|
-
if (
|
1027
|
-
forward_batch.forward_mode.is_extend()
|
1028
|
-
and not self.disable_chunked_prefix_cache
|
1029
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1030
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1031
1020
|
and (
|
1032
|
-
|
1021
|
+
(
|
1022
|
+
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
|
1023
|
+
and not self.disable_chunked_prefix_cache
|
1024
|
+
)
|
1033
1025
|
or sum_extend_prefix_lens == 0
|
1034
1026
|
)
|
1035
1027
|
):
|
@@ -1697,7 +1689,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1697
1689
|
k[..., self.qk_nope_head_dim :] = k_pe
|
1698
1690
|
|
1699
1691
|
output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
1700
|
-
lse = torch.transpose(lse, 0, 1).contiguous()
|
1701
1692
|
tmp_output = torch.empty_like(accum_output)
|
1702
1693
|
tmp_lse = torch.empty_like(accum_lse)
|
1703
1694
|
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
|
@@ -1719,55 +1710,26 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1719
1710
|
# will be helpful for understanding the purpose of this function.
|
1720
1711
|
|
1721
1712
|
# First do normal mha forward to get output for extended part
|
1722
|
-
|
1723
|
-
|
1724
|
-
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
1725
|
-
)
|
1726
|
-
q = self.q_a_layernorm(q)
|
1727
|
-
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
1728
|
-
else:
|
1729
|
-
q = self.q_proj(hidden_states)[0].view(
|
1730
|
-
-1, self.num_local_heads, self.qk_head_dim
|
1731
|
-
)
|
1732
|
-
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
1733
|
-
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
1734
|
-
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
1735
|
-
latent_cache = latent_cache.unsqueeze(1)
|
1736
|
-
kv_a = self.kv_a_layernorm(kv_a)
|
1737
|
-
kv = self.kv_b_proj(kv_a)[0]
|
1738
|
-
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
1739
|
-
k_nope = kv[..., : self.qk_nope_head_dim]
|
1740
|
-
v = kv[..., self.qk_nope_head_dim :]
|
1741
|
-
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
1742
|
-
|
1743
|
-
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1744
|
-
q[..., self.qk_nope_head_dim :] = q_pe
|
1745
|
-
k = torch.empty_like(q)
|
1746
|
-
k[..., : self.qk_nope_head_dim] = k_nope
|
1747
|
-
k[..., self.qk_nope_head_dim :] = k_pe
|
1748
|
-
|
1749
|
-
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
1750
|
-
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
1751
|
-
|
1752
|
-
# Save latent cache
|
1753
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1754
|
-
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
1713
|
+
return self.forward_normal_prepare(
|
1714
|
+
positions, hidden_states, forward_batch, zero_allocator
|
1755
1715
|
)
|
1756
1716
|
|
1757
|
-
return q, k, v, forward_batch
|
1758
|
-
|
1759
1717
|
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
|
1718
|
+
has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
|
1719
|
+
# Only initialize the info once
|
1720
|
+
if has_extend_prefix and forward_batch.num_prefix_chunks is None:
|
1721
|
+
forward_batch.prepare_chunked_prefix_cache_info(q.device)
|
1722
|
+
if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
|
1723
|
+
forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
|
1724
|
+
|
1725
|
+
forward_batch.mha_return_lse = has_extend_prefix
|
1760
1726
|
# Do mha for extended part without prefix
|
1761
1727
|
forward_batch.set_attn_attend_prefix_cache(False)
|
1762
|
-
attn_output
|
1763
|
-
lse = torch.transpose(lse, 0, 1).contiguous()
|
1728
|
+
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
1764
1729
|
|
1765
1730
|
# Do mha attention with chunked prefix cache if there are any sequence with prefix
|
1766
|
-
if
|
1767
|
-
|
1768
|
-
if forward_batch.num_prefix_chunks is None:
|
1769
|
-
forward_batch.prepare_chunked_prefix_cache_info(q.device)
|
1770
|
-
|
1731
|
+
if has_extend_prefix:
|
1732
|
+
attn_output, lse = attn_output
|
1771
1733
|
forward_batch.set_attn_attend_prefix_cache(True)
|
1772
1734
|
attn_output = self._chunked_prefix_attn_mha(
|
1773
1735
|
q=q,
|
@@ -1866,10 +1828,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1866
1828
|
input_layernorm=self.input_layernorm,
|
1867
1829
|
post_attention_layernorm=self.post_attention_layernorm,
|
1868
1830
|
allow_reduce_scatter=True,
|
1831
|
+
is_last_layer=(
|
1832
|
+
is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
|
1833
|
+
),
|
1869
1834
|
)
|
1870
1835
|
|
1871
|
-
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
1872
|
-
|
1873
1836
|
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
1874
1837
|
return is_nextn or (
|
1875
1838
|
self.config.n_routed_experts is not None
|
@@ -1877,20 +1840,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1877
1840
|
and layer_id % self.config.moe_layer_freq == 0
|
1878
1841
|
)
|
1879
1842
|
|
1880
|
-
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
1881
|
-
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
1882
|
-
|
1883
|
-
batch_size = (
|
1884
|
-
forward_batch.input_ids.shape[0]
|
1885
|
-
if hasattr(forward_batch, "input_ids")
|
1886
|
-
else 0
|
1887
|
-
)
|
1888
|
-
|
1889
|
-
if batch_size > 128:
|
1890
|
-
return False
|
1891
|
-
|
1892
|
-
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
1893
|
-
|
1894
1843
|
def forward(
|
1895
1844
|
self,
|
1896
1845
|
positions: torch.Tensor,
|
@@ -1916,11 +1865,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1916
1865
|
)
|
1917
1866
|
|
1918
1867
|
should_allreduce_fusion = (
|
1919
|
-
self.
|
1920
|
-
|
1921
|
-
is_dp_attention_enabled() and self.speculative_algorithm.is_eagle()
|
1868
|
+
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
|
1869
|
+
forward_batch
|
1922
1870
|
)
|
1923
|
-
and not self.is_nextn
|
1924
1871
|
)
|
1925
1872
|
|
1926
1873
|
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
@@ -2011,26 +1958,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
2011
1958
|
)
|
2012
1959
|
return output
|
2013
1960
|
|
2014
|
-
def _build_fuse_allreduce_lookup_table(self):
|
2015
|
-
static_conditions_met = (
|
2016
|
-
self.layer_id != self.config.num_hidden_layers - 1
|
2017
|
-
and get_tensor_model_parallel_world_size() > 1
|
2018
|
-
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
2019
|
-
and _is_sm100_supported
|
2020
|
-
and _is_flashinfer_available
|
2021
|
-
)
|
2022
|
-
|
2023
|
-
if not static_conditions_met:
|
2024
|
-
return {}
|
2025
|
-
|
2026
|
-
lookup_table = {}
|
2027
|
-
for batch_size in range(129): # 0 to 128
|
2028
|
-
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
2029
|
-
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
2030
|
-
lookup_table[batch_size] = should_fuse
|
2031
|
-
|
2032
|
-
return lookup_table
|
2033
|
-
|
2034
1961
|
|
2035
1962
|
class DeepseekV2Model(nn.Module):
|
2036
1963
|
fall_back_to_pt_during_load = False
|
@@ -2045,26 +1972,52 @@ class DeepseekV2Model(nn.Module):
|
|
2045
1972
|
self.padding_id = config.pad_token_id
|
2046
1973
|
self.vocab_size = config.vocab_size
|
2047
1974
|
self.first_k_dense_replace = config.first_k_dense_replace
|
1975
|
+
self.pp_group = get_pp_group()
|
1976
|
+
|
1977
|
+
if self.pp_group.is_first_rank:
|
1978
|
+
self.embed_tokens = VocabParallelEmbedding(
|
1979
|
+
config.vocab_size,
|
1980
|
+
config.hidden_size,
|
1981
|
+
enable_tp=not is_dp_attention_enabled(),
|
1982
|
+
)
|
1983
|
+
else:
|
1984
|
+
self.embed_tokens = PPMissingLayer()
|
2048
1985
|
|
2049
|
-
self.embed_tokens = VocabParallelEmbedding(
|
2050
|
-
config.vocab_size,
|
2051
|
-
config.hidden_size,
|
2052
|
-
enable_tp=not is_dp_attention_enabled(),
|
2053
|
-
)
|
2054
1986
|
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
2055
|
-
self.layers =
|
2056
|
-
|
2057
|
-
|
2058
|
-
|
2059
|
-
|
2060
|
-
|
2061
|
-
|
2062
|
-
|
2063
|
-
|
2064
|
-
|
2065
|
-
|
1987
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
1988
|
+
config.num_hidden_layers,
|
1989
|
+
lambda idx, prefix: DeepseekV2DecoderLayer(
|
1990
|
+
config=config,
|
1991
|
+
layer_id=idx,
|
1992
|
+
quant_config=quant_config,
|
1993
|
+
prefix=prefix,
|
1994
|
+
alt_stream=self.alt_stream,
|
1995
|
+
),
|
1996
|
+
pp_rank=self.pp_group.rank_in_group,
|
1997
|
+
pp_size=self.pp_group.world_size,
|
1998
|
+
prefix=add_prefix("layers", prefix),
|
1999
|
+
offloader_kwargs=dict(
|
2000
|
+
submodule_accessor=lambda layer: (
|
2001
|
+
layer.mlp.experts
|
2002
|
+
if isinstance(layer.mlp, DeepseekV2MoE)
|
2003
|
+
else layer.mlp
|
2004
|
+
),
|
2005
|
+
whitelist_param_names_creator=lambda module: (
|
2006
|
+
[
|
2007
|
+
"w13_weight",
|
2008
|
+
"w2_weight",
|
2009
|
+
"w13_blockscale_swizzled",
|
2010
|
+
"w2_blockscale_swizzled",
|
2011
|
+
]
|
2012
|
+
if isinstance(module, FusedMoE)
|
2013
|
+
else []
|
2014
|
+
),
|
2015
|
+
),
|
2066
2016
|
)
|
2067
|
-
self.
|
2017
|
+
if self.pp_group.is_last_rank:
|
2018
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
2019
|
+
else:
|
2020
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
2068
2021
|
|
2069
2022
|
def get_input_embeddings(self) -> torch.Tensor:
|
2070
2023
|
return self.embed_tokens
|
@@ -2075,8 +2028,9 @@ class DeepseekV2Model(nn.Module):
|
|
2075
2028
|
positions: torch.Tensor,
|
2076
2029
|
forward_batch: ForwardBatch,
|
2077
2030
|
input_embeds: torch.Tensor = None,
|
2078
|
-
|
2079
|
-
|
2031
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
2032
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
2033
|
+
total_num_layers = self.end_layer - self.start_layer
|
2080
2034
|
device = input_embeds.device if input_embeds is not None else input_ids.device
|
2081
2035
|
zero_allocator = BumpAllocator(
|
2082
2036
|
buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
|
@@ -2084,44 +2038,62 @@ class DeepseekV2Model(nn.Module):
|
|
2084
2038
|
device=device,
|
2085
2039
|
)
|
2086
2040
|
|
2087
|
-
if
|
2088
|
-
|
2041
|
+
if self.pp_group.is_first_rank:
|
2042
|
+
if input_embeds is None:
|
2043
|
+
hidden_states = self.embed_tokens(input_ids)
|
2044
|
+
else:
|
2045
|
+
hidden_states = input_embeds
|
2046
|
+
residual = None
|
2089
2047
|
else:
|
2090
|
-
|
2048
|
+
assert pp_proxy_tensors is not None
|
2049
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
2050
|
+
residual = pp_proxy_tensors["residual"]
|
2091
2051
|
|
2092
|
-
|
2052
|
+
normal_start_layer = self.start_layer
|
2053
|
+
normal_end_layer = self.end_layer
|
2054
|
+
if forward_batch.can_run_tbo:
|
2055
|
+
if (
|
2056
|
+
self.first_k_dense_replace > normal_start_layer
|
2057
|
+
and self.first_k_dense_replace < normal_end_layer
|
2058
|
+
):
|
2059
|
+
normal_end_layer = self.first_k_dense_replace
|
2060
|
+
elif self.first_k_dense_replace < normal_start_layer:
|
2061
|
+
normal_end_layer = normal_start_layer = 0
|
2093
2062
|
|
2094
|
-
|
2095
|
-
self.first_k_dense_replace
|
2096
|
-
if forward_batch.can_run_tbo
|
2097
|
-
else total_num_layers
|
2098
|
-
)
|
2099
|
-
for i in range(normal_num_layers):
|
2063
|
+
for i in range(normal_start_layer, normal_end_layer):
|
2100
2064
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
2101
2065
|
layer = self.layers[i]
|
2102
2066
|
hidden_states, residual = layer(
|
2103
2067
|
positions, hidden_states, forward_batch, residual, zero_allocator
|
2104
2068
|
)
|
2105
2069
|
|
2106
|
-
if
|
2070
|
+
if normal_end_layer != self.end_layer:
|
2107
2071
|
hidden_states, residual = model_forward_maybe_tbo(
|
2108
|
-
layers=self.layers[
|
2072
|
+
layers=self.layers[normal_end_layer : self.end_layer],
|
2109
2073
|
enable_tbo=True,
|
2110
2074
|
positions=positions,
|
2111
2075
|
forward_batch=forward_batch,
|
2112
2076
|
hidden_states=hidden_states,
|
2113
2077
|
residual=residual,
|
2114
2078
|
input_data_scatter_mode=self.layers[
|
2115
|
-
|
2079
|
+
normal_end_layer - 1
|
2116
2080
|
].layer_scatter_modes.layer_output_mode,
|
2117
2081
|
zero_allocator=zero_allocator,
|
2118
2082
|
)
|
2119
2083
|
|
2120
|
-
if not
|
2121
|
-
|
2122
|
-
|
2123
|
-
|
2124
|
-
|
2084
|
+
if not self.pp_group.is_last_rank:
|
2085
|
+
return PPProxyTensors(
|
2086
|
+
{
|
2087
|
+
"hidden_states": hidden_states,
|
2088
|
+
"residual": residual,
|
2089
|
+
}
|
2090
|
+
)
|
2091
|
+
else:
|
2092
|
+
if not forward_batch.forward_mode.is_idle():
|
2093
|
+
if residual is None:
|
2094
|
+
hidden_states = self.norm(hidden_states)
|
2095
|
+
else:
|
2096
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
2125
2097
|
return hidden_states
|
2126
2098
|
|
2127
2099
|
|
@@ -2148,6 +2120,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2148
2120
|
"kv_a_proj_with_mqa",
|
2149
2121
|
]
|
2150
2122
|
|
2123
|
+
self.pp_group = get_pp_group()
|
2151
2124
|
self.config = config
|
2152
2125
|
self.tp_size = get_tensor_model_parallel_world_size()
|
2153
2126
|
self.quant_config = quant_config
|
@@ -2217,13 +2190,27 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2217
2190
|
positions: torch.Tensor,
|
2218
2191
|
forward_batch: ForwardBatch,
|
2219
2192
|
input_embeds: torch.Tensor = None,
|
2193
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
2220
2194
|
) -> torch.Tensor:
|
2221
|
-
hidden_states = self.model(
|
2222
|
-
|
2223
|
-
return self.logits_processor(
|
2224
|
-
input_ids, hidden_states, self.lm_head, forward_batch
|
2195
|
+
hidden_states = self.model(
|
2196
|
+
input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
|
2225
2197
|
)
|
2226
2198
|
|
2199
|
+
if self.pp_group.is_last_rank:
|
2200
|
+
return self.logits_processor(
|
2201
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
2202
|
+
)
|
2203
|
+
else:
|
2204
|
+
return hidden_states
|
2205
|
+
|
2206
|
+
@property
|
2207
|
+
def start_layer(self):
|
2208
|
+
return self.model.start_layer
|
2209
|
+
|
2210
|
+
@property
|
2211
|
+
def end_layer(self):
|
2212
|
+
return self.model.end_layer
|
2213
|
+
|
2227
2214
|
def post_load_weights(self, is_nextn=False, weight_names=None):
|
2228
2215
|
|
2229
2216
|
# Perform post-processing after loading weights
|
@@ -2231,7 +2218,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2231
2218
|
layer_ids = [self.config.num_hidden_layers]
|
2232
2219
|
else:
|
2233
2220
|
if weight_names is None:
|
2234
|
-
layer_ids = range(self.
|
2221
|
+
layer_ids = range(self.model.start_layer, self.model.end_layer)
|
2235
2222
|
else:
|
2236
2223
|
layer_ids = set()
|
2237
2224
|
for name in weight_names:
|
@@ -2478,17 +2465,15 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2478
2465
|
|
2479
2466
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
2480
2467
|
# (param_name, weight_name, expert_id, shard_id)
|
2481
|
-
expert_params_mapping =
|
2468
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
2482
2469
|
ckpt_gate_proj_name="gate_proj",
|
2483
2470
|
ckpt_down_proj_name="down_proj",
|
2484
2471
|
ckpt_up_proj_name="up_proj",
|
2485
2472
|
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
2486
2473
|
)
|
2487
2474
|
if self.quant_config and self.quant_config.get_name() == "w4afp8":
|
2488
|
-
expert_params_mapping += (
|
2489
|
-
|
2490
|
-
num_experts=self.config.n_routed_experts
|
2491
|
-
)
|
2475
|
+
expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
|
2476
|
+
num_experts=self.config.n_routed_experts
|
2492
2477
|
)
|
2493
2478
|
|
2494
2479
|
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
@@ -2515,6 +2500,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2515
2500
|
params_dict = dict(self.named_parameters())
|
2516
2501
|
weight_names = []
|
2517
2502
|
for name, loaded_weight in weights:
|
2503
|
+
layer_id = get_layer_id(name)
|
2504
|
+
if (
|
2505
|
+
layer_id is not None
|
2506
|
+
and hasattr(self.model, "start_layer")
|
2507
|
+
and (
|
2508
|
+
layer_id < self.model.start_layer
|
2509
|
+
or layer_id >= self.model.end_layer
|
2510
|
+
)
|
2511
|
+
):
|
2512
|
+
continue
|
2518
2513
|
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
|
2519
2514
|
name = name.replace(
|
2520
2515
|
"mlp.shared_experts",
|
@@ -2599,6 +2594,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2599
2594
|
# Skip loading extra bias for GPTQ models.
|
2600
2595
|
if name.endswith(".bias") and name not in params_dict:
|
2601
2596
|
continue
|
2597
|
+
# Skip loading embed_tokens if not first rank in pipeline parallelism
|
2598
|
+
if ".embed_tokens." in name and not self.pp_group.is_first_rank:
|
2599
|
+
continue
|
2600
|
+
# Skip loading norm if not last rank in pipeline parallelism
|
2601
|
+
if ".norm." in name and not self.pp_group.is_last_rank:
|
2602
|
+
continue
|
2602
2603
|
if fuse_qkv_a_proj and (
|
2603
2604
|
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
2604
2605
|
):
|