sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,922 @@
|
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ==============================================================================
|
|
14
|
+
|
|
15
|
+
# Adapted from DeepSeek and Mixtral implementation
|
|
16
|
+
"""Inference-only MiniMax M2 model compatible with HuggingFace weights."""
|
|
17
|
+
|
|
18
|
+
import logging
|
|
19
|
+
from typing import Iterable, Optional, Set, Tuple, Union
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
from torch import nn
|
|
23
|
+
from transformers import PretrainedConfig
|
|
24
|
+
|
|
25
|
+
from sglang.srt.distributed import (
|
|
26
|
+
get_moe_expert_parallel_world_size,
|
|
27
|
+
get_pp_group,
|
|
28
|
+
get_tensor_model_parallel_rank,
|
|
29
|
+
get_tensor_model_parallel_world_size,
|
|
30
|
+
tensor_model_parallel_all_reduce,
|
|
31
|
+
)
|
|
32
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
|
33
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
|
34
|
+
from sglang.srt.layers.activation import SiluAndMul
|
|
35
|
+
from sglang.srt.layers.communicator import (
|
|
36
|
+
LayerCommunicator,
|
|
37
|
+
LayerScatterModes,
|
|
38
|
+
ScatterMode,
|
|
39
|
+
)
|
|
40
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
|
41
|
+
from sglang.srt.layers.linear import (
|
|
42
|
+
MergedColumnParallelLinear,
|
|
43
|
+
QKVParallelLinear,
|
|
44
|
+
ReplicatedLinear,
|
|
45
|
+
RowParallelLinear,
|
|
46
|
+
)
|
|
47
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
48
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
|
49
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
50
|
+
from sglang.srt.layers.moe.topk import TopK
|
|
51
|
+
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
|
|
52
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
53
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
|
54
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
|
55
|
+
from sglang.srt.layers.utils import PPMissingLayer
|
|
56
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
57
|
+
ParallelLMHead,
|
|
58
|
+
VocabParallelEmbedding,
|
|
59
|
+
)
|
|
60
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
61
|
+
from sglang.srt.model_loader.weight_utils import (
|
|
62
|
+
default_weight_loader,
|
|
63
|
+
maybe_remap_kv_scale_name,
|
|
64
|
+
)
|
|
65
|
+
from sglang.srt.server_args import get_global_server_args
|
|
66
|
+
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
|
67
|
+
from sglang.srt.utils import (
|
|
68
|
+
BumpAllocator,
|
|
69
|
+
add_prefix,
|
|
70
|
+
get_compiler_backend,
|
|
71
|
+
is_non_idle_and_non_empty,
|
|
72
|
+
make_layers,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
logger = logging.getLogger(__name__)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class MiniMaxM2RMSNormTP(nn.Module):
|
|
79
|
+
"""RMSNorm with Tensor Parallel support for QK normalization."""
|
|
80
|
+
|
|
81
|
+
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
|
82
|
+
super().__init__()
|
|
83
|
+
self.tp_world = get_tensor_model_parallel_world_size()
|
|
84
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
|
85
|
+
|
|
86
|
+
# Weight parameter is sharded across TP ranks
|
|
87
|
+
self.weight = nn.Parameter(torch.ones(int(hidden_size / self.tp_world)))
|
|
88
|
+
self.weight.weight_loader = self.weight_loader
|
|
89
|
+
self.variance_epsilon = eps
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def weight_loader(
|
|
93
|
+
param: nn.Parameter,
|
|
94
|
+
loaded_weight: torch.Tensor,
|
|
95
|
+
) -> None:
|
|
96
|
+
"""Custom weight loader that handles TP sharding."""
|
|
97
|
+
tp_world = get_tensor_model_parallel_world_size()
|
|
98
|
+
tp_rank = get_tensor_model_parallel_rank()
|
|
99
|
+
|
|
100
|
+
shard_size = loaded_weight.shape[0] // tp_world
|
|
101
|
+
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
|
102
|
+
param.data.copy_(loaded_weight[shard])
|
|
103
|
+
|
|
104
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
|
105
|
+
def forward(
|
|
106
|
+
self,
|
|
107
|
+
x: torch.Tensor,
|
|
108
|
+
residual: Optional[torch.Tensor] = None,
|
|
109
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
110
|
+
"""Forward pass with TP-aware variance computation."""
|
|
111
|
+
assert residual is None, "RMSNormTP does not support residual connection."
|
|
112
|
+
|
|
113
|
+
orig_dtype = x.dtype
|
|
114
|
+
x = x.to(torch.float32)
|
|
115
|
+
|
|
116
|
+
# Compute variance across the full dimension (not just local shard)
|
|
117
|
+
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
|
|
118
|
+
|
|
119
|
+
if self.tp_world > 1:
|
|
120
|
+
# All-reduce variance across TP ranks to get global variance
|
|
121
|
+
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
|
|
122
|
+
|
|
123
|
+
# Normalize and apply local weight shard
|
|
124
|
+
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
|
125
|
+
x = (x * self.weight).to(orig_dtype)
|
|
126
|
+
|
|
127
|
+
return x
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class MiniMaxM2MLP(nn.Module):
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
hidden_size: int,
|
|
134
|
+
intermediate_size: int,
|
|
135
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
136
|
+
prefix: str = "mlp",
|
|
137
|
+
) -> None:
|
|
138
|
+
super().__init__()
|
|
139
|
+
|
|
140
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
|
141
|
+
hidden_size,
|
|
142
|
+
[intermediate_size] * 2,
|
|
143
|
+
bias=False,
|
|
144
|
+
quant_config=quant_config,
|
|
145
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
|
146
|
+
)
|
|
147
|
+
self.down_proj = RowParallelLinear(
|
|
148
|
+
intermediate_size,
|
|
149
|
+
hidden_size,
|
|
150
|
+
bias=False,
|
|
151
|
+
quant_config=quant_config,
|
|
152
|
+
prefix=add_prefix("down_proj", prefix),
|
|
153
|
+
)
|
|
154
|
+
self.act_fn = SiluAndMul()
|
|
155
|
+
return
|
|
156
|
+
|
|
157
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
158
|
+
gate_up, _ = self.gate_up_proj(x)
|
|
159
|
+
x = self.act_fn(gate_up)
|
|
160
|
+
x, _ = self.down_proj(x)
|
|
161
|
+
return x
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class MiniMaxM2MoE(nn.Module):
|
|
165
|
+
"""MiniMax MoE implementation using DeepEP for Expert Parallel support."""
|
|
166
|
+
|
|
167
|
+
def __init__(
|
|
168
|
+
self,
|
|
169
|
+
config: PretrainedConfig,
|
|
170
|
+
layer_id: int,
|
|
171
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
172
|
+
prefix: str = "",
|
|
173
|
+
):
|
|
174
|
+
super().__init__()
|
|
175
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
|
176
|
+
if self.tp_size > config.num_local_experts:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"Tensor parallel size {self.tp_size} is greater than "
|
|
179
|
+
f"the number of experts {config.num_local_experts}."
|
|
180
|
+
)
|
|
181
|
+
self.use_routing_bias = getattr(config, "use_routing_bias", False)
|
|
182
|
+
if self.use_routing_bias:
|
|
183
|
+
self.e_score_correction_bias = nn.Parameter(
|
|
184
|
+
torch.empty(config.num_local_experts, dtype=torch.float32)
|
|
185
|
+
)
|
|
186
|
+
self.e_score_correction_bias.weight_loader = (
|
|
187
|
+
MiniMaxM2MoE.ebias_weight_loader
|
|
188
|
+
)
|
|
189
|
+
else:
|
|
190
|
+
self.e_score_correction_bias = None
|
|
191
|
+
|
|
192
|
+
self.experts = get_moe_impl_class(quant_config)(
|
|
193
|
+
num_experts=config.num_local_experts
|
|
194
|
+
+ get_global_server_args().ep_num_redundant_experts,
|
|
195
|
+
top_k=config.num_experts_per_tok,
|
|
196
|
+
hidden_size=config.hidden_size,
|
|
197
|
+
intermediate_size=config.intermediate_size,
|
|
198
|
+
layer_id=layer_id,
|
|
199
|
+
quant_config=quant_config,
|
|
200
|
+
prefix=add_prefix("experts", prefix),
|
|
201
|
+
)
|
|
202
|
+
self.topk = TopK(
|
|
203
|
+
top_k=config.num_experts_per_tok,
|
|
204
|
+
renormalize=True,
|
|
205
|
+
scoring_func=config.scoring_func,
|
|
206
|
+
use_grouped_topk=True, # TODO: Use "grouped top-k" flag only for hardcoded sigmoid scoring
|
|
207
|
+
num_expert_group=1,
|
|
208
|
+
topk_group=1,
|
|
209
|
+
correction_bias=self.e_score_correction_bias,
|
|
210
|
+
routed_scaling_factor=1.0,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
self.gate = ReplicatedLinear(
|
|
214
|
+
config.hidden_size,
|
|
215
|
+
config.num_local_experts,
|
|
216
|
+
bias=False,
|
|
217
|
+
params_dtype=torch.float32,
|
|
218
|
+
quant_config=None,
|
|
219
|
+
prefix=add_prefix("gate", prefix),
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
self.layer_id = layer_id
|
|
223
|
+
|
|
224
|
+
if get_moe_a2a_backend().is_deepep():
|
|
225
|
+
self.ep_size = get_moe_expert_parallel_world_size()
|
|
226
|
+
self.top_k = config.num_experts_per_tok
|
|
227
|
+
|
|
228
|
+
@staticmethod
|
|
229
|
+
def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None:
|
|
230
|
+
assert param.size() == loaded_weight.size()
|
|
231
|
+
param.data.copy_(loaded_weight.to(torch.float32))
|
|
232
|
+
|
|
233
|
+
def forward(
|
|
234
|
+
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
|
235
|
+
) -> torch.Tensor:
|
|
236
|
+
if get_moe_a2a_backend().is_deepep():
|
|
237
|
+
return self.forward_deepep(hidden_states, forward_batch)
|
|
238
|
+
else:
|
|
239
|
+
return self.forward_normal(hidden_states)
|
|
240
|
+
|
|
241
|
+
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
242
|
+
num_tokens, hidden_dim = hidden_states.shape
|
|
243
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
244
|
+
|
|
245
|
+
# router_logits: (num_tokens, n_experts)
|
|
246
|
+
router_logits, _ = self.gate(hidden_states.to(torch.float32))
|
|
247
|
+
topk_output = self.topk(hidden_states, router_logits)
|
|
248
|
+
|
|
249
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
|
250
|
+
if self.tp_size > 1:
|
|
251
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
|
252
|
+
|
|
253
|
+
return final_hidden_states.view(num_tokens, hidden_dim)
|
|
254
|
+
|
|
255
|
+
def forward_deepep(
|
|
256
|
+
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
|
257
|
+
) -> torch.Tensor:
|
|
258
|
+
if hidden_states.shape[0] > 0:
|
|
259
|
+
# router_logits: (num_tokens, n_experts)
|
|
260
|
+
router_logits, _ = self.gate(hidden_states.to(torch.float32))
|
|
261
|
+
topk_weights, topk_idx, _ = self.topk(
|
|
262
|
+
hidden_states,
|
|
263
|
+
router_logits,
|
|
264
|
+
num_token_non_padded=forward_batch.num_token_non_padded,
|
|
265
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
|
266
|
+
layer_id=self.layer_id,
|
|
267
|
+
),
|
|
268
|
+
)
|
|
269
|
+
else:
|
|
270
|
+
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
|
|
271
|
+
hidden_states.shape[0], self.top_k
|
|
272
|
+
)
|
|
273
|
+
final_hidden_states = self.experts(
|
|
274
|
+
hidden_states=hidden_states,
|
|
275
|
+
topk_idx=topk_idx,
|
|
276
|
+
topk_weights=topk_weights,
|
|
277
|
+
forward_batch=forward_batch,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
return final_hidden_states
|
|
281
|
+
|
|
282
|
+
# TBO Operations for MiniMax MoE
|
|
283
|
+
def op_gate(self, state):
|
|
284
|
+
"""Gate operation for TBO - compute router logits"""
|
|
285
|
+
if is_non_idle_and_non_empty(
|
|
286
|
+
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
|
287
|
+
): # router_logits: (num_tokens, num_experts)
|
|
288
|
+
state.router_logits, _ = self.gate(state.hidden_states_mlp_input)
|
|
289
|
+
else:
|
|
290
|
+
state.router_logits = None
|
|
291
|
+
|
|
292
|
+
def op_select_experts(self, state):
|
|
293
|
+
"""Expert selection operation for TBO"""
|
|
294
|
+
router_logits = state.pop("router_logits")
|
|
295
|
+
hidden_states = state.hidden_states_mlp_input
|
|
296
|
+
|
|
297
|
+
if router_logits is not None:
|
|
298
|
+
with get_global_expert_distribution_recorder().with_current_layer(
|
|
299
|
+
self.layer_id
|
|
300
|
+
):
|
|
301
|
+
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
|
|
302
|
+
hidden_states=hidden_states,
|
|
303
|
+
router_logits=router_logits,
|
|
304
|
+
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
|
305
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
|
306
|
+
layer_id=self.layer_id,
|
|
307
|
+
),
|
|
308
|
+
)
|
|
309
|
+
else:
|
|
310
|
+
state.topk_idx_local = torch.full(
|
|
311
|
+
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
|
312
|
+
)
|
|
313
|
+
state.topk_weights_local = torch.empty(
|
|
314
|
+
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def op_dispatch_a(self, state):
|
|
318
|
+
"""Dispatch A operation for TBO - start async dispatch"""
|
|
319
|
+
if self.ep_size > 1:
|
|
320
|
+
self.experts.deepep_dispatcher.dispatch_a(
|
|
321
|
+
hidden_states=state.pop("hidden_states_mlp_input"),
|
|
322
|
+
topk_idx=state.pop("topk_idx_local"),
|
|
323
|
+
topk_weights=state.pop("topk_weights_local"),
|
|
324
|
+
forward_batch=state.forward_batch,
|
|
325
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
def op_dispatch_b(self, state):
|
|
329
|
+
"""Dispatch B operation for TBO - complete async dispatch"""
|
|
330
|
+
if self.ep_size > 1:
|
|
331
|
+
with get_global_expert_distribution_recorder().with_current_layer(
|
|
332
|
+
self.layer_id
|
|
333
|
+
):
|
|
334
|
+
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
|
|
335
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def op_experts(self, state):
|
|
339
|
+
"""Expert computation for TBO"""
|
|
340
|
+
state.hidden_states_experts_output = self.experts.moe_impl(
|
|
341
|
+
dispatch_output=state.dispatch_output,
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
def op_combine_a(self, state):
|
|
345
|
+
"""Combine A operation for TBO - start async combine"""
|
|
346
|
+
if self.ep_size > 1:
|
|
347
|
+
self.experts.deepep_dispatcher.combine_a(
|
|
348
|
+
hidden_states=state.pop("hidden_states_experts_output"),
|
|
349
|
+
topk_idx=state.dispatch_output.topk_idx,
|
|
350
|
+
topk_weights=state.dispatch_output.topk_weights,
|
|
351
|
+
forward_batch=state.forward_batch,
|
|
352
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
|
353
|
+
)
|
|
354
|
+
state.pop("dispatch_output")
|
|
355
|
+
|
|
356
|
+
def op_combine_b(self, state):
|
|
357
|
+
"""Combine B operation for TBO - complete async combine"""
|
|
358
|
+
if self.ep_size > 1:
|
|
359
|
+
state.hidden_states_after_combine = (
|
|
360
|
+
self.experts.deepep_dispatcher.combine_b(
|
|
361
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
|
362
|
+
)
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
def op_output(self, state):
|
|
366
|
+
"""Output operation for TBO - final MLP output"""
|
|
367
|
+
final_hidden_states = state.pop("hidden_states_after_combine")
|
|
368
|
+
# MiniMax doesn't have shared experts like DeepSeek, so no need to add them
|
|
369
|
+
state.hidden_states_mlp_output = final_hidden_states
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class MiniMaxM2Attention(nn.Module):
|
|
373
|
+
"""MiniMax Attention implementation with QK normalization and partial RoPE."""
|
|
374
|
+
|
|
375
|
+
def __init__(
|
|
376
|
+
self,
|
|
377
|
+
config: PretrainedConfig,
|
|
378
|
+
layer_id: int = 0,
|
|
379
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
380
|
+
prefix: str = "",
|
|
381
|
+
) -> None:
|
|
382
|
+
super().__init__()
|
|
383
|
+
self.hidden_size = config.hidden_size
|
|
384
|
+
tp_size = get_tensor_model_parallel_world_size()
|
|
385
|
+
|
|
386
|
+
# Get dimensions from config
|
|
387
|
+
self.total_num_heads = config.num_attention_heads
|
|
388
|
+
assert self.total_num_heads % tp_size == 0
|
|
389
|
+
self.num_heads = self.total_num_heads // tp_size
|
|
390
|
+
self.total_num_kv_heads = config.num_key_value_heads
|
|
391
|
+
|
|
392
|
+
if self.total_num_kv_heads >= tp_size:
|
|
393
|
+
# Number of KV heads is greater than TP size, so we partition
|
|
394
|
+
# the KV heads across multiple tensor parallel GPUs.
|
|
395
|
+
assert self.total_num_kv_heads % tp_size == 0
|
|
396
|
+
else:
|
|
397
|
+
# Number of KV heads is less than TP size, so we replicate
|
|
398
|
+
# the KV heads across multiple tensor parallel GPUs.
|
|
399
|
+
assert tp_size % self.total_num_kv_heads == 0
|
|
400
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
401
|
+
|
|
402
|
+
# Use head_dim from config if available, otherwise calculate
|
|
403
|
+
self.head_dim = getattr(
|
|
404
|
+
config, "head_dim", self.hidden_size // self.total_num_heads
|
|
405
|
+
)
|
|
406
|
+
self.q_size = self.num_heads * self.head_dim
|
|
407
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
|
408
|
+
self.scaling = self.head_dim**-0.5
|
|
409
|
+
|
|
410
|
+
# RoPE settings - support partial RoPE
|
|
411
|
+
self.rope_theta = getattr(config, "rope_theta", 10000)
|
|
412
|
+
self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
|
413
|
+
self.rotary_dim = getattr(
|
|
414
|
+
config, "rotary_dim", self.head_dim
|
|
415
|
+
) # MiniMax uses rotary_dim=64
|
|
416
|
+
|
|
417
|
+
# QK Normalization settings
|
|
418
|
+
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
|
419
|
+
self.qk_norm_type = getattr(config, "qk_norm_type", "per_layer")
|
|
420
|
+
|
|
421
|
+
self.qkv_proj = QKVParallelLinear(
|
|
422
|
+
self.hidden_size,
|
|
423
|
+
self.head_dim,
|
|
424
|
+
self.total_num_heads,
|
|
425
|
+
self.total_num_kv_heads,
|
|
426
|
+
bias=False,
|
|
427
|
+
quant_config=quant_config,
|
|
428
|
+
prefix=add_prefix("qkv_proj", prefix),
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
self.o_proj = RowParallelLinear(
|
|
432
|
+
self.total_num_heads * self.head_dim,
|
|
433
|
+
self.hidden_size,
|
|
434
|
+
bias=False,
|
|
435
|
+
reduce_results=False,
|
|
436
|
+
quant_config=quant_config,
|
|
437
|
+
prefix=add_prefix("o_proj", prefix),
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# Setup RoPE with partial rotary dimension
|
|
441
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
|
442
|
+
self.rotary_emb = get_rope(
|
|
443
|
+
self.head_dim,
|
|
444
|
+
rotary_dim=self.rotary_dim, # Use partial rotary dimension
|
|
445
|
+
max_position=self.max_position_embeddings,
|
|
446
|
+
base=self.rope_theta,
|
|
447
|
+
rope_scaling=rope_scaling,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
# QK Normalization layers
|
|
451
|
+
if self.use_qk_norm:
|
|
452
|
+
if self.qk_norm_type == "per_layer":
|
|
453
|
+
# Use RMSNormTP for proper tensor parallel support
|
|
454
|
+
# Use total dimensions (before TP sharding) for correct normalization
|
|
455
|
+
self.q_norm = MiniMaxM2RMSNormTP(
|
|
456
|
+
self.total_num_heads * self.head_dim, eps=config.rms_norm_eps
|
|
457
|
+
)
|
|
458
|
+
self.k_norm = MiniMaxM2RMSNormTP(
|
|
459
|
+
self.total_num_kv_heads * self.head_dim, eps=config.rms_norm_eps
|
|
460
|
+
)
|
|
461
|
+
else:
|
|
462
|
+
raise ValueError(f"Unsupported qk_norm_type: {self.qk_norm_type}")
|
|
463
|
+
|
|
464
|
+
self.attn = RadixAttention(
|
|
465
|
+
self.num_heads,
|
|
466
|
+
self.head_dim,
|
|
467
|
+
self.scaling,
|
|
468
|
+
num_kv_heads=self.num_kv_heads,
|
|
469
|
+
layer_id=layer_id,
|
|
470
|
+
quant_config=quant_config,
|
|
471
|
+
prefix=add_prefix("attn", prefix),
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
def forward_prepare(
|
|
475
|
+
self,
|
|
476
|
+
positions: torch.Tensor,
|
|
477
|
+
hidden_states: torch.Tensor,
|
|
478
|
+
forward_batch: ForwardBatch,
|
|
479
|
+
):
|
|
480
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
|
481
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
482
|
+
if self.use_qk_norm:
|
|
483
|
+
q = self.q_norm(q.contiguous())
|
|
484
|
+
k = self.k_norm(k.contiguous())
|
|
485
|
+
else:
|
|
486
|
+
q, k = q.contiguous(), k.contiguous()
|
|
487
|
+
q, k = self.rotary_emb(positions, q, k)
|
|
488
|
+
inner_state = q, k, v, forward_batch
|
|
489
|
+
return None, forward_batch, inner_state
|
|
490
|
+
|
|
491
|
+
def forward_core(self, intermediate_state):
|
|
492
|
+
_, _, inner_state = intermediate_state
|
|
493
|
+
attn_output = self.attn(*inner_state)
|
|
494
|
+
output, _ = self.o_proj(attn_output)
|
|
495
|
+
return output
|
|
496
|
+
|
|
497
|
+
def forward(
|
|
498
|
+
self,
|
|
499
|
+
positions: torch.Tensor,
|
|
500
|
+
hidden_states: torch.Tensor,
|
|
501
|
+
forward_batch: ForwardBatch,
|
|
502
|
+
) -> torch.Tensor:
|
|
503
|
+
s = self.forward_prepare(
|
|
504
|
+
positions=positions,
|
|
505
|
+
hidden_states=hidden_states,
|
|
506
|
+
forward_batch=forward_batch,
|
|
507
|
+
)
|
|
508
|
+
return self.forward_core(s)
|
|
509
|
+
|
|
510
|
+
def op_prepare(self, state):
|
|
511
|
+
state.attn_intermediate_state = self.forward_prepare(
|
|
512
|
+
positions=state.positions,
|
|
513
|
+
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
|
|
514
|
+
forward_batch=state.forward_batch,
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
def op_core(self, state):
|
|
518
|
+
state.hidden_states_after_attn = self.forward_core(
|
|
519
|
+
state.pop("attn_intermediate_state")
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
class MiniMaxM2DecoderLayer(nn.Module):
|
|
524
|
+
"""MiniMax Decoder Layer implementation with MoE support."""
|
|
525
|
+
|
|
526
|
+
def __init__(
|
|
527
|
+
self,
|
|
528
|
+
config: PretrainedConfig,
|
|
529
|
+
layer_id: int,
|
|
530
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
531
|
+
prefix: str = "",
|
|
532
|
+
) -> None:
|
|
533
|
+
super().__init__()
|
|
534
|
+
self.hidden_size = config.hidden_size
|
|
535
|
+
self.layer_id = layer_id
|
|
536
|
+
|
|
537
|
+
# TBO support: All MiniMax layers are sparse (MoE)
|
|
538
|
+
self.is_layer_sparse = True
|
|
539
|
+
|
|
540
|
+
self.self_attn = MiniMaxM2Attention(
|
|
541
|
+
config=config,
|
|
542
|
+
layer_id=layer_id,
|
|
543
|
+
quant_config=quant_config,
|
|
544
|
+
prefix=add_prefix("self_attn", prefix),
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
self.block_sparse_moe = MiniMaxM2MoE(
|
|
548
|
+
config=config,
|
|
549
|
+
layer_id=layer_id,
|
|
550
|
+
quant_config=quant_config,
|
|
551
|
+
prefix=add_prefix("mlp", prefix),
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
self.input_layernorm = RMSNorm(
|
|
555
|
+
config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6)
|
|
556
|
+
)
|
|
557
|
+
self.post_attention_layernorm = RMSNorm(
|
|
558
|
+
config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6)
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
is_previous_layer_sparse = True
|
|
562
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
|
563
|
+
layer_id=layer_id,
|
|
564
|
+
num_layers=config.num_hidden_layers,
|
|
565
|
+
is_layer_sparse=self.is_layer_sparse,
|
|
566
|
+
is_previous_layer_sparse=is_previous_layer_sparse,
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
self.layer_communicator = LayerCommunicator(
|
|
570
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
|
571
|
+
input_layernorm=self.input_layernorm,
|
|
572
|
+
post_attention_layernorm=self.post_attention_layernorm,
|
|
573
|
+
allow_reduce_scatter=True,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
def forward(
|
|
577
|
+
self,
|
|
578
|
+
positions: torch.Tensor,
|
|
579
|
+
hidden_states: torch.Tensor,
|
|
580
|
+
forward_batch: ForwardBatch,
|
|
581
|
+
residual: Optional[torch.Tensor],
|
|
582
|
+
) -> torch.Tensor:
|
|
583
|
+
# Self Attention
|
|
584
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
|
585
|
+
hidden_states, residual, forward_batch
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
hidden_states = self.self_attn(
|
|
589
|
+
positions=positions,
|
|
590
|
+
hidden_states=hidden_states,
|
|
591
|
+
forward_batch=forward_batch,
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
# Fully Connected (MLP or MoE)
|
|
595
|
+
|
|
596
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
|
597
|
+
hidden_states, residual, forward_batch
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
hidden_states = self.block_sparse_moe(hidden_states, forward_batch)
|
|
601
|
+
|
|
602
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
|
603
|
+
hidden_states, residual, forward_batch
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
return hidden_states, residual
|
|
607
|
+
|
|
608
|
+
# TBO Operations for MiniMax Decoder Layer
|
|
609
|
+
def op_comm_prepare_attn(
|
|
610
|
+
self,
|
|
611
|
+
state,
|
|
612
|
+
positions: torch.Tensor,
|
|
613
|
+
hidden_states: torch.Tensor,
|
|
614
|
+
forward_batch: ForwardBatch,
|
|
615
|
+
residual: Optional[torch.Tensor],
|
|
616
|
+
zero_allocator: BumpAllocator,
|
|
617
|
+
tbo_subbatch_index: Optional[int] = None,
|
|
618
|
+
):
|
|
619
|
+
"""Communication prepare for attention - TBO operation"""
|
|
620
|
+
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
|
|
621
|
+
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
|
|
622
|
+
)
|
|
623
|
+
state.update(
|
|
624
|
+
dict(
|
|
625
|
+
forward_batch=forward_batch,
|
|
626
|
+
positions=positions,
|
|
627
|
+
zero_allocator=zero_allocator,
|
|
628
|
+
tbo_subbatch_index=tbo_subbatch_index,
|
|
629
|
+
)
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
def op_comm_prepare_mlp(self, state):
|
|
633
|
+
"""Communication prepare for MLP - TBO operation"""
|
|
634
|
+
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
|
|
635
|
+
self.layer_communicator.prepare_mlp(
|
|
636
|
+
state.pop("hidden_states_after_attn"),
|
|
637
|
+
state.pop("residual_after_input_ln"),
|
|
638
|
+
state.forward_batch,
|
|
639
|
+
)
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
def op_mlp(self, state):
|
|
643
|
+
hidden_states = state.pop("hidden_states_mlp_input")
|
|
644
|
+
state.hidden_states_mlp_output = self.block_sparse_moe(
|
|
645
|
+
hidden_states, state.forward_batch
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
def op_comm_postprocess_layer(self, state):
|
|
649
|
+
"""Communication postprocess for layer - TBO operation"""
|
|
650
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
|
651
|
+
state.pop("hidden_states_mlp_output"),
|
|
652
|
+
state.pop("residual_after_comm_pre_mlp"),
|
|
653
|
+
state.forward_batch,
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
output = dict(
|
|
657
|
+
positions=state.positions,
|
|
658
|
+
hidden_states=hidden_states,
|
|
659
|
+
residual=residual,
|
|
660
|
+
forward_batch=state.forward_batch,
|
|
661
|
+
zero_allocator=state.zero_allocator,
|
|
662
|
+
tbo_subbatch_index=state.tbo_subbatch_index,
|
|
663
|
+
)
|
|
664
|
+
return output
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
class MiniMaxM2Model(nn.Module):
|
|
668
|
+
"""MiniMax Model implementation."""
|
|
669
|
+
|
|
670
|
+
fall_back_to_pt_during_load = False
|
|
671
|
+
|
|
672
|
+
def __init__(
|
|
673
|
+
self,
|
|
674
|
+
config: PretrainedConfig,
|
|
675
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
676
|
+
prefix: str = "",
|
|
677
|
+
) -> None:
|
|
678
|
+
super().__init__()
|
|
679
|
+
|
|
680
|
+
self.padding_idx = getattr(config, "pad_token_id", 0)
|
|
681
|
+
self.vocab_size = config.vocab_size
|
|
682
|
+
self.pp_group = get_pp_group()
|
|
683
|
+
|
|
684
|
+
self.embed_tokens = VocabParallelEmbedding(
|
|
685
|
+
config.vocab_size,
|
|
686
|
+
config.hidden_size,
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
def layer_fn(idx, prefix: str) -> nn.Module:
|
|
690
|
+
return MiniMaxM2DecoderLayer(
|
|
691
|
+
config=config,
|
|
692
|
+
layer_id=idx,
|
|
693
|
+
quant_config=quant_config,
|
|
694
|
+
prefix=prefix,
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
|
698
|
+
config.num_hidden_layers,
|
|
699
|
+
layer_fn,
|
|
700
|
+
pp_rank=self.pp_group.rank_in_group,
|
|
701
|
+
pp_size=self.pp_group.world_size,
|
|
702
|
+
prefix=add_prefix("layers", prefix),
|
|
703
|
+
)
|
|
704
|
+
if self.pp_group.is_last_rank:
|
|
705
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
706
|
+
else:
|
|
707
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
|
708
|
+
|
|
709
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
710
|
+
return self.embed_tokens(input_ids)
|
|
711
|
+
|
|
712
|
+
def forward(
|
|
713
|
+
self,
|
|
714
|
+
input_ids: torch.Tensor,
|
|
715
|
+
positions: torch.Tensor,
|
|
716
|
+
forward_batch: ForwardBatch,
|
|
717
|
+
input_embeds: torch.Tensor = None,
|
|
718
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
719
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
|
720
|
+
if self.pp_group.is_first_rank:
|
|
721
|
+
if input_embeds is None:
|
|
722
|
+
hidden_states = self.get_input_embeddings(input_ids)
|
|
723
|
+
else:
|
|
724
|
+
hidden_states = input_embeds
|
|
725
|
+
residual = None
|
|
726
|
+
else:
|
|
727
|
+
assert pp_proxy_tensors is not None
|
|
728
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
|
729
|
+
residual = pp_proxy_tensors["residual"]
|
|
730
|
+
|
|
731
|
+
if forward_batch.can_run_tbo:
|
|
732
|
+
hidden_states, residual = model_forward_maybe_tbo(
|
|
733
|
+
layers=self.layers,
|
|
734
|
+
enable_tbo=True,
|
|
735
|
+
input_data_scatter_mode=ScatterMode.model_input_output(),
|
|
736
|
+
positions=positions,
|
|
737
|
+
forward_batch=forward_batch,
|
|
738
|
+
hidden_states=hidden_states,
|
|
739
|
+
residual=residual,
|
|
740
|
+
)
|
|
741
|
+
else:
|
|
742
|
+
for i in range(self.start_layer, self.end_layer):
|
|
743
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
|
744
|
+
layer = self.layers[i]
|
|
745
|
+
hidden_states, residual = layer(
|
|
746
|
+
positions=positions,
|
|
747
|
+
forward_batch=forward_batch,
|
|
748
|
+
hidden_states=hidden_states,
|
|
749
|
+
residual=residual,
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
if not self.pp_group.is_last_rank:
|
|
753
|
+
return PPProxyTensors(
|
|
754
|
+
{"hidden_states": hidden_states, "residual": residual}
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
if residual is not None:
|
|
758
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
|
759
|
+
else:
|
|
760
|
+
hidden_states = self.norm(hidden_states)
|
|
761
|
+
|
|
762
|
+
return hidden_states
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
class MiniMaxM2ForCausalLM(nn.Module):
|
|
766
|
+
"""MiniMax M2 model for causal language modeling."""
|
|
767
|
+
|
|
768
|
+
def __init__(
|
|
769
|
+
self,
|
|
770
|
+
config: PretrainedConfig,
|
|
771
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
772
|
+
prefix: str = "",
|
|
773
|
+
) -> None:
|
|
774
|
+
super().__init__()
|
|
775
|
+
|
|
776
|
+
self.config = config
|
|
777
|
+
self.quant_config = quant_config
|
|
778
|
+
|
|
779
|
+
self.model = MiniMaxM2Model(
|
|
780
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
if get_pp_group().is_last_rank:
|
|
784
|
+
self.lm_head = ParallelLMHead(
|
|
785
|
+
config.vocab_size,
|
|
786
|
+
config.hidden_size,
|
|
787
|
+
quant_config=None,
|
|
788
|
+
prefix=add_prefix("lm_head", prefix),
|
|
789
|
+
)
|
|
790
|
+
else:
|
|
791
|
+
self.lm_head = PPMissingLayer()
|
|
792
|
+
|
|
793
|
+
self.logits_processor = LogitsProcessor(config)
|
|
794
|
+
|
|
795
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
796
|
+
return self.model.get_input_embeddings(input_ids)
|
|
797
|
+
|
|
798
|
+
@torch.no_grad()
|
|
799
|
+
def forward(
|
|
800
|
+
self,
|
|
801
|
+
input_ids: torch.Tensor,
|
|
802
|
+
positions: torch.Tensor,
|
|
803
|
+
forward_batch: ForwardBatch,
|
|
804
|
+
input_embeds: torch.Tensor = None,
|
|
805
|
+
) -> torch.Tensor:
|
|
806
|
+
# _print_tensor_info(input_ids, "input_ids")
|
|
807
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
|
808
|
+
return self.logits_processor(
|
|
809
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
813
|
+
"""Load model weights with proper mapping for MiniMax architecture."""
|
|
814
|
+
|
|
815
|
+
stacked_params_mapping = [
|
|
816
|
+
# (param_name, shard_name, shard_id)
|
|
817
|
+
("qkv_proj", "q_proj", "q"),
|
|
818
|
+
("qkv_proj", "k_proj", "k"),
|
|
819
|
+
("qkv_proj", "v_proj", "v"),
|
|
820
|
+
("gate_up_proj", "gate_proj", 0),
|
|
821
|
+
("gate_up_proj", "up_proj", 1),
|
|
822
|
+
]
|
|
823
|
+
|
|
824
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
825
|
+
# (param_name, weight_name, expert_id, shard_id)
|
|
826
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
827
|
+
ckpt_gate_proj_name="w1",
|
|
828
|
+
ckpt_down_proj_name="w2",
|
|
829
|
+
ckpt_up_proj_name="w3",
|
|
830
|
+
num_experts=self.config.num_local_experts,
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
params_dict = dict(self.named_parameters())
|
|
834
|
+
loaded_params: Set[str] = set()
|
|
835
|
+
for name, loaded_weight in weights:
|
|
836
|
+
if "rotary_emb.inv_freq" in name:
|
|
837
|
+
continue
|
|
838
|
+
|
|
839
|
+
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
|
840
|
+
if spec_layer is not None:
|
|
841
|
+
continue # skip spec decode layers for main model
|
|
842
|
+
|
|
843
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
844
|
+
# Skip non-stacked layers and experts (experts handled below).
|
|
845
|
+
if weight_name not in name:
|
|
846
|
+
continue
|
|
847
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
|
848
|
+
# Since we handle the experts below in expert_params_mapping,
|
|
849
|
+
# we need to skip here BEFORE we update the name, otherwise
|
|
850
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
|
851
|
+
# will then be updated below in expert_params_mapping
|
|
852
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
|
853
|
+
if ("mlp.experts." in name) and name not in params_dict:
|
|
854
|
+
continue
|
|
855
|
+
name = name.replace(weight_name, param_name)
|
|
856
|
+
# Skip loading extra bias for GPTQ models.
|
|
857
|
+
if name.endswith(".bias") and name not in params_dict:
|
|
858
|
+
continue
|
|
859
|
+
|
|
860
|
+
param = params_dict[name]
|
|
861
|
+
weight_loader = param.weight_loader
|
|
862
|
+
weight_loader(param, loaded_weight, shard_id)
|
|
863
|
+
break
|
|
864
|
+
else:
|
|
865
|
+
for mapping in expert_params_mapping:
|
|
866
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
|
867
|
+
if weight_name not in name:
|
|
868
|
+
continue
|
|
869
|
+
name = name.replace(weight_name, param_name)
|
|
870
|
+
|
|
871
|
+
param = params_dict[name]
|
|
872
|
+
weight_loader = param.weight_loader
|
|
873
|
+
weight_loader(
|
|
874
|
+
param,
|
|
875
|
+
loaded_weight,
|
|
876
|
+
name,
|
|
877
|
+
shard_id=shard_id,
|
|
878
|
+
expert_id=expert_id,
|
|
879
|
+
)
|
|
880
|
+
break
|
|
881
|
+
else:
|
|
882
|
+
# Skip loading extra bias for GPTQ models.
|
|
883
|
+
if name.endswith(".bias") and name not in params_dict:
|
|
884
|
+
continue
|
|
885
|
+
|
|
886
|
+
# Remapping the name of FP8 kv-scale.
|
|
887
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
888
|
+
if name is None:
|
|
889
|
+
continue
|
|
890
|
+
|
|
891
|
+
param = params_dict[name]
|
|
892
|
+
weight_loader = getattr(
|
|
893
|
+
param, "weight_loader", default_weight_loader
|
|
894
|
+
)
|
|
895
|
+
weight_loader(param, loaded_weight)
|
|
896
|
+
loaded_params.add(name)
|
|
897
|
+
return loaded_params
|
|
898
|
+
|
|
899
|
+
@classmethod
|
|
900
|
+
def get_model_config_for_expert_location(cls, config):
|
|
901
|
+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
|
902
|
+
|
|
903
|
+
return ModelConfigForExpertLocation(
|
|
904
|
+
num_layers=config.num_hidden_layers,
|
|
905
|
+
num_logical_experts=config.num_local_experts,
|
|
906
|
+
num_groups=None,
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
def get_spec_layer_idx_from_weight_name(
|
|
911
|
+
config: PretrainedConfig, weight_name: str
|
|
912
|
+
) -> Optional[int]:
|
|
913
|
+
if hasattr(config, "num_mtp_modules") and (config.num_mtp_modules > 0):
|
|
914
|
+
layer_idx = config.num_hidden_layers
|
|
915
|
+
for i in range(config.num_mtp_modules):
|
|
916
|
+
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
|
|
917
|
+
return layer_idx + i
|
|
918
|
+
return None
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
# Entry class for model registration
|
|
922
|
+
EntryClass = MiniMaxM2ForCausalLM
|