sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,51 @@
1
- # Copyright 2023-2024 SGLang Team
2
- # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/bailing_moe.py
3
-
4
- from collections.abc import Iterable
5
- from typing import Optional, Tuple
1
+ # coding=utf-8
2
+ # Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ SGLang BailingMoE model."""
21
+ import logging
22
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
6
23
 
7
24
  import torch
8
25
  import torch.nn.functional as F
9
26
  from torch import nn
10
- from transformers.configuration_utils import PretrainedConfig
27
+ from transformers import PretrainedConfig
11
28
 
12
29
  from sglang.srt.distributed import (
30
+ get_pp_group,
13
31
  get_tensor_model_parallel_world_size,
32
+ parallel_state,
14
33
  tensor_model_parallel_all_reduce,
15
34
  )
35
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
36
+ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
37
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
16
38
  from sglang.srt.layers.activation import SiluAndMul
39
+ from sglang.srt.layers.communicator import (
40
+ LayerCommunicator,
41
+ LayerScatterModes,
42
+ enable_moe_dense_fully_dp,
43
+ )
44
+ from sglang.srt.layers.dp_attention import (
45
+ get_attention_dp_size,
46
+ get_attention_tp_rank,
47
+ get_attention_tp_size,
48
+ )
17
49
  from sglang.srt.layers.layernorm import RMSNorm
18
50
  from sglang.srt.layers.linear import (
19
51
  MergedColumnParallelLinear,
@@ -22,356 +54,831 @@ from sglang.srt.layers.linear import (
22
54
  RowParallelLinear,
23
55
  )
24
56
  from sglang.srt.layers.logits_processor import LogitsProcessor
25
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
57
+ from sglang.srt.layers.moe import get_moe_a2a_backend
58
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
59
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
60
+ from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
26
61
  from sglang.srt.layers.moe.topk import TopK
62
+ from sglang.srt.layers.moe.utils import DeepEPMode
27
63
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
64
  from sglang.srt.layers.radix_attention import RadixAttention
29
65
  from sglang.srt.layers.rotary_embedding import get_rope
66
+ from sglang.srt.layers.utils import PPMissingLayer
30
67
  from sglang.srt.layers.vocab_parallel_embedding import (
31
68
  ParallelLMHead,
32
69
  VocabParallelEmbedding,
33
70
  )
34
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
71
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
72
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
73
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
35
74
  from sglang.srt.model_loader.weight_utils import default_weight_loader
36
- from sglang.srt.utils import add_prefix, make_layers
75
+ from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
37
76
 
77
+ LoraConfig = None
78
+ logger = logging.getLogger(__name__)
79
+ _is_cuda = is_cuda()
38
80
 
39
- class BailingAttention(nn.Module):
40
81
 
82
+ class BailingMoEMLP(nn.Module):
41
83
  def __init__(
42
84
  self,
85
+ intermediate_size: int,
43
86
  config: PretrainedConfig,
44
- layer_id: int = 0,
45
87
  quant_config: Optional[QuantizationConfig] = None,
88
+ reduce_results: Optional[bool] = True,
46
89
  prefix: str = "",
47
- ):
90
+ tp_rank: Optional[int] = None,
91
+ tp_size: Optional[int] = None,
92
+ ) -> None:
48
93
  super().__init__()
49
- self.hidden_size = config.hidden_size
50
- tp_size = get_tensor_model_parallel_world_size()
51
-
52
- self.total_num_heads = config.num_attention_heads
53
- self.total_num_kv_heads = config.num_key_value_heads
54
-
55
- assert self.total_num_heads % tp_size == 0
56
- assert self.total_num_kv_heads % tp_size == 0
57
-
58
- self.num_heads = self.total_num_heads // tp_size
59
- self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
60
- self.q_size = self.num_heads * self.head_dim
61
-
62
- self.num_kv_heads = self.total_num_kv_heads // tp_size
63
- self.kv_size = self.num_kv_heads * self.head_dim
64
- self.scale = self.head_dim**-0.5
94
+ self.tp_size = tp_size
65
95
 
66
- self.query_key_value = QKVParallelLinear(
67
- self.hidden_size,
68
- self.head_dim,
69
- self.total_num_heads,
70
- self.total_num_kv_heads,
71
- bias=(config.use_bias or config.use_qkv_bias),
72
- quant_config=quant_config,
73
- prefix=add_prefix("query_key_value", prefix),
74
- )
75
-
76
- self.dense = RowParallelLinear(
77
- self.total_num_heads * self.head_dim,
78
- self.hidden_size,
96
+ self.gate_up_proj = MergedColumnParallelLinear(
97
+ config.hidden_size,
98
+ [intermediate_size] * 2,
79
99
  bias=config.use_bias,
80
100
  quant_config=quant_config,
81
- prefix=add_prefix("dense", prefix),
101
+ prefix=add_prefix("gate_up_proj", prefix),
102
+ tp_rank=tp_rank,
103
+ tp_size=tp_size,
82
104
  )
83
-
84
- self.attn = RadixAttention(
85
- self.num_heads,
86
- self.head_dim,
87
- self.scale,
88
- num_kv_heads=self.num_kv_heads,
89
- layer_id=layer_id,
105
+ self.down_proj = RowParallelLinear(
106
+ intermediate_size,
107
+ config.hidden_size,
108
+ bias=config.use_bias,
109
+ reduce_results=reduce_results,
90
110
  quant_config=quant_config,
91
- prefix=add_prefix("attn", prefix),
111
+ prefix=add_prefix("down_proj", prefix),
112
+ tp_rank=tp_rank,
113
+ tp_size=tp_size,
92
114
  )
93
115
 
94
- self.rotary_emb = get_rope(
95
- self.head_dim,
96
- rotary_dim=self.head_dim,
97
- max_position=config.max_position_embeddings,
98
- base=config.rope_theta,
99
- is_neox_style=True,
100
- rope_scaling=config.rope_scaling,
101
- )
116
+ if config.hidden_act != "silu":
117
+ raise ValueError("Unsupported activation. Only silu is supported for now.")
118
+ self.act_fn = SiluAndMul()
102
119
 
103
120
  def forward(
104
121
  self,
105
122
  hidden_states: torch.Tensor,
106
- position_ids: torch.Tensor,
107
- forward_batch: ForwardBatch,
123
+ forward_batch: Optional[ForwardBatch] = None,
124
+ use_reduce_scatter: bool = False,
108
125
  ) -> torch.Tensor:
109
- qkv, _ = self.query_key_value(hidden_states)
110
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
126
+ if (self.tp_size == 1) and hidden_states.shape[0] == 0:
127
+ return hidden_states
111
128
 
112
- q, k = self.rotary_emb(position_ids, q, k)
113
- context_layer = self.attn(q, k, v, forward_batch)
114
- attn_output, _ = self.dense(context_layer)
115
- return attn_output
129
+ gate_up, _ = self.gate_up_proj(hidden_states)
130
+ hidden_states = self.act_fn(gate_up)
131
+ hidden_states, _ = self.down_proj(
132
+ hidden_states, skip_all_reduce=use_reduce_scatter
133
+ )
134
+ return hidden_states
116
135
 
117
136
 
118
- class BailingMLP(nn.Module):
137
+ class BailingMoEGate(nn.Module):
119
138
  def __init__(
120
139
  self,
121
- intermediate_size: int,
122
- config: PretrainedConfig,
123
- quant_config: Optional[QuantizationConfig] = None,
124
- reduce_results: Optional[bool] = True,
140
+ config,
141
+ params_dtype: Optional[torch.dtype] = None,
125
142
  prefix: str = "",
126
- ) -> None:
143
+ ):
127
144
  super().__init__()
128
- self.gate_up_proj = MergedColumnParallelLinear(
129
- config.hidden_size,
130
- [intermediate_size] * 2,
131
- bias=config.use_bias,
132
- quant_config=quant_config,
133
- prefix=add_prefix("gate_up_proj", prefix),
134
- )
135
- self.down_proj = RowParallelLinear(
136
- intermediate_size,
137
- config.hidden_size,
138
- bias=config.use_bias,
139
- quant_config=quant_config,
140
- reduce_results=reduce_results,
141
- prefix=add_prefix("down_proj", prefix),
145
+ if params_dtype is None:
146
+ params_dtype = torch.get_default_dtype()
147
+ self.params_dtype = params_dtype
148
+ self.weight = nn.Parameter(
149
+ torch.empty(
150
+ (config.num_experts, config.hidden_size),
151
+ dtype=self.params_dtype,
152
+ ),
142
153
  )
143
- self.act_fn = SiluAndMul()
144
-
145
- def forward(self, x):
146
- x, _ = self.gate_up_proj(x)
147
- x = self.act_fn(x)
148
- x, _ = self.down_proj(x)
149
- return x
154
+ if getattr(config, "moe_router_enable_expert_bias", False):
155
+ self.expert_bias = nn.Parameter(
156
+ torch.empty((config.num_experts,), dtype=torch.float32),
157
+ )
158
+ else:
159
+ self.expert_bias = None
150
160
 
161
+ def forward(self, hidden_states):
162
+ logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, None).to(
163
+ hidden_states.dtype
164
+ )
165
+ return logits
151
166
 
152
- class BailingMoE(nn.Module):
153
167
 
168
+ class BailingMoESparseMoeBlock(nn.Module):
154
169
  def __init__(
155
170
  self,
156
- config: PretrainedConfig,
157
171
  layer_id: int,
172
+ config: PretrainedConfig,
158
173
  quant_config: Optional[QuantizationConfig] = None,
174
+ alt_stream: Optional[torch.cuda.Stream] = None,
159
175
  prefix: str = "",
160
176
  ):
161
177
  super().__init__()
178
+ self.layer_id = layer_id
179
+ self.alt_stream = alt_stream
162
180
  self.tp_size = get_tensor_model_parallel_world_size()
163
- self.num_experts = config.num_experts
164
181
  self.top_k = config.num_experts_per_tok
182
+ self.norm_topk_prob = config.norm_topk_prob
165
183
  self.hidden_size = config.hidden_size
166
184
  self.num_shared_experts = config.num_shared_experts
167
- self.norm_expert_prob = config.norm_topk_prob
168
- self.moe_intermediate_size = config.moe_intermediate_size
185
+ self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
186
+ self.score_function = getattr(config, "score_function", None)
187
+
188
+ if config.hidden_act != "silu":
189
+ raise ValueError(
190
+ f"Unsupported activation: {config.hidden_act}. "
191
+ "Only silu is supported for now."
192
+ )
193
+
194
+ # Gate always runs at half / full precision for now.
195
+ router_dtype = getattr(config, "router_dtype", None)
196
+ if router_dtype is None:
197
+ self.router_dtype = None
198
+ elif router_dtype == "fp32":
199
+ self.router_dtype = torch.float32
200
+ else:
201
+ self.router_dtype = torch.bfloat16
202
+
203
+ # TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now
204
+ assert global_server_args_dict["ep_num_redundant_experts"] == 0
205
+ # check group topk
206
+ self.num_expert_group = getattr(config, "n_group", 0)
207
+ self.topk_group = getattr(config, "topk_group", 0)
208
+ if self.num_expert_group > 0 or self.topk_group > 0:
209
+ assert (
210
+ self.num_expert_group > 0
211
+ and 0 < self.topk_group <= self.num_expert_group
212
+ )
213
+ self.use_grouped_topk = True
214
+ else:
215
+ self.num_expert_group = self.topk_group = None
216
+ self.use_grouped_topk = False
169
217
 
170
- self.gate = ReplicatedLinear(
171
- self.hidden_size, self.num_experts, bias=False, quant_config=None
218
+ self.num_experts = (
219
+ config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
172
220
  )
173
221
 
174
- self.topk = TopK(top_k=self.top_k, renormalize=self.norm_expert_prob)
222
+ self.gate = BailingMoEGate(
223
+ config=config,
224
+ params_dtype=self.router_dtype,
225
+ prefix=add_prefix("gate", prefix),
226
+ )
227
+ self.correction_bias = (
228
+ self.gate.expert_bias.data if self.gate.expert_bias is not None else None
229
+ )
230
+
231
+ if self.score_function is not None:
232
+ assert (
233
+ self.score_function == "softmax" and self.correction_bias is None
234
+ ) or (
235
+ self.score_function == "sigmoid" and self.correction_bias is not None
236
+ ), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)"
237
+
238
+ self.topk = TopK(
239
+ top_k=self.top_k,
240
+ renormalize=self.norm_topk_prob,
241
+ use_grouped_topk=self.use_grouped_topk,
242
+ num_expert_group=self.num_expert_group,
243
+ # num_fused_shared_experts=self.num_fused_shared_experts,
244
+ topk_group=self.topk_group,
245
+ correction_bias=self.correction_bias,
246
+ routed_scaling_factor=self.routed_scaling_factor,
247
+ )
175
248
 
176
- self.experts = FusedMoE(
249
+ self.experts = get_moe_impl_class(quant_config)(
177
250
  num_experts=self.num_experts,
178
251
  top_k=self.top_k,
179
- layer_id=layer_id,
180
- hidden_size=self.hidden_size,
181
- intermediate_size=self.moe_intermediate_size,
182
- reduce_results=False,
252
+ layer_id=self.layer_id,
253
+ hidden_size=config.hidden_size,
254
+ intermediate_size=config.moe_intermediate_size,
183
255
  quant_config=quant_config,
256
+ routed_scaling_factor=self.routed_scaling_factor,
184
257
  prefix=add_prefix("experts", prefix),
185
258
  )
186
-
187
- if self.num_shared_experts > 0:
188
- shared_intermediate_size = (
189
- self.moe_intermediate_size * self.num_shared_experts
190
- )
191
- self.shared_experts = BailingMLP(
192
- intermediate_size=shared_intermediate_size,
259
+ # shared expert
260
+ if config.num_shared_experts is not None:
261
+ if hasattr(config, "moe_shared_expert_intermediate_size"):
262
+ intermediate_size = config.moe_shared_expert_intermediate_size
263
+ else:
264
+ intermediate_size = config.moe_intermediate_size
265
+ intermediate_size *= config.num_shared_experts
266
+ # disable tp for shared experts when enable deepep moe
267
+ self.shared_experts = BailingMoEMLP(
268
+ intermediate_size=intermediate_size,
193
269
  config=config,
194
270
  quant_config=quant_config,
195
271
  reduce_results=False,
196
272
  prefix=add_prefix("shared_experts", prefix),
273
+ **(
274
+ dict(tp_rank=0, tp_size=1)
275
+ if get_moe_a2a_backend().is_deepep()
276
+ else {}
277
+ ),
278
+ )
279
+ # dispatcher
280
+ if get_moe_a2a_backend().is_deepep():
281
+ # TODO: we will support tp < ep in the future
282
+ self.ep_size = get_tensor_model_parallel_world_size()
283
+
284
+ self.deepep_dispatcher = DeepEPDispatcher(
285
+ group=parallel_state.get_tp_group().device_group,
286
+ router_topk=self.top_k,
287
+ permute_fusion=True,
288
+ num_experts=self.num_experts,
289
+ num_local_experts=config.num_experts // self.tp_size,
290
+ hidden_size=config.hidden_size,
291
+ params_dtype=config.torch_dtype,
292
+ deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
293
+ async_finish=True, # TODO
294
+ return_recv_hook=True,
197
295
  )
296
+
297
+ def forward(
298
+ self,
299
+ hidden_states: torch.Tensor,
300
+ forward_batch: Optional[ForwardBatch] = None,
301
+ use_reduce_scatter: bool = False,
302
+ ) -> torch.Tensor:
303
+ if not get_moe_a2a_backend().is_deepep():
304
+ return self.forward_normal(hidden_states, use_reduce_scatter)
198
305
  else:
199
- self.shared_experts = None
306
+ return self.forward_deepep(hidden_states, forward_batch)
200
307
 
201
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
202
- orig_shape = hidden_states.shape
203
- hidden_states_flat = hidden_states.view(-1, self.hidden_size)
308
+ def get_moe_weights(self):
309
+ return [
310
+ x.data
311
+ for name, x in self.experts.named_parameters()
312
+ if name not in ["correction_bias"]
313
+ ]
204
314
 
315
+ def _forward_shared_experts(self, hidden_states: torch.Tensor):
205
316
  shared_output = None
206
- if self.shared_experts is not None:
207
- shared_output = self.shared_experts(hidden_states_flat)
317
+ if self.num_shared_experts > 0:
318
+ shared_output = self.shared_experts(hidden_states)
319
+ return shared_output
208
320
 
209
- router_logits, _ = self.gate(hidden_states_flat)
210
- topk_output = self.topk(hidden_states_flat, router_logits)
211
- final_hidden_states = self.experts(hidden_states_flat, topk_output)
321
+ def _forward_router_experts(self, hidden_states: torch.Tensor):
322
+ # router_logits: (num_tokens, n_experts)
323
+ router_logits = self.gate(hidden_states)
324
+ topk_output = self.topk(hidden_states, router_logits)
325
+ return self.experts(hidden_states, topk_output)
212
326
 
213
- if shared_output is not None:
327
+ def forward_normal_dual_stream(
328
+ self,
329
+ hidden_states: torch.Tensor,
330
+ ) -> torch.Tensor:
331
+ current_stream = torch.cuda.current_stream()
332
+ self.alt_stream.wait_stream(current_stream)
333
+ shared_output = self._forward_shared_experts(hidden_states.clone())
334
+
335
+ with torch.cuda.stream(self.alt_stream):
336
+ router_output = self._forward_router_experts(hidden_states)
337
+ current_stream.wait_stream(self.alt_stream)
338
+
339
+ return router_output, shared_output
340
+
341
+ def forward_normal(
342
+ self,
343
+ hidden_states: torch.Tensor,
344
+ use_reduce_scatter: bool = False,
345
+ ) -> torch.Tensor:
346
+ num_tokens, hidden_size = hidden_states.shape
347
+ hidden_states = hidden_states.view(-1, hidden_size)
348
+
349
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
350
+ if (
351
+ self.alt_stream is not None
352
+ and hidden_states.shape[0] > 0
353
+ and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
354
+ and get_is_capture_mode()
355
+ ):
356
+ final_hidden_states, shared_output = self.forward_normal_dual_stream(
357
+ hidden_states
358
+ )
359
+ else:
360
+ shared_output = self._forward_shared_experts(hidden_states)
361
+ final_hidden_states = self._forward_router_experts(hidden_states)
362
+
363
+ if self.num_shared_experts > 0:
214
364
  final_hidden_states = final_hidden_states + shared_output
215
365
 
216
- if self.tp_size > 1:
366
+ if self.tp_size > 1 and not use_reduce_scatter:
217
367
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
368
+ return final_hidden_states.view(num_tokens, hidden_size)
218
369
 
219
- return final_hidden_states.view(orig_shape)
370
+ def forward_deepep(
371
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
372
+ ) -> torch.Tensor:
373
+ shared_output = None
374
+ forward_mode = forward_batch.forward_mode
375
+ if is_non_idle_and_non_empty(forward_mode, hidden_states):
376
+ router_logits = self.gate(hidden_states)
377
+ if self.num_shared_experts > 0:
378
+ shared_output = self.shared_experts(hidden_states)
220
379
 
380
+ topk_weights, topk_idx, _ = self.topk(
381
+ hidden_states,
382
+ router_logits,
383
+ num_token_non_padded=forward_batch.num_token_non_padded,
384
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
385
+ layer_id=self.layer_id,
386
+ ),
387
+ )
388
+ else:
389
+ topk_idx = torch.full(
390
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
391
+ )
392
+ topk_weights = torch.empty(
393
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
394
+ )
395
+
396
+ if self.ep_size > 1:
397
+ (
398
+ hidden_states,
399
+ topk_idx,
400
+ topk_weights,
401
+ reorder_topk_ids,
402
+ num_recv_tokens_per_expert,
403
+ seg_indptr,
404
+ masked_m,
405
+ expected_m,
406
+ ) = self.deepep_dispatcher.dispatch(
407
+ hidden_states,
408
+ topk_idx,
409
+ topk_weights,
410
+ forward_batch=forward_batch,
411
+ )
221
412
 
222
- class BailingMoeBlock(nn.Module):
413
+ final_hidden_states = self.experts(
414
+ hidden_states=hidden_states,
415
+ topk_idx=topk_idx,
416
+ topk_weights=topk_weights,
417
+ reorder_topk_ids=reorder_topk_ids,
418
+ seg_indptr=seg_indptr,
419
+ masked_m=masked_m,
420
+ expected_m=expected_m,
421
+ num_recv_tokens_per_expert=num_recv_tokens_per_expert,
422
+ forward_batch=forward_batch,
423
+ )
424
+ if self.ep_size > 1:
425
+ final_hidden_states = self.deepep_dispatcher.combine(
426
+ final_hidden_states,
427
+ topk_idx,
428
+ topk_weights,
429
+ forward_batch=forward_batch,
430
+ )
223
431
 
432
+ final_hidden_states *= self.routed_scaling_factor
433
+
434
+ if shared_output is not None:
435
+ final_hidden_states = final_hidden_states + shared_output
436
+ return final_hidden_states
437
+
438
+
439
+ class BailingMoEAttention(nn.Module):
224
440
  def __init__(
225
441
  self,
226
442
  config: PretrainedConfig,
227
- layer_id: int,
443
+ layer_id: int = 0,
228
444
  quant_config: Optional[QuantizationConfig] = None,
445
+ reduce_results: bool = True,
229
446
  prefix: str = "",
447
+ alt_stream: Optional[torch.cuda.Stream] = None,
230
448
  ):
231
449
  super().__init__()
232
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
233
- self.attention = BailingAttention(
234
- config, layer_id, quant_config, prefix=add_prefix("attention", prefix)
450
+ self.hidden_size = config.hidden_size
451
+ self.total_num_heads = config.num_attention_heads
452
+ self.total_kv_heads = config.num_key_value_heads
453
+ self.dp_size = get_attention_dp_size()
454
+ attn_tp_rank = get_attention_tp_rank()
455
+ attn_tp_size = get_attention_tp_size()
456
+
457
+ assert self.total_num_heads % attn_tp_size == 0
458
+ assert self.total_kv_heads % attn_tp_size == 0
459
+ assert self.total_num_heads >= self.total_kv_heads
460
+
461
+ self.num_heads = self.total_num_heads // attn_tp_size
462
+ self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
463
+ self.q_size = self.head_dim * self.num_heads
464
+
465
+ self.num_kv_heads = self.total_kv_heads // attn_tp_size
466
+ self.kv_size = max(1, self.num_kv_heads * self.head_dim)
467
+
468
+ self.scale = self.head_dim**-0.5
469
+
470
+ self.use_qk_norm = getattr(config, "use_qk_norm", False)
471
+
472
+ self.query_key_value = QKVParallelLinear(
473
+ self.hidden_size,
474
+ self.head_dim,
475
+ self.total_num_heads,
476
+ self.total_kv_heads,
477
+ bias=(config.use_bias or config.use_qkv_bias),
478
+ quant_config=quant_config,
479
+ prefix=add_prefix("query_key_value", prefix),
480
+ tp_rank=attn_tp_rank,
481
+ tp_size=attn_tp_size,
235
482
  )
236
- self.post_attention_layernorm = RMSNorm(
237
- config.hidden_size, eps=config.rms_norm_eps
483
+
484
+ if self.use_qk_norm:
485
+ self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
486
+ self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
487
+
488
+ self.dense = RowParallelLinear(
489
+ self.total_num_heads * self.head_dim,
490
+ self.hidden_size,
491
+ bias=config.use_bias,
492
+ quant_config=quant_config,
493
+ reduce_results=reduce_results,
494
+ prefix=add_prefix("dense", prefix),
495
+ tp_rank=attn_tp_rank,
496
+ tp_size=attn_tp_size,
238
497
  )
239
- self.mlp = BailingMoE(
240
- config=config,
498
+
499
+ if hasattr(config, "partial_rotary_factor"):
500
+ self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
501
+ elif hasattr(config, "rotary_dim"):
502
+ self.rotary_dim = config.rotary_dim
503
+ else:
504
+ self.rotary_dim = self.head_dim
505
+ self.rotary_emb = get_rope(
506
+ self.head_dim,
507
+ rotary_dim=self.rotary_dim,
508
+ max_position=config.max_position_embeddings,
509
+ base=config.rope_theta,
510
+ rope_scaling=config.rope_scaling,
511
+ )
512
+
513
+ self.attn = RadixAttention(
514
+ self.num_heads,
515
+ self.head_dim,
516
+ self.scale,
517
+ num_kv_heads=self.num_kv_heads,
241
518
  layer_id=layer_id,
242
- quant_config=quant_config,
243
- prefix=add_prefix("mlp", prefix),
519
+ prefix=add_prefix("attn", prefix),
244
520
  )
245
521
 
522
+ self.alt_stream = alt_stream
523
+
524
+ def _apply_qk_norm(
525
+ self, q: torch.Tensor, k: torch.Tensor
526
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
527
+ # overlap qk norm
528
+ if self.alt_stream is not None and get_is_capture_mode():
529
+ current_stream = torch.cuda.current_stream()
530
+ self.alt_stream.wait_stream(current_stream)
531
+ q_by_head = q.reshape(-1, self.head_dim)
532
+ q_by_head = self.query_layernorm(q_by_head)
533
+ with torch.cuda.stream(self.alt_stream):
534
+ k_by_head = k.reshape(-1, self.head_dim)
535
+ k_by_head = self.key_layernorm(k_by_head)
536
+ current_stream.wait_stream(self.alt_stream)
537
+ else:
538
+ q_by_head = q.reshape(-1, self.head_dim)
539
+ q_by_head = self.query_layernorm(q_by_head)
540
+ k_by_head = k.reshape(-1, self.head_dim)
541
+ k_by_head = self.key_layernorm(k_by_head)
542
+ q = q_by_head.view(q.shape)
543
+ k = k_by_head.view(k.shape)
544
+ return q, k
545
+
246
546
  def forward(
247
547
  self,
548
+ positions: torch.Tensor,
248
549
  hidden_states: torch.Tensor,
249
- position_ids: torch.Tensor,
250
- residual: Optional[torch.Tensor],
251
550
  forward_batch: ForwardBatch,
252
- ) -> Tuple[torch.Tensor, torch.Tensor]:
253
- # Pre-normalization and residual connection for the attention block
254
- if residual is None:
255
- residual = hidden_states
256
- normed_hidden_states = self.input_layernorm(hidden_states)
551
+ ) -> torch.Tensor:
552
+ if hidden_states.shape[0] == 0:
553
+ return hidden_states
554
+ qkv, _ = self.query_key_value(hidden_states)
555
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
556
+ if self.use_qk_norm:
557
+ q, k = self._apply_qk_norm(q, k)
558
+ q, k = self.rotary_emb(positions, q, k)
559
+ context_layer = self.attn(q, k, v, forward_batch)
560
+ attn_output, _ = self.dense(context_layer)
561
+ return attn_output
562
+
563
+
564
+ class BailingMoEBlock(nn.Module):
565
+ def __init__(
566
+ self,
567
+ config: PretrainedConfig,
568
+ layer_id: int = 0,
569
+ quant_config: Optional[QuantizationConfig] = None,
570
+ prefix: str = "",
571
+ alt_stream: Optional[torch.cuda.Stream] = None,
572
+ ):
573
+ super().__init__()
574
+ hidden_size = config.hidden_size
575
+
576
+ self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
577
+ self.dp_size = get_attention_dp_size()
578
+ self.attention = BailingMoEAttention(
579
+ config,
580
+ layer_id,
581
+ quant_config,
582
+ reduce_results=False,
583
+ prefix=add_prefix("attention", prefix),
584
+ alt_stream=alt_stream,
585
+ )
586
+ self.layer_id = layer_id
587
+ self.attn_tp_size = get_attention_tp_size()
588
+ self.attn_tp_rank = get_attention_tp_rank()
589
+
590
+ self.is_layer_sparse = self._is_layer_sparse(
591
+ config, layer_id=layer_id, is_nextn=False
592
+ )
593
+ is_previous_layer_sparse = self._is_layer_sparse(
594
+ config, layer_id=layer_id - 1, is_nextn=False
595
+ )
596
+
597
+ self.layer_scatter_modes = LayerScatterModes.init_new(
598
+ layer_id=layer_id,
599
+ num_layers=config.num_hidden_layers,
600
+ is_layer_sparse=self.is_layer_sparse,
601
+ is_previous_layer_sparse=is_previous_layer_sparse,
602
+ )
603
+
604
+ self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
605
+
606
+ if self.is_layer_sparse:
607
+ self.mlp = BailingMoESparseMoeBlock(
608
+ layer_id=layer_id,
609
+ config=config,
610
+ quant_config=quant_config,
611
+ alt_stream=alt_stream,
612
+ prefix=add_prefix("mlp", prefix),
613
+ )
257
614
  else:
258
- normed_hidden_states, residual = self.input_layernorm(
259
- hidden_states, residual
615
+ if enable_moe_dense_fully_dp():
616
+ mlp_tp_rank, mlp_tp_size = 0, 1
617
+ else:
618
+ mlp_tp_rank, mlp_tp_size = None, None
619
+ self.mlp = BailingMoEMLP(
620
+ intermediate_size=config.intermediate_size,
621
+ config=config,
622
+ quant_config=quant_config,
623
+ prefix=add_prefix("mlp", prefix),
624
+ tp_rank=mlp_tp_rank,
625
+ tp_size=mlp_tp_size,
260
626
  )
261
627
 
262
- attn_output = self.attention(
263
- hidden_states=normed_hidden_states,
264
- position_ids=position_ids,
628
+ self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
629
+
630
+ self.layer_communicator = LayerCommunicator(
631
+ layer_scatter_modes=self.layer_scatter_modes,
632
+ input_layernorm=self.input_layernorm,
633
+ post_attention_layernorm=self.post_attention_layernorm,
634
+ allow_reduce_scatter=True,
635
+ )
636
+
637
+ def _is_layer_sparse(
638
+ self, config: PretrainedConfig, layer_id: int, is_nextn: bool
639
+ ) -> bool:
640
+ return is_nextn or (
641
+ config.num_experts is not None and layer_id >= config.first_k_dense_replace
642
+ )
643
+
644
+ def forward(
645
+ self,
646
+ positions: torch.Tensor,
647
+ hidden_states: torch.Tensor,
648
+ forward_batch: ForwardBatch,
649
+ residual: Optional[torch.Tensor],
650
+ ) -> torch.Tensor:
651
+ hidden_states, residual = self.layer_communicator.prepare_attn(
652
+ hidden_states=hidden_states,
653
+ residual=residual,
654
+ forward_batch=forward_batch,
655
+ )
656
+
657
+ hidden_states = self.attention(
658
+ positions=positions,
659
+ hidden_states=hidden_states,
265
660
  forward_batch=forward_batch,
266
661
  )
267
662
 
268
- # Pre-normalization and residual connection for the MLP block
269
- normed_hidden_states, residual = self.post_attention_layernorm(
270
- attn_output, residual
663
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
664
+ hidden_states=hidden_states,
665
+ residual=residual,
666
+ forward_batch=forward_batch,
667
+ )
668
+
669
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
670
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
671
+ forward_batch
271
672
  )
272
- mlp_output = self.mlp(normed_hidden_states)
273
673
 
274
- return mlp_output, residual
674
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
275
675
 
676
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
677
+ hidden_states=hidden_states,
678
+ residual=residual,
679
+ forward_batch=forward_batch,
680
+ )
276
681
 
277
- class BailingMoeModel(nn.Module):
682
+ return hidden_states, residual
683
+
684
+
685
+ class BailingMoEModel(nn.Module):
278
686
 
279
687
  def __init__(
280
688
  self,
281
689
  config: PretrainedConfig,
282
690
  quant_config: Optional[QuantizationConfig] = None,
691
+ alt_stream: Optional[torch.cuda.Stream] = None,
283
692
  prefix: str = "",
284
693
  ):
285
694
  super().__init__()
695
+ self.pp_group = get_pp_group()
286
696
  self.config = config
287
- self.padding_idx = config.pad_token_id
288
697
  self.vocab_size = config.vocab_size
289
698
  self.embed_dim = config.hidden_size
699
+ if self.pp_group.is_first_rank:
700
+ self.word_embeddings = VocabParallelEmbedding(
701
+ self.vocab_size,
702
+ self.embed_dim,
703
+ quant_config=quant_config,
704
+ prefix=add_prefix("word_embeddings", prefix),
705
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
706
+ )
707
+ else:
708
+ self.word_embeddings = PPMissingLayer()
290
709
 
291
- self.embed_tokens = VocabParallelEmbedding(
292
- config.vocab_size,
293
- config.hidden_size,
294
- prefix=add_prefix("embed_tokens", prefix),
295
- )
296
710
  self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)
297
711
 
298
- self.layers = make_layers(
712
+ self.layers, self.start_layer, self.end_layer = make_layers(
299
713
  config.num_hidden_layers,
300
- lambda idx, prefix: BailingMoeBlock(
301
- config=config,
714
+ lambda idx, prefix: BailingMoEBlock(
302
715
  layer_id=idx,
716
+ config=config,
303
717
  quant_config=quant_config,
304
718
  prefix=prefix,
719
+ alt_stream=alt_stream,
305
720
  ),
721
+ pp_rank=self.pp_group.rank_in_group,
722
+ pp_size=self.pp_group.world_size,
306
723
  prefix=add_prefix("layers", prefix),
307
724
  )
308
-
309
- self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
725
+ if self.pp_group.is_last_rank:
726
+ self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
727
+ else:
728
+ self.norm = PPMissingLayer(return_tuple=True)
310
729
 
311
730
  def forward(
312
731
  self,
313
732
  input_ids: torch.Tensor,
314
- position_ids: torch.Tensor,
733
+ positions: torch.Tensor,
315
734
  forward_batch: ForwardBatch,
316
- input_embeds: Optional[torch.Tensor] = None,
317
- ) -> torch.Tensor:
318
- if input_embeds is None:
319
- hidden_states = self.embed_tokens(input_ids)
735
+ input_embeds: torch.Tensor = None,
736
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
737
+ ) -> Union[torch.Tensor, PPProxyTensors]:
738
+ if self.pp_group.is_first_rank:
739
+ if input_embeds is None:
740
+ hidden_states = self.word_embeddings(input_ids)
741
+ else:
742
+ hidden_states = input_embeds
743
+ residual = None
320
744
  else:
321
- hidden_states = input_embeds
322
-
323
- residual = None
324
- for layer in self.layers:
325
- hidden_states, residual = layer(
326
- hidden_states,
327
- position_ids,
328
- residual,
329
- forward_batch,
745
+ assert pp_proxy_tensors is not None
746
+ hidden_states = pp_proxy_tensors["hidden_states"]
747
+ residual = pp_proxy_tensors["residual"]
748
+
749
+ for i in range(self.start_layer, self.end_layer):
750
+ with get_global_expert_distribution_recorder().with_current_layer(i):
751
+ layer = self.layers[i]
752
+ hidden_states, residual = layer(
753
+ positions,
754
+ hidden_states,
755
+ forward_batch,
756
+ residual,
757
+ )
758
+ if not self.pp_group.is_last_rank:
759
+ return PPProxyTensors(
760
+ {
761
+ "hidden_states": hidden_states,
762
+ "residual": residual,
763
+ }
330
764
  )
765
+ else:
766
+ if not forward_batch.forward_mode.is_idle():
767
+ if residual is None:
768
+ hidden_states = self.norm(hidden_states)
769
+ else:
770
+ hidden_states, _ = self.norm(hidden_states, residual)
771
+ return hidden_states
331
772
 
332
- hidden_states, _ = self.norm(hidden_states, residual)
333
- return hidden_states
334
-
335
-
336
- class BailingMoeForCausalLM(nn.Module):
337
773
 
774
+ class BailingMoEForCausalLM(nn.Module):
338
775
  def __init__(
339
776
  self,
340
777
  config: PretrainedConfig,
341
778
  quant_config: Optional[QuantizationConfig] = None,
342
- ) -> None:
779
+ prefix: str = "",
780
+ ):
343
781
  super().__init__()
782
+ self.pp_group = get_pp_group()
344
783
  self.config = config
345
- self.model = BailingMoeModel(config=config, quant_config=quant_config)
346
- self.lm_head = ParallelLMHead(
347
- num_embeddings=config.vocab_size,
348
- embedding_dim=config.hidden_size,
349
- quant_config=quant_config,
784
+ self.quant_config = quant_config
785
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
786
+
787
+ self.model = BailingMoEModel(
788
+ config,
789
+ quant_config,
790
+ alt_stream=alt_stream,
791
+ prefix=add_prefix("model", ""),
350
792
  )
351
- if config.tie_word_embeddings:
352
- self.lm_head.weight = self.model.embed_tokens.weight
353
793
 
794
+ # tie_word_embeddings为true,复用tie_word_embeddings,反之是独立的
795
+ if config.tie_word_embeddings:
796
+ self.lm_head = self.model.word_embeddings
797
+ else:
798
+ # TODO something wrong with ParallelLMHead with DP attention enabled
799
+ self.lm_head = ParallelLMHead(
800
+ config.vocab_size,
801
+ config.hidden_size,
802
+ quant_config=quant_config,
803
+ prefix=add_prefix("lm_head", prefix),
804
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
805
+ )
354
806
  self.logits_processor = LogitsProcessor(config)
355
807
 
808
+ @property
809
+ def start_layer(self):
810
+ return self.model.start_layer
811
+
812
+ @property
813
+ def end_layer(self):
814
+ return self.model.end_layer
815
+
816
+ def get_embed_and_head(self):
817
+ """Used by the eagle_worker."""
818
+ return self.model.word_embeddings.weight, self.lm_head.weight
819
+
820
+ def set_embed_and_head(self, embed, head):
821
+ """Used by the eagle_worker."""
822
+ del self.model.word_embeddings.weight
823
+ del self.lm_head.weight
824
+ self.model.word_embeddings.weight = embed
825
+ self.lm_head.weight = head
826
+ torch.cuda.empty_cache()
827
+ torch.cuda.synchronize()
828
+
829
+ @torch.no_grad()
356
830
  def forward(
357
831
  self,
358
832
  input_ids: torch.Tensor,
359
833
  positions: torch.Tensor,
360
834
  forward_batch: ForwardBatch,
361
- inputs_embeds: Optional[torch.Tensor] = None,
835
+ input_embeds: torch.Tensor = None,
836
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
362
837
  ) -> torch.Tensor:
363
- hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
364
- return self.logits_processor(
365
- input_ids, hidden_states, self.lm_head, forward_batch
838
+ hidden_states = self.model(
839
+ input_ids,
840
+ positions,
841
+ forward_batch,
842
+ input_embeds,
843
+ pp_proxy_tensors=pp_proxy_tensors,
366
844
  )
367
-
368
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
845
+ if self.pp_group.is_last_rank:
846
+ return self.logits_processor(
847
+ input_ids, hidden_states, self.lm_head, forward_batch
848
+ )
849
+ else:
850
+ return hidden_states
851
+
852
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
853
+ if is_nextn:
854
+ if hasattr(self.config, "num_nextn_predict_layers"):
855
+ num_nextn_layers = self.config.num_nextn_predict_layers
856
+ assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
857
+ # compatible with old design
858
+ nextn_layer_id = (
859
+ 0
860
+ if self.config.num_hidden_layers == 1
861
+ else self.config.num_hidden_layers
862
+ )
863
+ else:
864
+ raise ValueError("num_nextn_predict_layers is not in the config")
369
865
 
370
866
  stacked_params_mapping = [
867
+ # (param_name, shard_name, shard_id)
371
868
  ("gate_up_proj", "gate_proj", 0),
372
869
  ("gate_up_proj", "up_proj", 1),
373
870
  ]
374
871
 
872
+ if is_nextn:
873
+ nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
874
+ nextn_spec_weight_names = [
875
+ "final_layernorm",
876
+ "eh_proj",
877
+ "enorm",
878
+ "hnorm",
879
+ ]
880
+ # Params for weights, fp8 weight scales, fp8 activation scales
881
+ # (param_name, weight_name, expert_id, shard_id)
375
882
  expert_params_mapping = FusedMoE.make_expert_params_mapping(
376
883
  ckpt_gate_proj_name="gate_proj",
377
884
  ckpt_down_proj_name="down_proj",
@@ -381,39 +888,87 @@ class BailingMoeForCausalLM(nn.Module):
381
888
 
382
889
  params_dict = dict(self.named_parameters())
383
890
  for name, loaded_weight in weights:
891
+ if (
892
+ ("v_head" in name)
893
+ or ("inv_freq" in name)
894
+ or (self.config.tie_word_embeddings and "lm_head" in name)
895
+ ):
896
+ continue
384
897
 
385
898
  if (
386
899
  hasattr(self.config, "norm_head")
387
900
  and self.config.norm_head
388
901
  and "lm_head.weight" in name
389
902
  ):
903
+ import torch.nn.functional as F
904
+
390
905
  loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)
391
906
 
392
- if "model.word_embeddings.weight" == name:
393
- name = "model.embed_tokens.weight"
907
+ if is_nextn:
908
+ if not name.startswith(nextn_layer_prefix):
909
+ continue
910
+
911
+ # Use shared head and embed weights from target model
912
+ if "shared_head.head" in name or "embed_tokens" in name:
913
+ continue
914
+
915
+ is_decoder = True
916
+ # For nextn specific weights
917
+ for weight_name in nextn_spec_weight_names:
918
+ if weight_name in name:
919
+ name = name.replace(nextn_layer_prefix, "model")
920
+ is_decoder = False
921
+ break
922
+ # For decoder layer weights
923
+ if is_decoder:
924
+ name = name.replace(nextn_layer_prefix, "model.decoder")
394
925
 
395
926
  for param_name, weight_name, shard_id in stacked_params_mapping:
396
- if weight_name in name and "mlp.experts" not in name:
397
- full_param_name = name.replace(weight_name, param_name)
398
- param = params_dict[full_param_name]
399
- param.weight_loader(param, loaded_weight, shard_id)
400
- break
927
+ if weight_name not in name:
928
+ continue
929
+ # We have mlp.experts[0].gate_proj in the checkpoint.
930
+ # Since we handle the experts below in expert_params_mapping,
931
+ # we need to skip here BEFORE we update the name, otherwise
932
+ # name will be updated to mlp.experts[0].gate_up_proj, which
933
+ # will then be updated below in expert_params_mapping
934
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
935
+ if "mlp.experts" in name:
936
+ continue
937
+ name = name.replace(weight_name, param_name)
938
+ # Skip loading extra bias for GPTQ models.
939
+ if name.endswith(".bias") and name not in params_dict:
940
+ continue
941
+ if name not in params_dict:
942
+ continue
943
+
944
+ param = params_dict[name]
945
+ weight_loader = param.weight_loader
946
+ weight_loader(param, loaded_weight, shard_id)
947
+ break
401
948
  else:
402
- for p_name, w_name, e_id, s_id in expert_params_mapping:
403
- if w_name in name and "mlp.experts" in name:
404
- full_param_name = name.replace(w_name, p_name)
405
- param = params_dict[full_param_name]
406
- param.weight_loader(
407
- param,
408
- loaded_weight,
409
- full_param_name,
410
- shard_id=s_id,
411
- expert_id=e_id,
412
- )
413
- break
949
+ for mapping in expert_params_mapping:
950
+ param_name, weight_name, expert_id, shard_id = mapping
951
+ if weight_name not in name:
952
+ continue
953
+ name = name.replace(weight_name, param_name)
954
+ if name not in params_dict:
955
+ continue
956
+ param = params_dict[name]
957
+ weight_loader = param.weight_loader
958
+ weight_loader(
959
+ param,
960
+ loaded_weight,
961
+ name,
962
+ shard_id=shard_id,
963
+ expert_id=expert_id,
964
+ )
965
+ break
414
966
  else:
967
+ # Skip loading extra bias for GPTQ models.
415
968
  if name.endswith(".bias") and name not in params_dict:
416
969
  continue
970
+ if name not in params_dict:
971
+ continue
417
972
 
418
973
  param = params_dict[name]
419
974
  weight_loader = getattr(
@@ -421,5 +976,30 @@ class BailingMoeForCausalLM(nn.Module):
421
976
  )
422
977
  weight_loader(param, loaded_weight)
423
978
 
979
+ if not is_nextn:
980
+ self.routed_experts_weights_of_layer = {
981
+ layer_id: layer.mlp.get_moe_weights()
982
+ for layer_id, layer in enumerate(self.model.layers)
983
+ if not isinstance(layer, PPMissingLayer)
984
+ and isinstance(layer.mlp, BailingMoESparseMoeBlock)
985
+ }
986
+
987
+ @classmethod
988
+ def get_model_config_for_expert_location(cls, config):
989
+ num_groups = getattr(config, "n_group", 0)
990
+ return ModelConfigForExpertLocation(
991
+ num_layers=config.num_hidden_layers,
992
+ num_logical_experts=config.num_experts,
993
+ num_groups=None if num_groups == 0 else num_groups,
994
+ )
995
+
996
+
997
+ class BailingMoeForCausalLM(BailingMoEForCausalLM):
998
+ pass
999
+
1000
+
1001
+ class BailingMoeV2ForCausalLM(BailingMoEForCausalLM):
1002
+ pass
1003
+
424
1004
 
425
- EntryClass = BailingMoeForCausalLM
1005
+ EntryClass = [BailingMoEForCausalLM, BailingMoeForCausalLM, BailingMoeV2ForCausalLM]