sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
sglang/srt/models/bailing_moe.py
CHANGED
@@ -1,19 +1,51 @@
|
|
1
|
-
#
|
2
|
-
#
|
3
|
-
|
4
|
-
|
5
|
-
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
|
3
|
+
#
|
4
|
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5
|
+
# and OPT implementations in this library. It has been modified from its
|
6
|
+
# original forms to accommodate minor architectural differences compared
|
7
|
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8
|
+
#
|
9
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10
|
+
# you may not use this file except in compliance with the License.
|
11
|
+
# You may obtain a copy of the License at
|
12
|
+
#
|
13
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14
|
+
#
|
15
|
+
# Unless required by applicable law or agreed to in writing, software
|
16
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18
|
+
# See the License for the specific language governing permissions and
|
19
|
+
# limitations under the License.
|
20
|
+
""" SGLang BailingMoE model."""
|
21
|
+
import logging
|
22
|
+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
6
23
|
|
7
24
|
import torch
|
8
25
|
import torch.nn.functional as F
|
9
26
|
from torch import nn
|
10
|
-
from transformers
|
27
|
+
from transformers import PretrainedConfig
|
11
28
|
|
12
29
|
from sglang.srt.distributed import (
|
30
|
+
get_pp_group,
|
13
31
|
get_tensor_model_parallel_world_size,
|
32
|
+
parallel_state,
|
14
33
|
tensor_model_parallel_all_reduce,
|
15
34
|
)
|
35
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
36
|
+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
37
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
16
38
|
from sglang.srt.layers.activation import SiluAndMul
|
39
|
+
from sglang.srt.layers.communicator import (
|
40
|
+
LayerCommunicator,
|
41
|
+
LayerScatterModes,
|
42
|
+
enable_moe_dense_fully_dp,
|
43
|
+
)
|
44
|
+
from sglang.srt.layers.dp_attention import (
|
45
|
+
get_attention_dp_size,
|
46
|
+
get_attention_tp_rank,
|
47
|
+
get_attention_tp_size,
|
48
|
+
)
|
17
49
|
from sglang.srt.layers.layernorm import RMSNorm
|
18
50
|
from sglang.srt.layers.linear import (
|
19
51
|
MergedColumnParallelLinear,
|
@@ -22,356 +54,831 @@ from sglang.srt.layers.linear import (
|
|
22
54
|
RowParallelLinear,
|
23
55
|
)
|
24
56
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
25
|
-
from sglang.srt.layers.moe
|
57
|
+
from sglang.srt.layers.moe import get_moe_a2a_backend
|
58
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
59
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
60
|
+
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
26
61
|
from sglang.srt.layers.moe.topk import TopK
|
62
|
+
from sglang.srt.layers.moe.utils import DeepEPMode
|
27
63
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
64
|
from sglang.srt.layers.radix_attention import RadixAttention
|
29
65
|
from sglang.srt.layers.rotary_embedding import get_rope
|
66
|
+
from sglang.srt.layers.utils import PPMissingLayer
|
30
67
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
31
68
|
ParallelLMHead,
|
32
69
|
VocabParallelEmbedding,
|
33
70
|
)
|
34
|
-
from sglang.srt.
|
71
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
72
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
73
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
35
74
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
36
|
-
from sglang.srt.utils import add_prefix, make_layers
|
75
|
+
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
|
37
76
|
|
77
|
+
LoraConfig = None
|
78
|
+
logger = logging.getLogger(__name__)
|
79
|
+
_is_cuda = is_cuda()
|
38
80
|
|
39
|
-
class BailingAttention(nn.Module):
|
40
81
|
|
82
|
+
class BailingMoEMLP(nn.Module):
|
41
83
|
def __init__(
|
42
84
|
self,
|
85
|
+
intermediate_size: int,
|
43
86
|
config: PretrainedConfig,
|
44
|
-
layer_id: int = 0,
|
45
87
|
quant_config: Optional[QuantizationConfig] = None,
|
88
|
+
reduce_results: Optional[bool] = True,
|
46
89
|
prefix: str = "",
|
47
|
-
|
90
|
+
tp_rank: Optional[int] = None,
|
91
|
+
tp_size: Optional[int] = None,
|
92
|
+
) -> None:
|
48
93
|
super().__init__()
|
49
|
-
self.
|
50
|
-
tp_size = get_tensor_model_parallel_world_size()
|
51
|
-
|
52
|
-
self.total_num_heads = config.num_attention_heads
|
53
|
-
self.total_num_kv_heads = config.num_key_value_heads
|
54
|
-
|
55
|
-
assert self.total_num_heads % tp_size == 0
|
56
|
-
assert self.total_num_kv_heads % tp_size == 0
|
57
|
-
|
58
|
-
self.num_heads = self.total_num_heads // tp_size
|
59
|
-
self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
|
60
|
-
self.q_size = self.num_heads * self.head_dim
|
61
|
-
|
62
|
-
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
63
|
-
self.kv_size = self.num_kv_heads * self.head_dim
|
64
|
-
self.scale = self.head_dim**-0.5
|
94
|
+
self.tp_size = tp_size
|
65
95
|
|
66
|
-
self.
|
67
|
-
|
68
|
-
|
69
|
-
self.total_num_heads,
|
70
|
-
self.total_num_kv_heads,
|
71
|
-
bias=(config.use_bias or config.use_qkv_bias),
|
72
|
-
quant_config=quant_config,
|
73
|
-
prefix=add_prefix("query_key_value", prefix),
|
74
|
-
)
|
75
|
-
|
76
|
-
self.dense = RowParallelLinear(
|
77
|
-
self.total_num_heads * self.head_dim,
|
78
|
-
self.hidden_size,
|
96
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
97
|
+
config.hidden_size,
|
98
|
+
[intermediate_size] * 2,
|
79
99
|
bias=config.use_bias,
|
80
100
|
quant_config=quant_config,
|
81
|
-
prefix=add_prefix("
|
101
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
102
|
+
tp_rank=tp_rank,
|
103
|
+
tp_size=tp_size,
|
82
104
|
)
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
num_kv_heads=self.num_kv_heads,
|
89
|
-
layer_id=layer_id,
|
105
|
+
self.down_proj = RowParallelLinear(
|
106
|
+
intermediate_size,
|
107
|
+
config.hidden_size,
|
108
|
+
bias=config.use_bias,
|
109
|
+
reduce_results=reduce_results,
|
90
110
|
quant_config=quant_config,
|
91
|
-
prefix=add_prefix("
|
111
|
+
prefix=add_prefix("down_proj", prefix),
|
112
|
+
tp_rank=tp_rank,
|
113
|
+
tp_size=tp_size,
|
92
114
|
)
|
93
115
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
max_position=config.max_position_embeddings,
|
98
|
-
base=config.rope_theta,
|
99
|
-
is_neox_style=True,
|
100
|
-
rope_scaling=config.rope_scaling,
|
101
|
-
)
|
116
|
+
if config.hidden_act != "silu":
|
117
|
+
raise ValueError("Unsupported activation. Only silu is supported for now.")
|
118
|
+
self.act_fn = SiluAndMul()
|
102
119
|
|
103
120
|
def forward(
|
104
121
|
self,
|
105
122
|
hidden_states: torch.Tensor,
|
106
|
-
|
107
|
-
|
123
|
+
forward_batch: Optional[ForwardBatch] = None,
|
124
|
+
use_reduce_scatter: bool = False,
|
108
125
|
) -> torch.Tensor:
|
109
|
-
|
110
|
-
|
126
|
+
if (self.tp_size == 1) and hidden_states.shape[0] == 0:
|
127
|
+
return hidden_states
|
111
128
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
129
|
+
gate_up, _ = self.gate_up_proj(hidden_states)
|
130
|
+
hidden_states = self.act_fn(gate_up)
|
131
|
+
hidden_states, _ = self.down_proj(
|
132
|
+
hidden_states, skip_all_reduce=use_reduce_scatter
|
133
|
+
)
|
134
|
+
return hidden_states
|
116
135
|
|
117
136
|
|
118
|
-
class
|
137
|
+
class BailingMoEGate(nn.Module):
|
119
138
|
def __init__(
|
120
139
|
self,
|
121
|
-
|
122
|
-
|
123
|
-
quant_config: Optional[QuantizationConfig] = None,
|
124
|
-
reduce_results: Optional[bool] = True,
|
140
|
+
config,
|
141
|
+
params_dtype: Optional[torch.dtype] = None,
|
125
142
|
prefix: str = "",
|
126
|
-
)
|
143
|
+
):
|
127
144
|
super().__init__()
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
intermediate_size,
|
137
|
-
config.hidden_size,
|
138
|
-
bias=config.use_bias,
|
139
|
-
quant_config=quant_config,
|
140
|
-
reduce_results=reduce_results,
|
141
|
-
prefix=add_prefix("down_proj", prefix),
|
145
|
+
if params_dtype is None:
|
146
|
+
params_dtype = torch.get_default_dtype()
|
147
|
+
self.params_dtype = params_dtype
|
148
|
+
self.weight = nn.Parameter(
|
149
|
+
torch.empty(
|
150
|
+
(config.num_experts, config.hidden_size),
|
151
|
+
dtype=self.params_dtype,
|
152
|
+
),
|
142
153
|
)
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
return x
|
154
|
+
if getattr(config, "moe_router_enable_expert_bias", False):
|
155
|
+
self.expert_bias = nn.Parameter(
|
156
|
+
torch.empty((config.num_experts,), dtype=torch.float32),
|
157
|
+
)
|
158
|
+
else:
|
159
|
+
self.expert_bias = None
|
150
160
|
|
161
|
+
def forward(self, hidden_states):
|
162
|
+
logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, None).to(
|
163
|
+
hidden_states.dtype
|
164
|
+
)
|
165
|
+
return logits
|
151
166
|
|
152
|
-
class BailingMoE(nn.Module):
|
153
167
|
|
168
|
+
class BailingMoESparseMoeBlock(nn.Module):
|
154
169
|
def __init__(
|
155
170
|
self,
|
156
|
-
config: PretrainedConfig,
|
157
171
|
layer_id: int,
|
172
|
+
config: PretrainedConfig,
|
158
173
|
quant_config: Optional[QuantizationConfig] = None,
|
174
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
159
175
|
prefix: str = "",
|
160
176
|
):
|
161
177
|
super().__init__()
|
178
|
+
self.layer_id = layer_id
|
179
|
+
self.alt_stream = alt_stream
|
162
180
|
self.tp_size = get_tensor_model_parallel_world_size()
|
163
|
-
self.num_experts = config.num_experts
|
164
181
|
self.top_k = config.num_experts_per_tok
|
182
|
+
self.norm_topk_prob = config.norm_topk_prob
|
165
183
|
self.hidden_size = config.hidden_size
|
166
184
|
self.num_shared_experts = config.num_shared_experts
|
167
|
-
self.
|
168
|
-
self.
|
185
|
+
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
|
186
|
+
self.score_function = getattr(config, "score_function", None)
|
187
|
+
|
188
|
+
if config.hidden_act != "silu":
|
189
|
+
raise ValueError(
|
190
|
+
f"Unsupported activation: {config.hidden_act}. "
|
191
|
+
"Only silu is supported for now."
|
192
|
+
)
|
193
|
+
|
194
|
+
# Gate always runs at half / full precision for now.
|
195
|
+
router_dtype = getattr(config, "router_dtype", None)
|
196
|
+
if router_dtype is None:
|
197
|
+
self.router_dtype = None
|
198
|
+
elif router_dtype == "fp32":
|
199
|
+
self.router_dtype = torch.float32
|
200
|
+
else:
|
201
|
+
self.router_dtype = torch.bfloat16
|
202
|
+
|
203
|
+
# TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now
|
204
|
+
assert global_server_args_dict["ep_num_redundant_experts"] == 0
|
205
|
+
# check group topk
|
206
|
+
self.num_expert_group = getattr(config, "n_group", 0)
|
207
|
+
self.topk_group = getattr(config, "topk_group", 0)
|
208
|
+
if self.num_expert_group > 0 or self.topk_group > 0:
|
209
|
+
assert (
|
210
|
+
self.num_expert_group > 0
|
211
|
+
and 0 < self.topk_group <= self.num_expert_group
|
212
|
+
)
|
213
|
+
self.use_grouped_topk = True
|
214
|
+
else:
|
215
|
+
self.num_expert_group = self.topk_group = None
|
216
|
+
self.use_grouped_topk = False
|
169
217
|
|
170
|
-
self.
|
171
|
-
|
218
|
+
self.num_experts = (
|
219
|
+
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
172
220
|
)
|
173
221
|
|
174
|
-
self.
|
222
|
+
self.gate = BailingMoEGate(
|
223
|
+
config=config,
|
224
|
+
params_dtype=self.router_dtype,
|
225
|
+
prefix=add_prefix("gate", prefix),
|
226
|
+
)
|
227
|
+
self.correction_bias = (
|
228
|
+
self.gate.expert_bias.data if self.gate.expert_bias is not None else None
|
229
|
+
)
|
230
|
+
|
231
|
+
if self.score_function is not None:
|
232
|
+
assert (
|
233
|
+
self.score_function == "softmax" and self.correction_bias is None
|
234
|
+
) or (
|
235
|
+
self.score_function == "sigmoid" and self.correction_bias is not None
|
236
|
+
), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)"
|
237
|
+
|
238
|
+
self.topk = TopK(
|
239
|
+
top_k=self.top_k,
|
240
|
+
renormalize=self.norm_topk_prob,
|
241
|
+
use_grouped_topk=self.use_grouped_topk,
|
242
|
+
num_expert_group=self.num_expert_group,
|
243
|
+
# num_fused_shared_experts=self.num_fused_shared_experts,
|
244
|
+
topk_group=self.topk_group,
|
245
|
+
correction_bias=self.correction_bias,
|
246
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
247
|
+
)
|
175
248
|
|
176
|
-
self.experts =
|
249
|
+
self.experts = get_moe_impl_class(quant_config)(
|
177
250
|
num_experts=self.num_experts,
|
178
251
|
top_k=self.top_k,
|
179
|
-
layer_id=layer_id,
|
180
|
-
hidden_size=
|
181
|
-
intermediate_size=
|
182
|
-
reduce_results=False,
|
252
|
+
layer_id=self.layer_id,
|
253
|
+
hidden_size=config.hidden_size,
|
254
|
+
intermediate_size=config.moe_intermediate_size,
|
183
255
|
quant_config=quant_config,
|
256
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
184
257
|
prefix=add_prefix("experts", prefix),
|
185
258
|
)
|
186
|
-
|
187
|
-
if
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
259
|
+
# shared expert
|
260
|
+
if config.num_shared_experts is not None:
|
261
|
+
if hasattr(config, "moe_shared_expert_intermediate_size"):
|
262
|
+
intermediate_size = config.moe_shared_expert_intermediate_size
|
263
|
+
else:
|
264
|
+
intermediate_size = config.moe_intermediate_size
|
265
|
+
intermediate_size *= config.num_shared_experts
|
266
|
+
# disable tp for shared experts when enable deepep moe
|
267
|
+
self.shared_experts = BailingMoEMLP(
|
268
|
+
intermediate_size=intermediate_size,
|
193
269
|
config=config,
|
194
270
|
quant_config=quant_config,
|
195
271
|
reduce_results=False,
|
196
272
|
prefix=add_prefix("shared_experts", prefix),
|
273
|
+
**(
|
274
|
+
dict(tp_rank=0, tp_size=1)
|
275
|
+
if get_moe_a2a_backend().is_deepep()
|
276
|
+
else {}
|
277
|
+
),
|
278
|
+
)
|
279
|
+
# dispatcher
|
280
|
+
if get_moe_a2a_backend().is_deepep():
|
281
|
+
# TODO: we will support tp < ep in the future
|
282
|
+
self.ep_size = get_tensor_model_parallel_world_size()
|
283
|
+
|
284
|
+
self.deepep_dispatcher = DeepEPDispatcher(
|
285
|
+
group=parallel_state.get_tp_group().device_group,
|
286
|
+
router_topk=self.top_k,
|
287
|
+
permute_fusion=True,
|
288
|
+
num_experts=self.num_experts,
|
289
|
+
num_local_experts=config.num_experts // self.tp_size,
|
290
|
+
hidden_size=config.hidden_size,
|
291
|
+
params_dtype=config.torch_dtype,
|
292
|
+
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
293
|
+
async_finish=True, # TODO
|
294
|
+
return_recv_hook=True,
|
197
295
|
)
|
296
|
+
|
297
|
+
def forward(
|
298
|
+
self,
|
299
|
+
hidden_states: torch.Tensor,
|
300
|
+
forward_batch: Optional[ForwardBatch] = None,
|
301
|
+
use_reduce_scatter: bool = False,
|
302
|
+
) -> torch.Tensor:
|
303
|
+
if not get_moe_a2a_backend().is_deepep():
|
304
|
+
return self.forward_normal(hidden_states, use_reduce_scatter)
|
198
305
|
else:
|
199
|
-
self.
|
306
|
+
return self.forward_deepep(hidden_states, forward_batch)
|
200
307
|
|
201
|
-
def
|
202
|
-
|
203
|
-
|
308
|
+
def get_moe_weights(self):
|
309
|
+
return [
|
310
|
+
x.data
|
311
|
+
for name, x in self.experts.named_parameters()
|
312
|
+
if name not in ["correction_bias"]
|
313
|
+
]
|
204
314
|
|
315
|
+
def _forward_shared_experts(self, hidden_states: torch.Tensor):
|
205
316
|
shared_output = None
|
206
|
-
if self.
|
207
|
-
shared_output = self.shared_experts(
|
317
|
+
if self.num_shared_experts > 0:
|
318
|
+
shared_output = self.shared_experts(hidden_states)
|
319
|
+
return shared_output
|
208
320
|
|
209
|
-
|
210
|
-
|
211
|
-
|
321
|
+
def _forward_router_experts(self, hidden_states: torch.Tensor):
|
322
|
+
# router_logits: (num_tokens, n_experts)
|
323
|
+
router_logits = self.gate(hidden_states)
|
324
|
+
topk_output = self.topk(hidden_states, router_logits)
|
325
|
+
return self.experts(hidden_states, topk_output)
|
212
326
|
|
213
|
-
|
327
|
+
def forward_normal_dual_stream(
|
328
|
+
self,
|
329
|
+
hidden_states: torch.Tensor,
|
330
|
+
) -> torch.Tensor:
|
331
|
+
current_stream = torch.cuda.current_stream()
|
332
|
+
self.alt_stream.wait_stream(current_stream)
|
333
|
+
shared_output = self._forward_shared_experts(hidden_states.clone())
|
334
|
+
|
335
|
+
with torch.cuda.stream(self.alt_stream):
|
336
|
+
router_output = self._forward_router_experts(hidden_states)
|
337
|
+
current_stream.wait_stream(self.alt_stream)
|
338
|
+
|
339
|
+
return router_output, shared_output
|
340
|
+
|
341
|
+
def forward_normal(
|
342
|
+
self,
|
343
|
+
hidden_states: torch.Tensor,
|
344
|
+
use_reduce_scatter: bool = False,
|
345
|
+
) -> torch.Tensor:
|
346
|
+
num_tokens, hidden_size = hidden_states.shape
|
347
|
+
hidden_states = hidden_states.view(-1, hidden_size)
|
348
|
+
|
349
|
+
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
350
|
+
if (
|
351
|
+
self.alt_stream is not None
|
352
|
+
and hidden_states.shape[0] > 0
|
353
|
+
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
354
|
+
and get_is_capture_mode()
|
355
|
+
):
|
356
|
+
final_hidden_states, shared_output = self.forward_normal_dual_stream(
|
357
|
+
hidden_states
|
358
|
+
)
|
359
|
+
else:
|
360
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
361
|
+
final_hidden_states = self._forward_router_experts(hidden_states)
|
362
|
+
|
363
|
+
if self.num_shared_experts > 0:
|
214
364
|
final_hidden_states = final_hidden_states + shared_output
|
215
365
|
|
216
|
-
if self.tp_size > 1:
|
366
|
+
if self.tp_size > 1 and not use_reduce_scatter:
|
217
367
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
368
|
+
return final_hidden_states.view(num_tokens, hidden_size)
|
218
369
|
|
219
|
-
|
370
|
+
def forward_deepep(
|
371
|
+
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
372
|
+
) -> torch.Tensor:
|
373
|
+
shared_output = None
|
374
|
+
forward_mode = forward_batch.forward_mode
|
375
|
+
if is_non_idle_and_non_empty(forward_mode, hidden_states):
|
376
|
+
router_logits = self.gate(hidden_states)
|
377
|
+
if self.num_shared_experts > 0:
|
378
|
+
shared_output = self.shared_experts(hidden_states)
|
220
379
|
|
380
|
+
topk_weights, topk_idx, _ = self.topk(
|
381
|
+
hidden_states,
|
382
|
+
router_logits,
|
383
|
+
num_token_non_padded=forward_batch.num_token_non_padded,
|
384
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
385
|
+
layer_id=self.layer_id,
|
386
|
+
),
|
387
|
+
)
|
388
|
+
else:
|
389
|
+
topk_idx = torch.full(
|
390
|
+
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
391
|
+
)
|
392
|
+
topk_weights = torch.empty(
|
393
|
+
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
394
|
+
)
|
395
|
+
|
396
|
+
if self.ep_size > 1:
|
397
|
+
(
|
398
|
+
hidden_states,
|
399
|
+
topk_idx,
|
400
|
+
topk_weights,
|
401
|
+
reorder_topk_ids,
|
402
|
+
num_recv_tokens_per_expert,
|
403
|
+
seg_indptr,
|
404
|
+
masked_m,
|
405
|
+
expected_m,
|
406
|
+
) = self.deepep_dispatcher.dispatch(
|
407
|
+
hidden_states,
|
408
|
+
topk_idx,
|
409
|
+
topk_weights,
|
410
|
+
forward_batch=forward_batch,
|
411
|
+
)
|
221
412
|
|
222
|
-
|
413
|
+
final_hidden_states = self.experts(
|
414
|
+
hidden_states=hidden_states,
|
415
|
+
topk_idx=topk_idx,
|
416
|
+
topk_weights=topk_weights,
|
417
|
+
reorder_topk_ids=reorder_topk_ids,
|
418
|
+
seg_indptr=seg_indptr,
|
419
|
+
masked_m=masked_m,
|
420
|
+
expected_m=expected_m,
|
421
|
+
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
422
|
+
forward_batch=forward_batch,
|
423
|
+
)
|
424
|
+
if self.ep_size > 1:
|
425
|
+
final_hidden_states = self.deepep_dispatcher.combine(
|
426
|
+
final_hidden_states,
|
427
|
+
topk_idx,
|
428
|
+
topk_weights,
|
429
|
+
forward_batch=forward_batch,
|
430
|
+
)
|
223
431
|
|
432
|
+
final_hidden_states *= self.routed_scaling_factor
|
433
|
+
|
434
|
+
if shared_output is not None:
|
435
|
+
final_hidden_states = final_hidden_states + shared_output
|
436
|
+
return final_hidden_states
|
437
|
+
|
438
|
+
|
439
|
+
class BailingMoEAttention(nn.Module):
|
224
440
|
def __init__(
|
225
441
|
self,
|
226
442
|
config: PretrainedConfig,
|
227
|
-
layer_id: int,
|
443
|
+
layer_id: int = 0,
|
228
444
|
quant_config: Optional[QuantizationConfig] = None,
|
445
|
+
reduce_results: bool = True,
|
229
446
|
prefix: str = "",
|
447
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
230
448
|
):
|
231
449
|
super().__init__()
|
232
|
-
self.
|
233
|
-
self.
|
234
|
-
|
450
|
+
self.hidden_size = config.hidden_size
|
451
|
+
self.total_num_heads = config.num_attention_heads
|
452
|
+
self.total_kv_heads = config.num_key_value_heads
|
453
|
+
self.dp_size = get_attention_dp_size()
|
454
|
+
attn_tp_rank = get_attention_tp_rank()
|
455
|
+
attn_tp_size = get_attention_tp_size()
|
456
|
+
|
457
|
+
assert self.total_num_heads % attn_tp_size == 0
|
458
|
+
assert self.total_kv_heads % attn_tp_size == 0
|
459
|
+
assert self.total_num_heads >= self.total_kv_heads
|
460
|
+
|
461
|
+
self.num_heads = self.total_num_heads // attn_tp_size
|
462
|
+
self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
|
463
|
+
self.q_size = self.head_dim * self.num_heads
|
464
|
+
|
465
|
+
self.num_kv_heads = self.total_kv_heads // attn_tp_size
|
466
|
+
self.kv_size = max(1, self.num_kv_heads * self.head_dim)
|
467
|
+
|
468
|
+
self.scale = self.head_dim**-0.5
|
469
|
+
|
470
|
+
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
471
|
+
|
472
|
+
self.query_key_value = QKVParallelLinear(
|
473
|
+
self.hidden_size,
|
474
|
+
self.head_dim,
|
475
|
+
self.total_num_heads,
|
476
|
+
self.total_kv_heads,
|
477
|
+
bias=(config.use_bias or config.use_qkv_bias),
|
478
|
+
quant_config=quant_config,
|
479
|
+
prefix=add_prefix("query_key_value", prefix),
|
480
|
+
tp_rank=attn_tp_rank,
|
481
|
+
tp_size=attn_tp_size,
|
235
482
|
)
|
236
|
-
|
237
|
-
|
483
|
+
|
484
|
+
if self.use_qk_norm:
|
485
|
+
self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
486
|
+
self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
487
|
+
|
488
|
+
self.dense = RowParallelLinear(
|
489
|
+
self.total_num_heads * self.head_dim,
|
490
|
+
self.hidden_size,
|
491
|
+
bias=config.use_bias,
|
492
|
+
quant_config=quant_config,
|
493
|
+
reduce_results=reduce_results,
|
494
|
+
prefix=add_prefix("dense", prefix),
|
495
|
+
tp_rank=attn_tp_rank,
|
496
|
+
tp_size=attn_tp_size,
|
238
497
|
)
|
239
|
-
|
240
|
-
|
498
|
+
|
499
|
+
if hasattr(config, "partial_rotary_factor"):
|
500
|
+
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
|
501
|
+
elif hasattr(config, "rotary_dim"):
|
502
|
+
self.rotary_dim = config.rotary_dim
|
503
|
+
else:
|
504
|
+
self.rotary_dim = self.head_dim
|
505
|
+
self.rotary_emb = get_rope(
|
506
|
+
self.head_dim,
|
507
|
+
rotary_dim=self.rotary_dim,
|
508
|
+
max_position=config.max_position_embeddings,
|
509
|
+
base=config.rope_theta,
|
510
|
+
rope_scaling=config.rope_scaling,
|
511
|
+
)
|
512
|
+
|
513
|
+
self.attn = RadixAttention(
|
514
|
+
self.num_heads,
|
515
|
+
self.head_dim,
|
516
|
+
self.scale,
|
517
|
+
num_kv_heads=self.num_kv_heads,
|
241
518
|
layer_id=layer_id,
|
242
|
-
|
243
|
-
prefix=add_prefix("mlp", prefix),
|
519
|
+
prefix=add_prefix("attn", prefix),
|
244
520
|
)
|
245
521
|
|
522
|
+
self.alt_stream = alt_stream
|
523
|
+
|
524
|
+
def _apply_qk_norm(
|
525
|
+
self, q: torch.Tensor, k: torch.Tensor
|
526
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
527
|
+
# overlap qk norm
|
528
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
529
|
+
current_stream = torch.cuda.current_stream()
|
530
|
+
self.alt_stream.wait_stream(current_stream)
|
531
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
532
|
+
q_by_head = self.query_layernorm(q_by_head)
|
533
|
+
with torch.cuda.stream(self.alt_stream):
|
534
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
535
|
+
k_by_head = self.key_layernorm(k_by_head)
|
536
|
+
current_stream.wait_stream(self.alt_stream)
|
537
|
+
else:
|
538
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
539
|
+
q_by_head = self.query_layernorm(q_by_head)
|
540
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
541
|
+
k_by_head = self.key_layernorm(k_by_head)
|
542
|
+
q = q_by_head.view(q.shape)
|
543
|
+
k = k_by_head.view(k.shape)
|
544
|
+
return q, k
|
545
|
+
|
246
546
|
def forward(
|
247
547
|
self,
|
548
|
+
positions: torch.Tensor,
|
248
549
|
hidden_states: torch.Tensor,
|
249
|
-
position_ids: torch.Tensor,
|
250
|
-
residual: Optional[torch.Tensor],
|
251
550
|
forward_batch: ForwardBatch,
|
252
|
-
) ->
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
551
|
+
) -> torch.Tensor:
|
552
|
+
if hidden_states.shape[0] == 0:
|
553
|
+
return hidden_states
|
554
|
+
qkv, _ = self.query_key_value(hidden_states)
|
555
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
556
|
+
if self.use_qk_norm:
|
557
|
+
q, k = self._apply_qk_norm(q, k)
|
558
|
+
q, k = self.rotary_emb(positions, q, k)
|
559
|
+
context_layer = self.attn(q, k, v, forward_batch)
|
560
|
+
attn_output, _ = self.dense(context_layer)
|
561
|
+
return attn_output
|
562
|
+
|
563
|
+
|
564
|
+
class BailingMoEBlock(nn.Module):
|
565
|
+
def __init__(
|
566
|
+
self,
|
567
|
+
config: PretrainedConfig,
|
568
|
+
layer_id: int = 0,
|
569
|
+
quant_config: Optional[QuantizationConfig] = None,
|
570
|
+
prefix: str = "",
|
571
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
572
|
+
):
|
573
|
+
super().__init__()
|
574
|
+
hidden_size = config.hidden_size
|
575
|
+
|
576
|
+
self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
|
577
|
+
self.dp_size = get_attention_dp_size()
|
578
|
+
self.attention = BailingMoEAttention(
|
579
|
+
config,
|
580
|
+
layer_id,
|
581
|
+
quant_config,
|
582
|
+
reduce_results=False,
|
583
|
+
prefix=add_prefix("attention", prefix),
|
584
|
+
alt_stream=alt_stream,
|
585
|
+
)
|
586
|
+
self.layer_id = layer_id
|
587
|
+
self.attn_tp_size = get_attention_tp_size()
|
588
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
589
|
+
|
590
|
+
self.is_layer_sparse = self._is_layer_sparse(
|
591
|
+
config, layer_id=layer_id, is_nextn=False
|
592
|
+
)
|
593
|
+
is_previous_layer_sparse = self._is_layer_sparse(
|
594
|
+
config, layer_id=layer_id - 1, is_nextn=False
|
595
|
+
)
|
596
|
+
|
597
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
598
|
+
layer_id=layer_id,
|
599
|
+
num_layers=config.num_hidden_layers,
|
600
|
+
is_layer_sparse=self.is_layer_sparse,
|
601
|
+
is_previous_layer_sparse=is_previous_layer_sparse,
|
602
|
+
)
|
603
|
+
|
604
|
+
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
605
|
+
|
606
|
+
if self.is_layer_sparse:
|
607
|
+
self.mlp = BailingMoESparseMoeBlock(
|
608
|
+
layer_id=layer_id,
|
609
|
+
config=config,
|
610
|
+
quant_config=quant_config,
|
611
|
+
alt_stream=alt_stream,
|
612
|
+
prefix=add_prefix("mlp", prefix),
|
613
|
+
)
|
257
614
|
else:
|
258
|
-
|
259
|
-
|
615
|
+
if enable_moe_dense_fully_dp():
|
616
|
+
mlp_tp_rank, mlp_tp_size = 0, 1
|
617
|
+
else:
|
618
|
+
mlp_tp_rank, mlp_tp_size = None, None
|
619
|
+
self.mlp = BailingMoEMLP(
|
620
|
+
intermediate_size=config.intermediate_size,
|
621
|
+
config=config,
|
622
|
+
quant_config=quant_config,
|
623
|
+
prefix=add_prefix("mlp", prefix),
|
624
|
+
tp_rank=mlp_tp_rank,
|
625
|
+
tp_size=mlp_tp_size,
|
260
626
|
)
|
261
627
|
|
262
|
-
|
263
|
-
|
264
|
-
|
628
|
+
self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
|
629
|
+
|
630
|
+
self.layer_communicator = LayerCommunicator(
|
631
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
632
|
+
input_layernorm=self.input_layernorm,
|
633
|
+
post_attention_layernorm=self.post_attention_layernorm,
|
634
|
+
allow_reduce_scatter=True,
|
635
|
+
)
|
636
|
+
|
637
|
+
def _is_layer_sparse(
|
638
|
+
self, config: PretrainedConfig, layer_id: int, is_nextn: bool
|
639
|
+
) -> bool:
|
640
|
+
return is_nextn or (
|
641
|
+
config.num_experts is not None and layer_id >= config.first_k_dense_replace
|
642
|
+
)
|
643
|
+
|
644
|
+
def forward(
|
645
|
+
self,
|
646
|
+
positions: torch.Tensor,
|
647
|
+
hidden_states: torch.Tensor,
|
648
|
+
forward_batch: ForwardBatch,
|
649
|
+
residual: Optional[torch.Tensor],
|
650
|
+
) -> torch.Tensor:
|
651
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
652
|
+
hidden_states=hidden_states,
|
653
|
+
residual=residual,
|
654
|
+
forward_batch=forward_batch,
|
655
|
+
)
|
656
|
+
|
657
|
+
hidden_states = self.attention(
|
658
|
+
positions=positions,
|
659
|
+
hidden_states=hidden_states,
|
265
660
|
forward_batch=forward_batch,
|
266
661
|
)
|
267
662
|
|
268
|
-
|
269
|
-
|
270
|
-
|
663
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
664
|
+
hidden_states=hidden_states,
|
665
|
+
residual=residual,
|
666
|
+
forward_batch=forward_batch,
|
667
|
+
)
|
668
|
+
|
669
|
+
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
670
|
+
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
671
|
+
forward_batch
|
271
672
|
)
|
272
|
-
mlp_output = self.mlp(normed_hidden_states)
|
273
673
|
|
274
|
-
|
674
|
+
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
|
275
675
|
|
676
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
677
|
+
hidden_states=hidden_states,
|
678
|
+
residual=residual,
|
679
|
+
forward_batch=forward_batch,
|
680
|
+
)
|
276
681
|
|
277
|
-
|
682
|
+
return hidden_states, residual
|
683
|
+
|
684
|
+
|
685
|
+
class BailingMoEModel(nn.Module):
|
278
686
|
|
279
687
|
def __init__(
|
280
688
|
self,
|
281
689
|
config: PretrainedConfig,
|
282
690
|
quant_config: Optional[QuantizationConfig] = None,
|
691
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
283
692
|
prefix: str = "",
|
284
693
|
):
|
285
694
|
super().__init__()
|
695
|
+
self.pp_group = get_pp_group()
|
286
696
|
self.config = config
|
287
|
-
self.padding_idx = config.pad_token_id
|
288
697
|
self.vocab_size = config.vocab_size
|
289
698
|
self.embed_dim = config.hidden_size
|
699
|
+
if self.pp_group.is_first_rank:
|
700
|
+
self.word_embeddings = VocabParallelEmbedding(
|
701
|
+
self.vocab_size,
|
702
|
+
self.embed_dim,
|
703
|
+
quant_config=quant_config,
|
704
|
+
prefix=add_prefix("word_embeddings", prefix),
|
705
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
706
|
+
)
|
707
|
+
else:
|
708
|
+
self.word_embeddings = PPMissingLayer()
|
290
709
|
|
291
|
-
self.embed_tokens = VocabParallelEmbedding(
|
292
|
-
config.vocab_size,
|
293
|
-
config.hidden_size,
|
294
|
-
prefix=add_prefix("embed_tokens", prefix),
|
295
|
-
)
|
296
710
|
self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)
|
297
711
|
|
298
|
-
self.layers = make_layers(
|
712
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
299
713
|
config.num_hidden_layers,
|
300
|
-
lambda idx, prefix:
|
301
|
-
config=config,
|
714
|
+
lambda idx, prefix: BailingMoEBlock(
|
302
715
|
layer_id=idx,
|
716
|
+
config=config,
|
303
717
|
quant_config=quant_config,
|
304
718
|
prefix=prefix,
|
719
|
+
alt_stream=alt_stream,
|
305
720
|
),
|
721
|
+
pp_rank=self.pp_group.rank_in_group,
|
722
|
+
pp_size=self.pp_group.world_size,
|
306
723
|
prefix=add_prefix("layers", prefix),
|
307
724
|
)
|
308
|
-
|
309
|
-
|
725
|
+
if self.pp_group.is_last_rank:
|
726
|
+
self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
|
727
|
+
else:
|
728
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
310
729
|
|
311
730
|
def forward(
|
312
731
|
self,
|
313
732
|
input_ids: torch.Tensor,
|
314
|
-
|
733
|
+
positions: torch.Tensor,
|
315
734
|
forward_batch: ForwardBatch,
|
316
|
-
input_embeds:
|
317
|
-
|
318
|
-
|
319
|
-
|
735
|
+
input_embeds: torch.Tensor = None,
|
736
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
737
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
738
|
+
if self.pp_group.is_first_rank:
|
739
|
+
if input_embeds is None:
|
740
|
+
hidden_states = self.word_embeddings(input_ids)
|
741
|
+
else:
|
742
|
+
hidden_states = input_embeds
|
743
|
+
residual = None
|
320
744
|
else:
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
residual
|
329
|
-
|
745
|
+
assert pp_proxy_tensors is not None
|
746
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
747
|
+
residual = pp_proxy_tensors["residual"]
|
748
|
+
|
749
|
+
for i in range(self.start_layer, self.end_layer):
|
750
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
751
|
+
layer = self.layers[i]
|
752
|
+
hidden_states, residual = layer(
|
753
|
+
positions,
|
754
|
+
hidden_states,
|
755
|
+
forward_batch,
|
756
|
+
residual,
|
757
|
+
)
|
758
|
+
if not self.pp_group.is_last_rank:
|
759
|
+
return PPProxyTensors(
|
760
|
+
{
|
761
|
+
"hidden_states": hidden_states,
|
762
|
+
"residual": residual,
|
763
|
+
}
|
330
764
|
)
|
765
|
+
else:
|
766
|
+
if not forward_batch.forward_mode.is_idle():
|
767
|
+
if residual is None:
|
768
|
+
hidden_states = self.norm(hidden_states)
|
769
|
+
else:
|
770
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
771
|
+
return hidden_states
|
331
772
|
|
332
|
-
hidden_states, _ = self.norm(hidden_states, residual)
|
333
|
-
return hidden_states
|
334
|
-
|
335
|
-
|
336
|
-
class BailingMoeForCausalLM(nn.Module):
|
337
773
|
|
774
|
+
class BailingMoEForCausalLM(nn.Module):
|
338
775
|
def __init__(
|
339
776
|
self,
|
340
777
|
config: PretrainedConfig,
|
341
778
|
quant_config: Optional[QuantizationConfig] = None,
|
342
|
-
|
779
|
+
prefix: str = "",
|
780
|
+
):
|
343
781
|
super().__init__()
|
782
|
+
self.pp_group = get_pp_group()
|
344
783
|
self.config = config
|
345
|
-
self.
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
784
|
+
self.quant_config = quant_config
|
785
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
786
|
+
|
787
|
+
self.model = BailingMoEModel(
|
788
|
+
config,
|
789
|
+
quant_config,
|
790
|
+
alt_stream=alt_stream,
|
791
|
+
prefix=add_prefix("model", ""),
|
350
792
|
)
|
351
|
-
if config.tie_word_embeddings:
|
352
|
-
self.lm_head.weight = self.model.embed_tokens.weight
|
353
793
|
|
794
|
+
# tie_word_embeddings为true,复用tie_word_embeddings,反之是独立的
|
795
|
+
if config.tie_word_embeddings:
|
796
|
+
self.lm_head = self.model.word_embeddings
|
797
|
+
else:
|
798
|
+
# TODO something wrong with ParallelLMHead with DP attention enabled
|
799
|
+
self.lm_head = ParallelLMHead(
|
800
|
+
config.vocab_size,
|
801
|
+
config.hidden_size,
|
802
|
+
quant_config=quant_config,
|
803
|
+
prefix=add_prefix("lm_head", prefix),
|
804
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
805
|
+
)
|
354
806
|
self.logits_processor = LogitsProcessor(config)
|
355
807
|
|
808
|
+
@property
|
809
|
+
def start_layer(self):
|
810
|
+
return self.model.start_layer
|
811
|
+
|
812
|
+
@property
|
813
|
+
def end_layer(self):
|
814
|
+
return self.model.end_layer
|
815
|
+
|
816
|
+
def get_embed_and_head(self):
|
817
|
+
"""Used by the eagle_worker."""
|
818
|
+
return self.model.word_embeddings.weight, self.lm_head.weight
|
819
|
+
|
820
|
+
def set_embed_and_head(self, embed, head):
|
821
|
+
"""Used by the eagle_worker."""
|
822
|
+
del self.model.word_embeddings.weight
|
823
|
+
del self.lm_head.weight
|
824
|
+
self.model.word_embeddings.weight = embed
|
825
|
+
self.lm_head.weight = head
|
826
|
+
torch.cuda.empty_cache()
|
827
|
+
torch.cuda.synchronize()
|
828
|
+
|
829
|
+
@torch.no_grad()
|
356
830
|
def forward(
|
357
831
|
self,
|
358
832
|
input_ids: torch.Tensor,
|
359
833
|
positions: torch.Tensor,
|
360
834
|
forward_batch: ForwardBatch,
|
361
|
-
|
835
|
+
input_embeds: torch.Tensor = None,
|
836
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
362
837
|
) -> torch.Tensor:
|
363
|
-
hidden_states = self.model(
|
364
|
-
|
365
|
-
|
838
|
+
hidden_states = self.model(
|
839
|
+
input_ids,
|
840
|
+
positions,
|
841
|
+
forward_batch,
|
842
|
+
input_embeds,
|
843
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
366
844
|
)
|
367
|
-
|
368
|
-
|
845
|
+
if self.pp_group.is_last_rank:
|
846
|
+
return self.logits_processor(
|
847
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
848
|
+
)
|
849
|
+
else:
|
850
|
+
return hidden_states
|
851
|
+
|
852
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
853
|
+
if is_nextn:
|
854
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
855
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
856
|
+
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
|
857
|
+
# compatible with old design
|
858
|
+
nextn_layer_id = (
|
859
|
+
0
|
860
|
+
if self.config.num_hidden_layers == 1
|
861
|
+
else self.config.num_hidden_layers
|
862
|
+
)
|
863
|
+
else:
|
864
|
+
raise ValueError("num_nextn_predict_layers is not in the config")
|
369
865
|
|
370
866
|
stacked_params_mapping = [
|
867
|
+
# (param_name, shard_name, shard_id)
|
371
868
|
("gate_up_proj", "gate_proj", 0),
|
372
869
|
("gate_up_proj", "up_proj", 1),
|
373
870
|
]
|
374
871
|
|
872
|
+
if is_nextn:
|
873
|
+
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
874
|
+
nextn_spec_weight_names = [
|
875
|
+
"final_layernorm",
|
876
|
+
"eh_proj",
|
877
|
+
"enorm",
|
878
|
+
"hnorm",
|
879
|
+
]
|
880
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
881
|
+
# (param_name, weight_name, expert_id, shard_id)
|
375
882
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
376
883
|
ckpt_gate_proj_name="gate_proj",
|
377
884
|
ckpt_down_proj_name="down_proj",
|
@@ -381,39 +888,87 @@ class BailingMoeForCausalLM(nn.Module):
|
|
381
888
|
|
382
889
|
params_dict = dict(self.named_parameters())
|
383
890
|
for name, loaded_weight in weights:
|
891
|
+
if (
|
892
|
+
("v_head" in name)
|
893
|
+
or ("inv_freq" in name)
|
894
|
+
or (self.config.tie_word_embeddings and "lm_head" in name)
|
895
|
+
):
|
896
|
+
continue
|
384
897
|
|
385
898
|
if (
|
386
899
|
hasattr(self.config, "norm_head")
|
387
900
|
and self.config.norm_head
|
388
901
|
and "lm_head.weight" in name
|
389
902
|
):
|
903
|
+
import torch.nn.functional as F
|
904
|
+
|
390
905
|
loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)
|
391
906
|
|
392
|
-
if
|
393
|
-
|
907
|
+
if is_nextn:
|
908
|
+
if not name.startswith(nextn_layer_prefix):
|
909
|
+
continue
|
910
|
+
|
911
|
+
# Use shared head and embed weights from target model
|
912
|
+
if "shared_head.head" in name or "embed_tokens" in name:
|
913
|
+
continue
|
914
|
+
|
915
|
+
is_decoder = True
|
916
|
+
# For nextn specific weights
|
917
|
+
for weight_name in nextn_spec_weight_names:
|
918
|
+
if weight_name in name:
|
919
|
+
name = name.replace(nextn_layer_prefix, "model")
|
920
|
+
is_decoder = False
|
921
|
+
break
|
922
|
+
# For decoder layer weights
|
923
|
+
if is_decoder:
|
924
|
+
name = name.replace(nextn_layer_prefix, "model.decoder")
|
394
925
|
|
395
926
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
396
|
-
if weight_name
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
927
|
+
if weight_name not in name:
|
928
|
+
continue
|
929
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
930
|
+
# Since we handle the experts below in expert_params_mapping,
|
931
|
+
# we need to skip here BEFORE we update the name, otherwise
|
932
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
933
|
+
# will then be updated below in expert_params_mapping
|
934
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
935
|
+
if "mlp.experts" in name:
|
936
|
+
continue
|
937
|
+
name = name.replace(weight_name, param_name)
|
938
|
+
# Skip loading extra bias for GPTQ models.
|
939
|
+
if name.endswith(".bias") and name not in params_dict:
|
940
|
+
continue
|
941
|
+
if name not in params_dict:
|
942
|
+
continue
|
943
|
+
|
944
|
+
param = params_dict[name]
|
945
|
+
weight_loader = param.weight_loader
|
946
|
+
weight_loader(param, loaded_weight, shard_id)
|
947
|
+
break
|
401
948
|
else:
|
402
|
-
for
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
949
|
+
for mapping in expert_params_mapping:
|
950
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
951
|
+
if weight_name not in name:
|
952
|
+
continue
|
953
|
+
name = name.replace(weight_name, param_name)
|
954
|
+
if name not in params_dict:
|
955
|
+
continue
|
956
|
+
param = params_dict[name]
|
957
|
+
weight_loader = param.weight_loader
|
958
|
+
weight_loader(
|
959
|
+
param,
|
960
|
+
loaded_weight,
|
961
|
+
name,
|
962
|
+
shard_id=shard_id,
|
963
|
+
expert_id=expert_id,
|
964
|
+
)
|
965
|
+
break
|
414
966
|
else:
|
967
|
+
# Skip loading extra bias for GPTQ models.
|
415
968
|
if name.endswith(".bias") and name not in params_dict:
|
416
969
|
continue
|
970
|
+
if name not in params_dict:
|
971
|
+
continue
|
417
972
|
|
418
973
|
param = params_dict[name]
|
419
974
|
weight_loader = getattr(
|
@@ -421,5 +976,30 @@ class BailingMoeForCausalLM(nn.Module):
|
|
421
976
|
)
|
422
977
|
weight_loader(param, loaded_weight)
|
423
978
|
|
979
|
+
if not is_nextn:
|
980
|
+
self.routed_experts_weights_of_layer = {
|
981
|
+
layer_id: layer.mlp.get_moe_weights()
|
982
|
+
for layer_id, layer in enumerate(self.model.layers)
|
983
|
+
if not isinstance(layer, PPMissingLayer)
|
984
|
+
and isinstance(layer.mlp, BailingMoESparseMoeBlock)
|
985
|
+
}
|
986
|
+
|
987
|
+
@classmethod
|
988
|
+
def get_model_config_for_expert_location(cls, config):
|
989
|
+
num_groups = getattr(config, "n_group", 0)
|
990
|
+
return ModelConfigForExpertLocation(
|
991
|
+
num_layers=config.num_hidden_layers,
|
992
|
+
num_logical_experts=config.num_experts,
|
993
|
+
num_groups=None if num_groups == 0 else num_groups,
|
994
|
+
)
|
995
|
+
|
996
|
+
|
997
|
+
class BailingMoeForCausalLM(BailingMoEForCausalLM):
|
998
|
+
pass
|
999
|
+
|
1000
|
+
|
1001
|
+
class BailingMoeV2ForCausalLM(BailingMoEForCausalLM):
|
1002
|
+
pass
|
1003
|
+
|
424
1004
|
|
425
|
-
EntryClass = BailingMoeForCausalLM
|
1005
|
+
EntryClass = [BailingMoEForCausalLM, BailingMoeForCausalLM, BailingMoeV2ForCausalLM]
|