sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ import concurrent.futures
20
20
  import logging
21
21
  import os
22
22
  from enum import IntEnum, auto
23
- from typing import Any, Dict, Iterable, Optional, Tuple
23
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
24
24
 
25
25
  import torch
26
26
  import torch.nn.functional as F
@@ -30,6 +30,7 @@ from transformers import PretrainedConfig
30
30
 
31
31
  from sglang.srt.distributed import (
32
32
  get_moe_expert_parallel_world_size,
33
+ get_pp_group,
33
34
  get_tensor_model_parallel_world_size,
34
35
  parallel_state,
35
36
  tensor_model_parallel_all_reduce,
@@ -50,7 +51,7 @@ from sglang.srt.layers.communicator import (
50
51
  from sglang.srt.layers.dp_attention import (
51
52
  get_attention_tp_rank,
52
53
  get_attention_tp_size,
53
- get_local_attention_dp_size,
54
+ is_dp_attention_enabled,
54
55
  )
55
56
  from sglang.srt.layers.layernorm import RMSNorm
56
57
  from sglang.srt.layers.linear import (
@@ -60,9 +61,14 @@ from sglang.srt.layers.linear import (
60
61
  RowParallelLinear,
61
62
  )
62
63
  from sglang.srt.layers.logits_processor import LogitsProcessor
64
+ from sglang.srt.layers.moe import (
65
+ get_deepep_mode,
66
+ get_moe_a2a_backend,
67
+ should_use_flashinfer_cutlass_moe_fp4_allgather,
68
+ )
63
69
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
70
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
64
71
  from sglang.srt.layers.moe.topk import TopK
65
- from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
66
72
  from sglang.srt.layers.quantization import deep_gemm_wrapper
67
73
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
68
74
  from sglang.srt.layers.quantization.fp8_kernel import (
@@ -82,13 +88,13 @@ from sglang.srt.layers.quantization.int8_utils import (
82
88
  )
83
89
  from sglang.srt.layers.radix_attention import RadixAttention
84
90
  from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
85
- from sglang.srt.layers.utils import is_sm100_supported
91
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
86
92
  from sglang.srt.layers.vocab_parallel_embedding import (
87
93
  ParallelLMHead,
88
94
  VocabParallelEmbedding,
89
95
  )
90
96
  from sglang.srt.managers.schedule_batch import global_server_args_dict
91
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
97
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
92
98
  from sglang.srt.model_loader.weight_utils import default_weight_loader
93
99
  from sglang.srt.two_batch_overlap import (
94
100
  MaybeTboDeepEPDispatcher,
@@ -109,6 +115,7 @@ from sglang.srt.utils import (
109
115
  is_hip,
110
116
  is_non_idle_and_non_empty,
111
117
  log_info_on_rank0,
118
+ make_layers,
112
119
  use_intel_amx_backend,
113
120
  )
114
121
 
@@ -312,18 +319,7 @@ class DeepseekV2MoE(nn.Module):
312
319
  config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
313
320
  )
314
321
 
315
- self.topk = TopK(
316
- top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
317
- renormalize=config.norm_topk_prob,
318
- use_grouped_topk=True,
319
- num_expert_group=config.n_group,
320
- num_fused_shared_experts=self.num_fused_shared_experts,
321
- topk_group=config.topk_group,
322
- correction_bias=self.gate.e_score_correction_bias,
323
- routed_scaling_factor=self.routed_scaling_factor,
324
- )
325
-
326
- self.experts = get_moe_impl_class()(
322
+ self.experts = get_moe_impl_class(quant_config)(
327
323
  num_experts=config.n_routed_experts
328
324
  + self.num_fused_shared_experts
329
325
  + global_server_args_dict["ep_num_redundant_experts"],
@@ -335,30 +331,19 @@ class DeepseekV2MoE(nn.Module):
335
331
  quant_config=quant_config,
336
332
  routed_scaling_factor=self.routed_scaling_factor,
337
333
  prefix=add_prefix("experts", prefix),
338
- **(
339
- dict(deepep_mode=global_server_args_dict["deepep_mode"])
340
- if global_server_args_dict["moe_a2a_backend"].is_deepep()
341
- else {}
342
- ),
343
- # Additional args for FusedMoE
344
- **(
345
- dict(
346
- enable_flashinfer_cutlass_moe=True,
347
- )
348
- if global_server_args_dict["enable_flashinfer_cutlass_moe"]
349
- else {}
350
- ),
351
- **(
352
- dict(
353
- renormalize=config.norm_topk_prob,
354
- use_grouped_topk=True,
355
- num_expert_group=config.n_group,
356
- topk_group=config.topk_group,
357
- correction_bias=self.gate.e_score_correction_bias,
358
- )
359
- if should_use_flashinfer_trtllm_moe()
360
- else {}
361
- ),
334
+ )
335
+
336
+ self.topk = TopK(
337
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
338
+ renormalize=config.norm_topk_prob,
339
+ use_grouped_topk=True,
340
+ num_expert_group=config.n_group,
341
+ num_fused_shared_experts=self.num_fused_shared_experts,
342
+ topk_group=config.topk_group,
343
+ correction_bias=self.gate.e_score_correction_bias,
344
+ routed_scaling_factor=self.routed_scaling_factor,
345
+ apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
346
+ force_topk=quant_config is None,
362
347
  )
363
348
 
364
349
  self.shared_experts_is_int8 = False
@@ -366,7 +351,7 @@ class DeepseekV2MoE(nn.Module):
366
351
  self.shared_experts_weight_block_size = None
367
352
  if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
368
353
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
369
- # disable tp for shared experts when enable deepep moe
354
+ # disable tp for shared experts when enable deepep moe, or with fp4 allgather
370
355
  self.shared_experts = DeepseekV2MLP(
371
356
  hidden_size=config.hidden_size,
372
357
  intermediate_size=intermediate_size,
@@ -376,7 +361,8 @@ class DeepseekV2MoE(nn.Module):
376
361
  prefix=add_prefix("shared_experts", prefix),
377
362
  **(
378
363
  dict(tp_rank=0, tp_size=1)
379
- if global_server_args_dict["moe_a2a_backend"].is_deepep()
364
+ if get_moe_a2a_backend().is_deepep()
365
+ or should_use_flashinfer_cutlass_moe_fp4_allgather()
380
366
  else {}
381
367
  ),
382
368
  )
@@ -406,7 +392,7 @@ class DeepseekV2MoE(nn.Module):
406
392
 
407
393
  self.top_k = config.num_experts_per_tok
408
394
 
409
- if global_server_args_dict["moe_a2a_backend"].is_deepep():
395
+ if get_moe_a2a_backend().is_deepep():
410
396
  # TODO: we will support tp < ep in the future
411
397
  self.ep_size = get_moe_expert_parallel_world_size()
412
398
  self.num_experts = (
@@ -430,12 +416,12 @@ class DeepseekV2MoE(nn.Module):
430
416
  num_local_experts=config.n_routed_experts // self.tp_size,
431
417
  hidden_size=config.hidden_size,
432
418
  params_dtype=config.torch_dtype,
433
- deepep_mode=global_server_args_dict["deepep_mode"],
419
+ deepep_mode=get_deepep_mode(),
434
420
  async_finish=True,
435
421
  return_recv_hook=True,
436
422
  )
437
423
 
438
- self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
424
+ self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
439
425
 
440
426
  def get_moe_weights(self):
441
427
  return [
@@ -456,14 +442,19 @@ class DeepseekV2MoE(nn.Module):
456
442
  if (
457
443
  self.alt_stream is not None
458
444
  and self.num_fused_shared_experts == 0
445
+ and hidden_states.shape[0] > 0
459
446
  and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
460
447
  ):
461
448
  return self.forward_normal_dual_stream(
462
- hidden_states, should_allreduce_fusion, use_reduce_scatter
449
+ hidden_states,
450
+ should_allreduce_fusion,
451
+ use_reduce_scatter,
463
452
  )
464
453
  else:
465
454
  return self.forward_normal(
466
- hidden_states, should_allreduce_fusion, use_reduce_scatter
455
+ hidden_states,
456
+ should_allreduce_fusion,
457
+ use_reduce_scatter,
467
458
  )
468
459
  else:
469
460
  return self.forward_deepep(hidden_states, forward_batch)
@@ -482,25 +473,24 @@ class DeepseekV2MoE(nn.Module):
482
473
  with torch.cuda.stream(self.alt_stream):
483
474
  # router_logits: (num_tokens, n_experts)
484
475
  router_logits = self.gate(hidden_states)
485
- kwargs = {"hidden_states": hidden_states}
486
-
487
- # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
488
- # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
489
- if should_use_flashinfer_trtllm_moe():
490
- kwargs["topk_output"] = (self.topk, router_logits)
491
- else:
492
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
493
-
494
- final_hidden_states = self.experts(**kwargs)
476
+ topk_output = self.topk(hidden_states, router_logits)
477
+ final_hidden_states = self.experts(hidden_states, topk_output)
495
478
  if not _is_cuda:
496
479
  final_hidden_states *= self.routed_scaling_factor
480
+
497
481
  current_stream.wait_stream(self.alt_stream)
498
482
  with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
499
483
  final_hidden_states_out = torch.empty_like(final_hidden_states)
484
+
500
485
  torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
501
486
  final_hidden_states = final_hidden_states_out
502
487
  sm.tag(final_hidden_states)
503
- if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
488
+ if (
489
+ self.tp_size > 1
490
+ and not should_allreduce_fusion
491
+ and not use_reduce_scatter
492
+ and not should_use_flashinfer_cutlass_moe_fp4_allgather()
493
+ ):
504
494
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
505
495
  return final_hidden_states
506
496
 
@@ -515,19 +505,16 @@ class DeepseekV2MoE(nn.Module):
515
505
  ):
516
506
  return self.forward_cpu(hidden_states, should_allreduce_fusion)
517
507
 
518
- shared_output = self._forward_shared_experts(hidden_states)
519
- # router_logits: (num_tokens, n_experts)
520
- router_logits = self.gate(hidden_states)
521
- kwargs = {"hidden_states": hidden_states}
522
-
523
- # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
524
- # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
525
- if should_use_flashinfer_trtllm_moe():
526
- kwargs["topk_output"] = (self.topk, router_logits)
508
+ if hidden_states.shape[0] > 0:
509
+ shared_output = self._forward_shared_experts(hidden_states)
510
+ # router_logits: (num_tokens, n_experts)
511
+ router_logits = self.gate(hidden_states)
512
+ topk_output = self.topk(hidden_states, router_logits)
527
513
  else:
528
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
514
+ shared_output = None
515
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
529
516
 
530
- final_hidden_states = self.experts(**kwargs)
517
+ final_hidden_states = self.experts(hidden_states, topk_output)
531
518
  if not _is_cuda and not _use_aiter:
532
519
  # fused in biased_grouped_topk so we can skip here
533
520
  final_hidden_states *= self.routed_scaling_factor
@@ -537,7 +524,12 @@ class DeepseekV2MoE(nn.Module):
537
524
  torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
538
525
  final_hidden_states = final_hidden_states_out
539
526
  sm.tag(final_hidden_states)
540
- if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
527
+ if (
528
+ self.tp_size > 1
529
+ and not should_allreduce_fusion
530
+ and not use_reduce_scatter
531
+ and not should_use_flashinfer_cutlass_moe_fp4_allgather()
532
+ ):
541
533
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
542
534
  return final_hidden_states
543
535
 
@@ -616,11 +608,8 @@ class DeepseekV2MoE(nn.Module):
616
608
  ),
617
609
  )
618
610
  else:
619
- topk_idx = torch.full(
620
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
621
- )
622
- topk_weights = torch.empty(
623
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
611
+ topk_weights, topk_idx, _ = self.topk.empty_topk_output(
612
+ hidden_states.device
624
613
  )
625
614
 
626
615
  final_hidden_states = self.experts(
@@ -1006,29 +995,33 @@ class DeepseekV2AttentionMLA(nn.Module):
1006
995
 
1007
996
  if attention_backend == "ascend":
1008
997
  return AttnForwardMethod.MLA
1009
- elif attention_backend == "flashinfer":
998
+ elif (
999
+ attention_backend == "flashinfer"
1000
+ or attention_backend == "fa3"
1001
+ or attention_backend == "flashmla"
1002
+ or attention_backend == "trtllm_mla"
1003
+ or attention_backend == "cutlass_mla"
1004
+ ):
1005
+ # Use MHA with chunked KV cache when prefilling on long sequences.
1006
+ sum_extend_prefix_lens = (
1007
+ sum(forward_batch.extend_prefix_lens_cpu)
1008
+ if forward_batch.extend_prefix_lens_cpu is not None
1009
+ else 0
1010
+ )
1010
1011
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
1012
+ disable_ragged = (
1013
+ attention_backend == "flashinfer" or attention_backend == "flashmla"
1014
+ ) and self.flashinfer_mla_disable_ragged
1011
1015
  if (
1012
- not self.flashinfer_mla_disable_ragged
1016
+ not disable_ragged
1013
1017
  and forward_batch.forward_mode.is_extend()
1014
1018
  and not forward_batch.forward_mode.is_target_verify()
1015
1019
  and not forward_batch.forward_mode.is_draft_extend()
1016
- and sum(forward_batch.extend_prefix_lens_cpu) == 0
1017
- ):
1018
- return AttnForwardMethod.MHA
1019
- else:
1020
- return _dispatch_mla_subtype()
1021
- elif attention_backend == "fa3":
1022
- # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
1023
- if forward_batch.extend_prefix_lens_cpu is not None:
1024
- sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
1025
- if (
1026
- forward_batch.forward_mode.is_extend()
1027
- and not self.disable_chunked_prefix_cache
1028
- and not forward_batch.forward_mode.is_target_verify()
1029
- and not forward_batch.forward_mode.is_draft_extend()
1030
1020
  and (
1031
- sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
1021
+ (
1022
+ sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
1023
+ and not self.disable_chunked_prefix_cache
1024
+ )
1032
1025
  or sum_extend_prefix_lens == 0
1033
1026
  )
1034
1027
  ):
@@ -1696,7 +1689,6 @@ class DeepseekV2AttentionMLA(nn.Module):
1696
1689
  k[..., self.qk_nope_head_dim :] = k_pe
1697
1690
 
1698
1691
  output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
1699
- lse = torch.transpose(lse, 0, 1).contiguous()
1700
1692
  tmp_output = torch.empty_like(accum_output)
1701
1693
  tmp_lse = torch.empty_like(accum_lse)
1702
1694
  merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
@@ -1718,55 +1710,26 @@ class DeepseekV2AttentionMLA(nn.Module):
1718
1710
  # will be helpful for understanding the purpose of this function.
1719
1711
 
1720
1712
  # First do normal mha forward to get output for extended part
1721
- if self.q_lora_rank is not None:
1722
- q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
1723
- [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1724
- )
1725
- q = self.q_a_layernorm(q)
1726
- q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
1727
- else:
1728
- q = self.q_proj(hidden_states)[0].view(
1729
- -1, self.num_local_heads, self.qk_head_dim
1730
- )
1731
- latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1732
- _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1733
- kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1734
- latent_cache = latent_cache.unsqueeze(1)
1735
- kv_a = self.kv_a_layernorm(kv_a)
1736
- kv = self.kv_b_proj(kv_a)[0]
1737
- kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1738
- k_nope = kv[..., : self.qk_nope_head_dim]
1739
- v = kv[..., self.qk_nope_head_dim :]
1740
- k_pe = latent_cache[:, :, self.kv_lora_rank :]
1741
-
1742
- q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1743
- q[..., self.qk_nope_head_dim :] = q_pe
1744
- k = torch.empty_like(q)
1745
- k[..., : self.qk_nope_head_dim] = k_nope
1746
- k[..., self.qk_nope_head_dim :] = k_pe
1747
-
1748
- latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1749
- latent_cache[:, :, self.kv_lora_rank :] = k_pe
1750
-
1751
- # Save latent cache
1752
- forward_batch.token_to_kv_pool.set_kv_buffer(
1753
- self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1713
+ return self.forward_normal_prepare(
1714
+ positions, hidden_states, forward_batch, zero_allocator
1754
1715
  )
1755
1716
 
1756
- return q, k, v, forward_batch
1757
-
1758
1717
  def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
1718
+ has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
1719
+ # Only initialize the info once
1720
+ if has_extend_prefix and forward_batch.num_prefix_chunks is None:
1721
+ forward_batch.prepare_chunked_prefix_cache_info(q.device)
1722
+ if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
1723
+ forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
1724
+
1725
+ forward_batch.mha_return_lse = has_extend_prefix
1759
1726
  # Do mha for extended part without prefix
1760
1727
  forward_batch.set_attn_attend_prefix_cache(False)
1761
- attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
1762
- lse = torch.transpose(lse, 0, 1).contiguous()
1728
+ attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
1763
1729
 
1764
1730
  # Do mha attention with chunked prefix cache if there are any sequence with prefix
1765
- if any(forward_batch.extend_prefix_lens_cpu):
1766
- # Only initialize the info once
1767
- if forward_batch.num_prefix_chunks is None:
1768
- forward_batch.prepare_chunked_prefix_cache_info(q.device)
1769
-
1731
+ if has_extend_prefix:
1732
+ attn_output, lse = attn_output
1770
1733
  forward_batch.set_attn_attend_prefix_cache(True)
1771
1734
  attn_output = self._chunked_prefix_attn_mha(
1772
1735
  q=q,
@@ -1797,7 +1760,6 @@ class DeepseekV2DecoderLayer(nn.Module):
1797
1760
  rope_theta = getattr(config, "rope_theta", 10000)
1798
1761
  rope_scaling = getattr(config, "rope_scaling", None)
1799
1762
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1800
- self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
1801
1763
  self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
1802
1764
  self.layer_id = layer_id
1803
1765
  self.is_nextn = is_nextn
@@ -1866,10 +1828,11 @@ class DeepseekV2DecoderLayer(nn.Module):
1866
1828
  input_layernorm=self.input_layernorm,
1867
1829
  post_attention_layernorm=self.post_attention_layernorm,
1868
1830
  allow_reduce_scatter=True,
1831
+ is_last_layer=(
1832
+ is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
1833
+ ),
1869
1834
  )
1870
1835
 
1871
- self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
1872
-
1873
1836
  def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
1874
1837
  return is_nextn or (
1875
1838
  self.config.n_routed_experts is not None
@@ -1877,20 +1840,6 @@ class DeepseekV2DecoderLayer(nn.Module):
1877
1840
  and layer_id % self.config.moe_layer_freq == 0
1878
1841
  )
1879
1842
 
1880
- def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
1881
- """Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
1882
-
1883
- batch_size = (
1884
- forward_batch.input_ids.shape[0]
1885
- if hasattr(forward_batch, "input_ids")
1886
- else 0
1887
- )
1888
-
1889
- if batch_size > 128:
1890
- return False
1891
-
1892
- return self._fuse_allreduce_lookup_table.get(batch_size, False)
1893
-
1894
1843
  def forward(
1895
1844
  self,
1896
1845
  positions: torch.Tensor,
@@ -1916,9 +1865,9 @@ class DeepseekV2DecoderLayer(nn.Module):
1916
1865
  )
1917
1866
 
1918
1867
  should_allreduce_fusion = (
1919
- self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
1920
- and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
1921
- and not self.is_nextn
1868
+ self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
1869
+ forward_batch
1870
+ )
1922
1871
  )
1923
1872
 
1924
1873
  # For DP with padding, reduce scatter can be used instead of all-reduce.
@@ -2009,26 +1958,6 @@ class DeepseekV2DecoderLayer(nn.Module):
2009
1958
  )
2010
1959
  return output
2011
1960
 
2012
- def _build_fuse_allreduce_lookup_table(self):
2013
- static_conditions_met = (
2014
- self.layer_id != self.config.num_hidden_layers - 1
2015
- and get_tensor_model_parallel_world_size() > 1
2016
- and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
2017
- and _is_sm100_supported
2018
- and _is_flashinfer_available
2019
- )
2020
-
2021
- if not static_conditions_met:
2022
- return {}
2023
-
2024
- lookup_table = {}
2025
- for batch_size in range(129): # 0 to 128
2026
- is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
2027
- should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
2028
- lookup_table[batch_size] = should_fuse
2029
-
2030
- return lookup_table
2031
-
2032
1961
 
2033
1962
  class DeepseekV2Model(nn.Module):
2034
1963
  fall_back_to_pt_during_load = False
@@ -2043,26 +1972,52 @@ class DeepseekV2Model(nn.Module):
2043
1972
  self.padding_id = config.pad_token_id
2044
1973
  self.vocab_size = config.vocab_size
2045
1974
  self.first_k_dense_replace = config.first_k_dense_replace
1975
+ self.pp_group = get_pp_group()
1976
+
1977
+ if self.pp_group.is_first_rank:
1978
+ self.embed_tokens = VocabParallelEmbedding(
1979
+ config.vocab_size,
1980
+ config.hidden_size,
1981
+ enable_tp=not is_dp_attention_enabled(),
1982
+ )
1983
+ else:
1984
+ self.embed_tokens = PPMissingLayer()
2046
1985
 
2047
- self.embed_tokens = VocabParallelEmbedding(
2048
- config.vocab_size,
2049
- config.hidden_size,
2050
- enable_tp=not global_server_args_dict["enable_dp_attention"],
2051
- )
2052
1986
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
2053
- self.layers = nn.ModuleList(
2054
- [
2055
- DeepseekV2DecoderLayer(
2056
- config,
2057
- layer_id,
2058
- quant_config=quant_config,
2059
- prefix=add_prefix(f"layers.{layer_id}", prefix),
2060
- alt_stream=self.alt_stream,
2061
- )
2062
- for layer_id in range(config.num_hidden_layers)
2063
- ]
1987
+ self.layers, self.start_layer, self.end_layer = make_layers(
1988
+ config.num_hidden_layers,
1989
+ lambda idx, prefix: DeepseekV2DecoderLayer(
1990
+ config=config,
1991
+ layer_id=idx,
1992
+ quant_config=quant_config,
1993
+ prefix=prefix,
1994
+ alt_stream=self.alt_stream,
1995
+ ),
1996
+ pp_rank=self.pp_group.rank_in_group,
1997
+ pp_size=self.pp_group.world_size,
1998
+ prefix=add_prefix("layers", prefix),
1999
+ offloader_kwargs=dict(
2000
+ submodule_accessor=lambda layer: (
2001
+ layer.mlp.experts
2002
+ if isinstance(layer.mlp, DeepseekV2MoE)
2003
+ else layer.mlp
2004
+ ),
2005
+ whitelist_param_names_creator=lambda module: (
2006
+ [
2007
+ "w13_weight",
2008
+ "w2_weight",
2009
+ "w13_blockscale_swizzled",
2010
+ "w2_blockscale_swizzled",
2011
+ ]
2012
+ if isinstance(module, FusedMoE)
2013
+ else []
2014
+ ),
2015
+ ),
2064
2016
  )
2065
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2017
+ if self.pp_group.is_last_rank:
2018
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2019
+ else:
2020
+ self.norm = PPMissingLayer(return_tuple=True)
2066
2021
 
2067
2022
  def get_input_embeddings(self) -> torch.Tensor:
2068
2023
  return self.embed_tokens
@@ -2073,8 +2028,9 @@ class DeepseekV2Model(nn.Module):
2073
2028
  positions: torch.Tensor,
2074
2029
  forward_batch: ForwardBatch,
2075
2030
  input_embeds: torch.Tensor = None,
2076
- ) -> torch.Tensor:
2077
- total_num_layers = len(self.layers)
2031
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
2032
+ ) -> Union[torch.Tensor, PPProxyTensors]:
2033
+ total_num_layers = self.end_layer - self.start_layer
2078
2034
  device = input_embeds.device if input_embeds is not None else input_ids.device
2079
2035
  zero_allocator = BumpAllocator(
2080
2036
  buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
@@ -2082,44 +2038,62 @@ class DeepseekV2Model(nn.Module):
2082
2038
  device=device,
2083
2039
  )
2084
2040
 
2085
- if input_embeds is None:
2086
- hidden_states = self.embed_tokens(input_ids)
2041
+ if self.pp_group.is_first_rank:
2042
+ if input_embeds is None:
2043
+ hidden_states = self.embed_tokens(input_ids)
2044
+ else:
2045
+ hidden_states = input_embeds
2046
+ residual = None
2087
2047
  else:
2088
- hidden_states = input_embeds
2048
+ assert pp_proxy_tensors is not None
2049
+ hidden_states = pp_proxy_tensors["hidden_states"]
2050
+ residual = pp_proxy_tensors["residual"]
2089
2051
 
2090
- residual = None
2052
+ normal_start_layer = self.start_layer
2053
+ normal_end_layer = self.end_layer
2054
+ if forward_batch.can_run_tbo:
2055
+ if (
2056
+ self.first_k_dense_replace > normal_start_layer
2057
+ and self.first_k_dense_replace < normal_end_layer
2058
+ ):
2059
+ normal_end_layer = self.first_k_dense_replace
2060
+ elif self.first_k_dense_replace < normal_start_layer:
2061
+ normal_end_layer = normal_start_layer = 0
2091
2062
 
2092
- normal_num_layers = (
2093
- self.first_k_dense_replace
2094
- if forward_batch.can_run_tbo
2095
- else total_num_layers
2096
- )
2097
- for i in range(normal_num_layers):
2063
+ for i in range(normal_start_layer, normal_end_layer):
2098
2064
  with get_global_expert_distribution_recorder().with_current_layer(i):
2099
2065
  layer = self.layers[i]
2100
2066
  hidden_states, residual = layer(
2101
2067
  positions, hidden_states, forward_batch, residual, zero_allocator
2102
2068
  )
2103
2069
 
2104
- if normal_num_layers != total_num_layers:
2070
+ if normal_end_layer != self.end_layer:
2105
2071
  hidden_states, residual = model_forward_maybe_tbo(
2106
- layers=self.layers[normal_num_layers:],
2072
+ layers=self.layers[normal_end_layer : self.end_layer],
2107
2073
  enable_tbo=True,
2108
2074
  positions=positions,
2109
2075
  forward_batch=forward_batch,
2110
2076
  hidden_states=hidden_states,
2111
2077
  residual=residual,
2112
2078
  input_data_scatter_mode=self.layers[
2113
- normal_num_layers - 1
2079
+ normal_end_layer - 1
2114
2080
  ].layer_scatter_modes.layer_output_mode,
2115
2081
  zero_allocator=zero_allocator,
2116
2082
  )
2117
2083
 
2118
- if not forward_batch.forward_mode.is_idle():
2119
- if residual is None:
2120
- hidden_states = self.norm(hidden_states)
2121
- else:
2122
- hidden_states, _ = self.norm(hidden_states, residual)
2084
+ if not self.pp_group.is_last_rank:
2085
+ return PPProxyTensors(
2086
+ {
2087
+ "hidden_states": hidden_states,
2088
+ "residual": residual,
2089
+ }
2090
+ )
2091
+ else:
2092
+ if not forward_batch.forward_mode.is_idle():
2093
+ if residual is None:
2094
+ hidden_states = self.norm(hidden_states)
2095
+ else:
2096
+ hidden_states, _ = self.norm(hidden_states, residual)
2123
2097
  return hidden_states
2124
2098
 
2125
2099
 
@@ -2146,6 +2120,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2146
2120
  "kv_a_proj_with_mqa",
2147
2121
  ]
2148
2122
 
2123
+ self.pp_group = get_pp_group()
2149
2124
  self.config = config
2150
2125
  self.tp_size = get_tensor_model_parallel_world_size()
2151
2126
  self.quant_config = quant_config
@@ -2215,13 +2190,27 @@ class DeepseekV2ForCausalLM(nn.Module):
2215
2190
  positions: torch.Tensor,
2216
2191
  forward_batch: ForwardBatch,
2217
2192
  input_embeds: torch.Tensor = None,
2193
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
2218
2194
  ) -> torch.Tensor:
2219
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
2220
-
2221
- return self.logits_processor(
2222
- input_ids, hidden_states, self.lm_head, forward_batch
2195
+ hidden_states = self.model(
2196
+ input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
2223
2197
  )
2224
2198
 
2199
+ if self.pp_group.is_last_rank:
2200
+ return self.logits_processor(
2201
+ input_ids, hidden_states, self.lm_head, forward_batch
2202
+ )
2203
+ else:
2204
+ return hidden_states
2205
+
2206
+ @property
2207
+ def start_layer(self):
2208
+ return self.model.start_layer
2209
+
2210
+ @property
2211
+ def end_layer(self):
2212
+ return self.model.end_layer
2213
+
2225
2214
  def post_load_weights(self, is_nextn=False, weight_names=None):
2226
2215
 
2227
2216
  # Perform post-processing after loading weights
@@ -2229,7 +2218,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2229
2218
  layer_ids = [self.config.num_hidden_layers]
2230
2219
  else:
2231
2220
  if weight_names is None:
2232
- layer_ids = range(self.config.num_hidden_layers)
2221
+ layer_ids = range(self.model.start_layer, self.model.end_layer)
2233
2222
  else:
2234
2223
  layer_ids = set()
2235
2224
  for name in weight_names:
@@ -2476,17 +2465,15 @@ class DeepseekV2ForCausalLM(nn.Module):
2476
2465
 
2477
2466
  # Params for weights, fp8 weight scales, fp8 activation scales
2478
2467
  # (param_name, weight_name, expert_id, shard_id)
2479
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
2468
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
2480
2469
  ckpt_gate_proj_name="gate_proj",
2481
2470
  ckpt_down_proj_name="down_proj",
2482
2471
  ckpt_up_proj_name="up_proj",
2483
2472
  num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
2484
2473
  )
2485
2474
  if self.quant_config and self.quant_config.get_name() == "w4afp8":
2486
- expert_params_mapping += (
2487
- get_moe_impl_class().make_expert_input_scale_params_mapping(
2488
- num_experts=self.config.n_routed_experts
2489
- )
2475
+ expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
2476
+ num_experts=self.config.n_routed_experts
2490
2477
  )
2491
2478
 
2492
2479
  # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
@@ -2513,6 +2500,16 @@ class DeepseekV2ForCausalLM(nn.Module):
2513
2500
  params_dict = dict(self.named_parameters())
2514
2501
  weight_names = []
2515
2502
  for name, loaded_weight in weights:
2503
+ layer_id = get_layer_id(name)
2504
+ if (
2505
+ layer_id is not None
2506
+ and hasattr(self.model, "start_layer")
2507
+ and (
2508
+ layer_id < self.model.start_layer
2509
+ or layer_id >= self.model.end_layer
2510
+ )
2511
+ ):
2512
+ continue
2516
2513
  if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
2517
2514
  name = name.replace(
2518
2515
  "mlp.shared_experts",
@@ -2597,6 +2594,12 @@ class DeepseekV2ForCausalLM(nn.Module):
2597
2594
  # Skip loading extra bias for GPTQ models.
2598
2595
  if name.endswith(".bias") and name not in params_dict:
2599
2596
  continue
2597
+ # Skip loading embed_tokens if not first rank in pipeline parallelism
2598
+ if ".embed_tokens." in name and not self.pp_group.is_first_rank:
2599
+ continue
2600
+ # Skip loading norm if not last rank in pipeline parallelism
2601
+ if ".norm." in name and not self.pp_group.is_last_rank:
2602
+ continue
2600
2603
  if fuse_qkv_a_proj and (
2601
2604
  "q_a_proj" in name or "kv_a_proj_with_mqa" in name
2602
2605
  ):