sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ import concurrent.futures
20
20
  import logging
21
21
  import os
22
22
  from enum import IntEnum, auto
23
- from typing import Any, Dict, Iterable, Optional, Tuple
23
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
24
24
 
25
25
  import torch
26
26
  import torch.nn.functional as F
@@ -30,6 +30,7 @@ from transformers import PretrainedConfig
30
30
 
31
31
  from sglang.srt.distributed import (
32
32
  get_moe_expert_parallel_world_size,
33
+ get_pp_group,
33
34
  get_tensor_model_parallel_world_size,
34
35
  parallel_state,
35
36
  tensor_model_parallel_all_reduce,
@@ -50,7 +51,6 @@ from sglang.srt.layers.communicator import (
50
51
  from sglang.srt.layers.dp_attention import (
51
52
  get_attention_tp_rank,
52
53
  get_attention_tp_size,
53
- get_local_attention_dp_size,
54
54
  is_dp_attention_enabled,
55
55
  )
56
56
  from sglang.srt.layers.layernorm import RMSNorm
@@ -61,9 +61,14 @@ from sglang.srt.layers.linear import (
61
61
  RowParallelLinear,
62
62
  )
63
63
  from sglang.srt.layers.logits_processor import LogitsProcessor
64
+ from sglang.srt.layers.moe import (
65
+ get_deepep_mode,
66
+ get_moe_a2a_backend,
67
+ should_use_flashinfer_cutlass_moe_fp4_allgather,
68
+ )
64
69
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
70
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
65
71
  from sglang.srt.layers.moe.topk import TopK
66
- from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
67
72
  from sglang.srt.layers.quantization import deep_gemm_wrapper
68
73
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
69
74
  from sglang.srt.layers.quantization.fp8_kernel import (
@@ -83,13 +88,13 @@ from sglang.srt.layers.quantization.int8_utils import (
83
88
  )
84
89
  from sglang.srt.layers.radix_attention import RadixAttention
85
90
  from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
86
- from sglang.srt.layers.utils import is_sm100_supported
91
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
87
92
  from sglang.srt.layers.vocab_parallel_embedding import (
88
93
  ParallelLMHead,
89
94
  VocabParallelEmbedding,
90
95
  )
91
96
  from sglang.srt.managers.schedule_batch import global_server_args_dict
92
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
97
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
93
98
  from sglang.srt.model_loader.weight_utils import default_weight_loader
94
99
  from sglang.srt.two_batch_overlap import (
95
100
  MaybeTboDeepEPDispatcher,
@@ -110,6 +115,7 @@ from sglang.srt.utils import (
110
115
  is_hip,
111
116
  is_non_idle_and_non_empty,
112
117
  log_info_on_rank0,
118
+ make_layers,
113
119
  use_intel_amx_backend,
114
120
  )
115
121
 
@@ -313,18 +319,7 @@ class DeepseekV2MoE(nn.Module):
313
319
  config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
314
320
  )
315
321
 
316
- self.topk = TopK(
317
- top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
318
- renormalize=config.norm_topk_prob,
319
- use_grouped_topk=True,
320
- num_expert_group=config.n_group,
321
- num_fused_shared_experts=self.num_fused_shared_experts,
322
- topk_group=config.topk_group,
323
- correction_bias=self.gate.e_score_correction_bias,
324
- routed_scaling_factor=self.routed_scaling_factor,
325
- )
326
-
327
- self.experts = get_moe_impl_class()(
322
+ self.experts = get_moe_impl_class(quant_config)(
328
323
  num_experts=config.n_routed_experts
329
324
  + self.num_fused_shared_experts
330
325
  + global_server_args_dict["ep_num_redundant_experts"],
@@ -336,30 +331,19 @@ class DeepseekV2MoE(nn.Module):
336
331
  quant_config=quant_config,
337
332
  routed_scaling_factor=self.routed_scaling_factor,
338
333
  prefix=add_prefix("experts", prefix),
339
- **(
340
- dict(deepep_mode=global_server_args_dict["deepep_mode"])
341
- if global_server_args_dict["moe_a2a_backend"].is_deepep()
342
- else {}
343
- ),
344
- # Additional args for FusedMoE
345
- **(
346
- dict(
347
- enable_flashinfer_cutlass_moe=True,
348
- )
349
- if global_server_args_dict["enable_flashinfer_cutlass_moe"]
350
- else {}
351
- ),
352
- **(
353
- dict(
354
- renormalize=config.norm_topk_prob,
355
- use_grouped_topk=True,
356
- num_expert_group=config.n_group,
357
- topk_group=config.topk_group,
358
- correction_bias=self.gate.e_score_correction_bias,
359
- )
360
- if should_use_flashinfer_trtllm_moe()
361
- else {}
362
- ),
334
+ )
335
+
336
+ self.topk = TopK(
337
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
338
+ renormalize=config.norm_topk_prob,
339
+ use_grouped_topk=True,
340
+ num_expert_group=config.n_group,
341
+ num_fused_shared_experts=self.num_fused_shared_experts,
342
+ topk_group=config.topk_group,
343
+ correction_bias=self.gate.e_score_correction_bias,
344
+ routed_scaling_factor=self.routed_scaling_factor,
345
+ apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
346
+ force_topk=quant_config is None,
363
347
  )
364
348
 
365
349
  self.shared_experts_is_int8 = False
@@ -367,7 +351,7 @@ class DeepseekV2MoE(nn.Module):
367
351
  self.shared_experts_weight_block_size = None
368
352
  if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
369
353
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
370
- # disable tp for shared experts when enable deepep moe
354
+ # disable tp for shared experts when enable deepep moe, or with fp4 allgather
371
355
  self.shared_experts = DeepseekV2MLP(
372
356
  hidden_size=config.hidden_size,
373
357
  intermediate_size=intermediate_size,
@@ -377,7 +361,8 @@ class DeepseekV2MoE(nn.Module):
377
361
  prefix=add_prefix("shared_experts", prefix),
378
362
  **(
379
363
  dict(tp_rank=0, tp_size=1)
380
- if global_server_args_dict["moe_a2a_backend"].is_deepep()
364
+ if get_moe_a2a_backend().is_deepep()
365
+ or should_use_flashinfer_cutlass_moe_fp4_allgather()
381
366
  else {}
382
367
  ),
383
368
  )
@@ -407,7 +392,7 @@ class DeepseekV2MoE(nn.Module):
407
392
 
408
393
  self.top_k = config.num_experts_per_tok
409
394
 
410
- if global_server_args_dict["moe_a2a_backend"].is_deepep():
395
+ if get_moe_a2a_backend().is_deepep():
411
396
  # TODO: we will support tp < ep in the future
412
397
  self.ep_size = get_moe_expert_parallel_world_size()
413
398
  self.num_experts = (
@@ -431,12 +416,12 @@ class DeepseekV2MoE(nn.Module):
431
416
  num_local_experts=config.n_routed_experts // self.tp_size,
432
417
  hidden_size=config.hidden_size,
433
418
  params_dtype=config.torch_dtype,
434
- deepep_mode=global_server_args_dict["deepep_mode"],
419
+ deepep_mode=get_deepep_mode(),
435
420
  async_finish=True,
436
421
  return_recv_hook=True,
437
422
  )
438
423
 
439
- self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
424
+ self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
440
425
 
441
426
  def get_moe_weights(self):
442
427
  return [
@@ -457,14 +442,19 @@ class DeepseekV2MoE(nn.Module):
457
442
  if (
458
443
  self.alt_stream is not None
459
444
  and self.num_fused_shared_experts == 0
445
+ and hidden_states.shape[0] > 0
460
446
  and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
461
447
  ):
462
448
  return self.forward_normal_dual_stream(
463
- hidden_states, should_allreduce_fusion, use_reduce_scatter
449
+ hidden_states,
450
+ should_allreduce_fusion,
451
+ use_reduce_scatter,
464
452
  )
465
453
  else:
466
454
  return self.forward_normal(
467
- hidden_states, should_allreduce_fusion, use_reduce_scatter
455
+ hidden_states,
456
+ should_allreduce_fusion,
457
+ use_reduce_scatter,
468
458
  )
469
459
  else:
470
460
  return self.forward_deepep(hidden_states, forward_batch)
@@ -483,25 +473,24 @@ class DeepseekV2MoE(nn.Module):
483
473
  with torch.cuda.stream(self.alt_stream):
484
474
  # router_logits: (num_tokens, n_experts)
485
475
  router_logits = self.gate(hidden_states)
486
- kwargs = {"hidden_states": hidden_states}
487
-
488
- # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
489
- # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
490
- if should_use_flashinfer_trtllm_moe():
491
- kwargs["topk_output"] = (self.topk, router_logits)
492
- else:
493
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
494
-
495
- final_hidden_states = self.experts(**kwargs)
476
+ topk_output = self.topk(hidden_states, router_logits)
477
+ final_hidden_states = self.experts(hidden_states, topk_output)
496
478
  if not _is_cuda:
497
479
  final_hidden_states *= self.routed_scaling_factor
480
+
498
481
  current_stream.wait_stream(self.alt_stream)
499
482
  with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
500
483
  final_hidden_states_out = torch.empty_like(final_hidden_states)
484
+
501
485
  torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
502
486
  final_hidden_states = final_hidden_states_out
503
487
  sm.tag(final_hidden_states)
504
- if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
488
+ if (
489
+ self.tp_size > 1
490
+ and not should_allreduce_fusion
491
+ and not use_reduce_scatter
492
+ and not should_use_flashinfer_cutlass_moe_fp4_allgather()
493
+ ):
505
494
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
506
495
  return final_hidden_states
507
496
 
@@ -516,19 +505,16 @@ class DeepseekV2MoE(nn.Module):
516
505
  ):
517
506
  return self.forward_cpu(hidden_states, should_allreduce_fusion)
518
507
 
519
- shared_output = self._forward_shared_experts(hidden_states)
520
- # router_logits: (num_tokens, n_experts)
521
- router_logits = self.gate(hidden_states)
522
- kwargs = {"hidden_states": hidden_states}
523
-
524
- # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
525
- # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
526
- if should_use_flashinfer_trtllm_moe():
527
- kwargs["topk_output"] = (self.topk, router_logits)
508
+ if hidden_states.shape[0] > 0:
509
+ shared_output = self._forward_shared_experts(hidden_states)
510
+ # router_logits: (num_tokens, n_experts)
511
+ router_logits = self.gate(hidden_states)
512
+ topk_output = self.topk(hidden_states, router_logits)
528
513
  else:
529
- kwargs["topk_output"] = self.topk(hidden_states, router_logits)
514
+ shared_output = None
515
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
530
516
 
531
- final_hidden_states = self.experts(**kwargs)
517
+ final_hidden_states = self.experts(hidden_states, topk_output)
532
518
  if not _is_cuda and not _use_aiter:
533
519
  # fused in biased_grouped_topk so we can skip here
534
520
  final_hidden_states *= self.routed_scaling_factor
@@ -538,7 +524,12 @@ class DeepseekV2MoE(nn.Module):
538
524
  torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
539
525
  final_hidden_states = final_hidden_states_out
540
526
  sm.tag(final_hidden_states)
541
- if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
527
+ if (
528
+ self.tp_size > 1
529
+ and not should_allreduce_fusion
530
+ and not use_reduce_scatter
531
+ and not should_use_flashinfer_cutlass_moe_fp4_allgather()
532
+ ):
542
533
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
543
534
  return final_hidden_states
544
535
 
@@ -617,11 +608,8 @@ class DeepseekV2MoE(nn.Module):
617
608
  ),
618
609
  )
619
610
  else:
620
- topk_idx = torch.full(
621
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
622
- )
623
- topk_weights = torch.empty(
624
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
611
+ topk_weights, topk_idx, _ = self.topk.empty_topk_output(
612
+ hidden_states.device
625
613
  )
626
614
 
627
615
  final_hidden_states = self.experts(
@@ -1007,29 +995,33 @@ class DeepseekV2AttentionMLA(nn.Module):
1007
995
 
1008
996
  if attention_backend == "ascend":
1009
997
  return AttnForwardMethod.MLA
1010
- elif attention_backend == "flashinfer":
998
+ elif (
999
+ attention_backend == "flashinfer"
1000
+ or attention_backend == "fa3"
1001
+ or attention_backend == "flashmla"
1002
+ or attention_backend == "trtllm_mla"
1003
+ or attention_backend == "cutlass_mla"
1004
+ ):
1005
+ # Use MHA with chunked KV cache when prefilling on long sequences.
1006
+ sum_extend_prefix_lens = (
1007
+ sum(forward_batch.extend_prefix_lens_cpu)
1008
+ if forward_batch.extend_prefix_lens_cpu is not None
1009
+ else 0
1010
+ )
1011
1011
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
1012
+ disable_ragged = (
1013
+ attention_backend == "flashinfer" or attention_backend == "flashmla"
1014
+ ) and self.flashinfer_mla_disable_ragged
1012
1015
  if (
1013
- not self.flashinfer_mla_disable_ragged
1016
+ not disable_ragged
1014
1017
  and forward_batch.forward_mode.is_extend()
1015
1018
  and not forward_batch.forward_mode.is_target_verify()
1016
1019
  and not forward_batch.forward_mode.is_draft_extend()
1017
- and sum(forward_batch.extend_prefix_lens_cpu) == 0
1018
- ):
1019
- return AttnForwardMethod.MHA
1020
- else:
1021
- return _dispatch_mla_subtype()
1022
- elif attention_backend == "fa3":
1023
- # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
1024
- if forward_batch.extend_prefix_lens_cpu is not None:
1025
- sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
1026
- if (
1027
- forward_batch.forward_mode.is_extend()
1028
- and not self.disable_chunked_prefix_cache
1029
- and not forward_batch.forward_mode.is_target_verify()
1030
- and not forward_batch.forward_mode.is_draft_extend()
1031
1020
  and (
1032
- sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
1021
+ (
1022
+ sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
1023
+ and not self.disable_chunked_prefix_cache
1024
+ )
1033
1025
  or sum_extend_prefix_lens == 0
1034
1026
  )
1035
1027
  ):
@@ -1697,7 +1689,6 @@ class DeepseekV2AttentionMLA(nn.Module):
1697
1689
  k[..., self.qk_nope_head_dim :] = k_pe
1698
1690
 
1699
1691
  output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
1700
- lse = torch.transpose(lse, 0, 1).contiguous()
1701
1692
  tmp_output = torch.empty_like(accum_output)
1702
1693
  tmp_lse = torch.empty_like(accum_lse)
1703
1694
  merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
@@ -1719,55 +1710,26 @@ class DeepseekV2AttentionMLA(nn.Module):
1719
1710
  # will be helpful for understanding the purpose of this function.
1720
1711
 
1721
1712
  # First do normal mha forward to get output for extended part
1722
- if self.q_lora_rank is not None:
1723
- q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
1724
- [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1725
- )
1726
- q = self.q_a_layernorm(q)
1727
- q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
1728
- else:
1729
- q = self.q_proj(hidden_states)[0].view(
1730
- -1, self.num_local_heads, self.qk_head_dim
1731
- )
1732
- latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1733
- _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1734
- kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1735
- latent_cache = latent_cache.unsqueeze(1)
1736
- kv_a = self.kv_a_layernorm(kv_a)
1737
- kv = self.kv_b_proj(kv_a)[0]
1738
- kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1739
- k_nope = kv[..., : self.qk_nope_head_dim]
1740
- v = kv[..., self.qk_nope_head_dim :]
1741
- k_pe = latent_cache[:, :, self.kv_lora_rank :]
1742
-
1743
- q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1744
- q[..., self.qk_nope_head_dim :] = q_pe
1745
- k = torch.empty_like(q)
1746
- k[..., : self.qk_nope_head_dim] = k_nope
1747
- k[..., self.qk_nope_head_dim :] = k_pe
1748
-
1749
- latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1750
- latent_cache[:, :, self.kv_lora_rank :] = k_pe
1751
-
1752
- # Save latent cache
1753
- forward_batch.token_to_kv_pool.set_kv_buffer(
1754
- self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1713
+ return self.forward_normal_prepare(
1714
+ positions, hidden_states, forward_batch, zero_allocator
1755
1715
  )
1756
1716
 
1757
- return q, k, v, forward_batch
1758
-
1759
1717
  def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
1718
+ has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
1719
+ # Only initialize the info once
1720
+ if has_extend_prefix and forward_batch.num_prefix_chunks is None:
1721
+ forward_batch.prepare_chunked_prefix_cache_info(q.device)
1722
+ if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
1723
+ forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
1724
+
1725
+ forward_batch.mha_return_lse = has_extend_prefix
1760
1726
  # Do mha for extended part without prefix
1761
1727
  forward_batch.set_attn_attend_prefix_cache(False)
1762
- attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
1763
- lse = torch.transpose(lse, 0, 1).contiguous()
1728
+ attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
1764
1729
 
1765
1730
  # Do mha attention with chunked prefix cache if there are any sequence with prefix
1766
- if any(forward_batch.extend_prefix_lens_cpu):
1767
- # Only initialize the info once
1768
- if forward_batch.num_prefix_chunks is None:
1769
- forward_batch.prepare_chunked_prefix_cache_info(q.device)
1770
-
1731
+ if has_extend_prefix:
1732
+ attn_output, lse = attn_output
1771
1733
  forward_batch.set_attn_attend_prefix_cache(True)
1772
1734
  attn_output = self._chunked_prefix_attn_mha(
1773
1735
  q=q,
@@ -1866,10 +1828,11 @@ class DeepseekV2DecoderLayer(nn.Module):
1866
1828
  input_layernorm=self.input_layernorm,
1867
1829
  post_attention_layernorm=self.post_attention_layernorm,
1868
1830
  allow_reduce_scatter=True,
1831
+ is_last_layer=(
1832
+ is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
1833
+ ),
1869
1834
  )
1870
1835
 
1871
- self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
1872
-
1873
1836
  def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
1874
1837
  return is_nextn or (
1875
1838
  self.config.n_routed_experts is not None
@@ -1877,20 +1840,6 @@ class DeepseekV2DecoderLayer(nn.Module):
1877
1840
  and layer_id % self.config.moe_layer_freq == 0
1878
1841
  )
1879
1842
 
1880
- def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
1881
- """Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
1882
-
1883
- batch_size = (
1884
- forward_batch.input_ids.shape[0]
1885
- if hasattr(forward_batch, "input_ids")
1886
- else 0
1887
- )
1888
-
1889
- if batch_size > 128:
1890
- return False
1891
-
1892
- return self._fuse_allreduce_lookup_table.get(batch_size, False)
1893
-
1894
1843
  def forward(
1895
1844
  self,
1896
1845
  positions: torch.Tensor,
@@ -1916,11 +1865,9 @@ class DeepseekV2DecoderLayer(nn.Module):
1916
1865
  )
1917
1866
 
1918
1867
  should_allreduce_fusion = (
1919
- self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
1920
- and not (
1921
- is_dp_attention_enabled() and self.speculative_algorithm.is_eagle()
1868
+ self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
1869
+ forward_batch
1922
1870
  )
1923
- and not self.is_nextn
1924
1871
  )
1925
1872
 
1926
1873
  # For DP with padding, reduce scatter can be used instead of all-reduce.
@@ -2011,26 +1958,6 @@ class DeepseekV2DecoderLayer(nn.Module):
2011
1958
  )
2012
1959
  return output
2013
1960
 
2014
- def _build_fuse_allreduce_lookup_table(self):
2015
- static_conditions_met = (
2016
- self.layer_id != self.config.num_hidden_layers - 1
2017
- and get_tensor_model_parallel_world_size() > 1
2018
- and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
2019
- and _is_sm100_supported
2020
- and _is_flashinfer_available
2021
- )
2022
-
2023
- if not static_conditions_met:
2024
- return {}
2025
-
2026
- lookup_table = {}
2027
- for batch_size in range(129): # 0 to 128
2028
- is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
2029
- should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
2030
- lookup_table[batch_size] = should_fuse
2031
-
2032
- return lookup_table
2033
-
2034
1961
 
2035
1962
  class DeepseekV2Model(nn.Module):
2036
1963
  fall_back_to_pt_during_load = False
@@ -2045,26 +1972,52 @@ class DeepseekV2Model(nn.Module):
2045
1972
  self.padding_id = config.pad_token_id
2046
1973
  self.vocab_size = config.vocab_size
2047
1974
  self.first_k_dense_replace = config.first_k_dense_replace
1975
+ self.pp_group = get_pp_group()
1976
+
1977
+ if self.pp_group.is_first_rank:
1978
+ self.embed_tokens = VocabParallelEmbedding(
1979
+ config.vocab_size,
1980
+ config.hidden_size,
1981
+ enable_tp=not is_dp_attention_enabled(),
1982
+ )
1983
+ else:
1984
+ self.embed_tokens = PPMissingLayer()
2048
1985
 
2049
- self.embed_tokens = VocabParallelEmbedding(
2050
- config.vocab_size,
2051
- config.hidden_size,
2052
- enable_tp=not is_dp_attention_enabled(),
2053
- )
2054
1986
  self.alt_stream = torch.cuda.Stream() if _is_cuda else None
2055
- self.layers = nn.ModuleList(
2056
- [
2057
- DeepseekV2DecoderLayer(
2058
- config,
2059
- layer_id,
2060
- quant_config=quant_config,
2061
- prefix=add_prefix(f"layers.{layer_id}", prefix),
2062
- alt_stream=self.alt_stream,
2063
- )
2064
- for layer_id in range(config.num_hidden_layers)
2065
- ]
1987
+ self.layers, self.start_layer, self.end_layer = make_layers(
1988
+ config.num_hidden_layers,
1989
+ lambda idx, prefix: DeepseekV2DecoderLayer(
1990
+ config=config,
1991
+ layer_id=idx,
1992
+ quant_config=quant_config,
1993
+ prefix=prefix,
1994
+ alt_stream=self.alt_stream,
1995
+ ),
1996
+ pp_rank=self.pp_group.rank_in_group,
1997
+ pp_size=self.pp_group.world_size,
1998
+ prefix=add_prefix("layers", prefix),
1999
+ offloader_kwargs=dict(
2000
+ submodule_accessor=lambda layer: (
2001
+ layer.mlp.experts
2002
+ if isinstance(layer.mlp, DeepseekV2MoE)
2003
+ else layer.mlp
2004
+ ),
2005
+ whitelist_param_names_creator=lambda module: (
2006
+ [
2007
+ "w13_weight",
2008
+ "w2_weight",
2009
+ "w13_blockscale_swizzled",
2010
+ "w2_blockscale_swizzled",
2011
+ ]
2012
+ if isinstance(module, FusedMoE)
2013
+ else []
2014
+ ),
2015
+ ),
2066
2016
  )
2067
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2017
+ if self.pp_group.is_last_rank:
2018
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2019
+ else:
2020
+ self.norm = PPMissingLayer(return_tuple=True)
2068
2021
 
2069
2022
  def get_input_embeddings(self) -> torch.Tensor:
2070
2023
  return self.embed_tokens
@@ -2075,8 +2028,9 @@ class DeepseekV2Model(nn.Module):
2075
2028
  positions: torch.Tensor,
2076
2029
  forward_batch: ForwardBatch,
2077
2030
  input_embeds: torch.Tensor = None,
2078
- ) -> torch.Tensor:
2079
- total_num_layers = len(self.layers)
2031
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
2032
+ ) -> Union[torch.Tensor, PPProxyTensors]:
2033
+ total_num_layers = self.end_layer - self.start_layer
2080
2034
  device = input_embeds.device if input_embeds is not None else input_ids.device
2081
2035
  zero_allocator = BumpAllocator(
2082
2036
  buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
@@ -2084,44 +2038,62 @@ class DeepseekV2Model(nn.Module):
2084
2038
  device=device,
2085
2039
  )
2086
2040
 
2087
- if input_embeds is None:
2088
- hidden_states = self.embed_tokens(input_ids)
2041
+ if self.pp_group.is_first_rank:
2042
+ if input_embeds is None:
2043
+ hidden_states = self.embed_tokens(input_ids)
2044
+ else:
2045
+ hidden_states = input_embeds
2046
+ residual = None
2089
2047
  else:
2090
- hidden_states = input_embeds
2048
+ assert pp_proxy_tensors is not None
2049
+ hidden_states = pp_proxy_tensors["hidden_states"]
2050
+ residual = pp_proxy_tensors["residual"]
2091
2051
 
2092
- residual = None
2052
+ normal_start_layer = self.start_layer
2053
+ normal_end_layer = self.end_layer
2054
+ if forward_batch.can_run_tbo:
2055
+ if (
2056
+ self.first_k_dense_replace > normal_start_layer
2057
+ and self.first_k_dense_replace < normal_end_layer
2058
+ ):
2059
+ normal_end_layer = self.first_k_dense_replace
2060
+ elif self.first_k_dense_replace < normal_start_layer:
2061
+ normal_end_layer = normal_start_layer = 0
2093
2062
 
2094
- normal_num_layers = (
2095
- self.first_k_dense_replace
2096
- if forward_batch.can_run_tbo
2097
- else total_num_layers
2098
- )
2099
- for i in range(normal_num_layers):
2063
+ for i in range(normal_start_layer, normal_end_layer):
2100
2064
  with get_global_expert_distribution_recorder().with_current_layer(i):
2101
2065
  layer = self.layers[i]
2102
2066
  hidden_states, residual = layer(
2103
2067
  positions, hidden_states, forward_batch, residual, zero_allocator
2104
2068
  )
2105
2069
 
2106
- if normal_num_layers != total_num_layers:
2070
+ if normal_end_layer != self.end_layer:
2107
2071
  hidden_states, residual = model_forward_maybe_tbo(
2108
- layers=self.layers[normal_num_layers:],
2072
+ layers=self.layers[normal_end_layer : self.end_layer],
2109
2073
  enable_tbo=True,
2110
2074
  positions=positions,
2111
2075
  forward_batch=forward_batch,
2112
2076
  hidden_states=hidden_states,
2113
2077
  residual=residual,
2114
2078
  input_data_scatter_mode=self.layers[
2115
- normal_num_layers - 1
2079
+ normal_end_layer - 1
2116
2080
  ].layer_scatter_modes.layer_output_mode,
2117
2081
  zero_allocator=zero_allocator,
2118
2082
  )
2119
2083
 
2120
- if not forward_batch.forward_mode.is_idle():
2121
- if residual is None:
2122
- hidden_states = self.norm(hidden_states)
2123
- else:
2124
- hidden_states, _ = self.norm(hidden_states, residual)
2084
+ if not self.pp_group.is_last_rank:
2085
+ return PPProxyTensors(
2086
+ {
2087
+ "hidden_states": hidden_states,
2088
+ "residual": residual,
2089
+ }
2090
+ )
2091
+ else:
2092
+ if not forward_batch.forward_mode.is_idle():
2093
+ if residual is None:
2094
+ hidden_states = self.norm(hidden_states)
2095
+ else:
2096
+ hidden_states, _ = self.norm(hidden_states, residual)
2125
2097
  return hidden_states
2126
2098
 
2127
2099
 
@@ -2148,6 +2120,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2148
2120
  "kv_a_proj_with_mqa",
2149
2121
  ]
2150
2122
 
2123
+ self.pp_group = get_pp_group()
2151
2124
  self.config = config
2152
2125
  self.tp_size = get_tensor_model_parallel_world_size()
2153
2126
  self.quant_config = quant_config
@@ -2217,13 +2190,27 @@ class DeepseekV2ForCausalLM(nn.Module):
2217
2190
  positions: torch.Tensor,
2218
2191
  forward_batch: ForwardBatch,
2219
2192
  input_embeds: torch.Tensor = None,
2193
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
2220
2194
  ) -> torch.Tensor:
2221
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
2222
-
2223
- return self.logits_processor(
2224
- input_ids, hidden_states, self.lm_head, forward_batch
2195
+ hidden_states = self.model(
2196
+ input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
2225
2197
  )
2226
2198
 
2199
+ if self.pp_group.is_last_rank:
2200
+ return self.logits_processor(
2201
+ input_ids, hidden_states, self.lm_head, forward_batch
2202
+ )
2203
+ else:
2204
+ return hidden_states
2205
+
2206
+ @property
2207
+ def start_layer(self):
2208
+ return self.model.start_layer
2209
+
2210
+ @property
2211
+ def end_layer(self):
2212
+ return self.model.end_layer
2213
+
2227
2214
  def post_load_weights(self, is_nextn=False, weight_names=None):
2228
2215
 
2229
2216
  # Perform post-processing after loading weights
@@ -2231,7 +2218,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2231
2218
  layer_ids = [self.config.num_hidden_layers]
2232
2219
  else:
2233
2220
  if weight_names is None:
2234
- layer_ids = range(self.config.num_hidden_layers)
2221
+ layer_ids = range(self.model.start_layer, self.model.end_layer)
2235
2222
  else:
2236
2223
  layer_ids = set()
2237
2224
  for name in weight_names:
@@ -2478,17 +2465,15 @@ class DeepseekV2ForCausalLM(nn.Module):
2478
2465
 
2479
2466
  # Params for weights, fp8 weight scales, fp8 activation scales
2480
2467
  # (param_name, weight_name, expert_id, shard_id)
2481
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
2468
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
2482
2469
  ckpt_gate_proj_name="gate_proj",
2483
2470
  ckpt_down_proj_name="down_proj",
2484
2471
  ckpt_up_proj_name="up_proj",
2485
2472
  num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
2486
2473
  )
2487
2474
  if self.quant_config and self.quant_config.get_name() == "w4afp8":
2488
- expert_params_mapping += (
2489
- get_moe_impl_class().make_expert_input_scale_params_mapping(
2490
- num_experts=self.config.n_routed_experts
2491
- )
2475
+ expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
2476
+ num_experts=self.config.n_routed_experts
2492
2477
  )
2493
2478
 
2494
2479
  # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
@@ -2515,6 +2500,16 @@ class DeepseekV2ForCausalLM(nn.Module):
2515
2500
  params_dict = dict(self.named_parameters())
2516
2501
  weight_names = []
2517
2502
  for name, loaded_weight in weights:
2503
+ layer_id = get_layer_id(name)
2504
+ if (
2505
+ layer_id is not None
2506
+ and hasattr(self.model, "start_layer")
2507
+ and (
2508
+ layer_id < self.model.start_layer
2509
+ or layer_id >= self.model.end_layer
2510
+ )
2511
+ ):
2512
+ continue
2518
2513
  if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
2519
2514
  name = name.replace(
2520
2515
  "mlp.shared_experts",
@@ -2599,6 +2594,12 @@ class DeepseekV2ForCausalLM(nn.Module):
2599
2594
  # Skip loading extra bias for GPTQ models.
2600
2595
  if name.endswith(".bias") and name not in params_dict:
2601
2596
  continue
2597
+ # Skip loading embed_tokens if not first rank in pipeline parallelism
2598
+ if ".embed_tokens." in name and not self.pp_group.is_first_rank:
2599
+ continue
2600
+ # Skip loading norm if not last rank in pipeline parallelism
2601
+ if ".norm." in name and not self.pp_group.is_last_rank:
2602
+ continue
2602
2603
  if fuse_qkv_a_proj and (
2603
2604
  "q_a_proj" in name or "kv_a_proj_with_mqa" in name
2604
2605
  ):