sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1134 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
|
16
|
+
"""Inference-only GptOss model compatible with HuggingFace weights."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
from collections.abc import Iterable
|
20
|
+
from functools import partial
|
21
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
22
|
+
|
23
|
+
import torch
|
24
|
+
from torch import nn
|
25
|
+
from transformers import PretrainedConfig
|
26
|
+
|
27
|
+
from sglang.srt.distributed import (
|
28
|
+
get_moe_expert_parallel_rank,
|
29
|
+
get_moe_expert_parallel_world_size,
|
30
|
+
get_moe_tensor_parallel_rank,
|
31
|
+
get_moe_tensor_parallel_world_size,
|
32
|
+
get_pp_group,
|
33
|
+
get_tensor_model_parallel_rank,
|
34
|
+
get_tensor_model_parallel_world_size,
|
35
|
+
tensor_model_parallel_all_reduce,
|
36
|
+
)
|
37
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
38
|
+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
39
|
+
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
40
|
+
from sglang.srt.layers.dp_attention import (
|
41
|
+
get_attention_tp_rank,
|
42
|
+
get_attention_tp_size,
|
43
|
+
get_local_attention_dp_size,
|
44
|
+
)
|
45
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
46
|
+
from sglang.srt.layers.linear import (
|
47
|
+
QKVParallelLinear,
|
48
|
+
ReplicatedLinear,
|
49
|
+
RowParallelLinear,
|
50
|
+
)
|
51
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
52
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
53
|
+
from sglang.srt.layers.moe.topk import TopK
|
54
|
+
from sglang.srt.layers.moe.utils import DeepEPMode
|
55
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
56
|
+
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
57
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
58
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
59
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
60
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
61
|
+
ParallelLMHead,
|
62
|
+
VocabParallelEmbedding,
|
63
|
+
)
|
64
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
65
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
66
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
67
|
+
from sglang.srt.utils import add_prefix, make_layers
|
68
|
+
|
69
|
+
|
70
|
+
class GptOssConfig(PretrainedConfig):
|
71
|
+
model_type = "gpt_oss"
|
72
|
+
|
73
|
+
def __init__(self, **kwargs):
|
74
|
+
super().__init__(**kwargs)
|
75
|
+
|
76
|
+
|
77
|
+
logger = logging.getLogger(__name__)
|
78
|
+
|
79
|
+
|
80
|
+
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
81
|
+
# SGLang assumes exclusive
|
82
|
+
def get_attention_sliding_window_size(config):
|
83
|
+
return config.sliding_window - 1
|
84
|
+
|
85
|
+
|
86
|
+
class GptOssSparseMoeBlock(nn.Module):
|
87
|
+
def __init__(
|
88
|
+
self,
|
89
|
+
layer_id: int,
|
90
|
+
config: GptOssConfig,
|
91
|
+
quant_config: Optional[QuantizationConfig] = None,
|
92
|
+
prefix: str = "",
|
93
|
+
):
|
94
|
+
super().__init__()
|
95
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
96
|
+
self.layer_id = layer_id
|
97
|
+
self.activation = config.hidden_act
|
98
|
+
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
|
99
|
+
self.swiglu_limit = config.swiglu_limit
|
100
|
+
|
101
|
+
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
|
102
|
+
self.topk = None
|
103
|
+
else:
|
104
|
+
self.topk = TopK(
|
105
|
+
top_k=config.num_experts_per_tok,
|
106
|
+
renormalize=True,
|
107
|
+
)
|
108
|
+
|
109
|
+
self.top_k = config.num_experts_per_tok
|
110
|
+
experts_type = get_moe_impl_class()
|
111
|
+
extra_kwargs = {}
|
112
|
+
if experts_type.__name__ == "FusedMoE":
|
113
|
+
quant_config_name = (
|
114
|
+
quant_config.get_name() if quant_config is not None else None
|
115
|
+
)
|
116
|
+
extra_kwargs = {
|
117
|
+
"enable_flashinfer_cutlass_moe": global_server_args_dict[
|
118
|
+
"enable_flashinfer_cutlass_moe"
|
119
|
+
],
|
120
|
+
# for moe gate_up_proj and down_proj and their bias loading
|
121
|
+
"use_weight_loader_fused": quant_config_name != "mxfp4",
|
122
|
+
}
|
123
|
+
self.experts = experts_type(
|
124
|
+
num_experts=config.num_local_experts
|
125
|
+
+ global_server_args_dict["ep_num_redundant_experts"],
|
126
|
+
top_k=config.num_experts_per_tok,
|
127
|
+
layer_id=layer_id,
|
128
|
+
hidden_size=config.hidden_size,
|
129
|
+
intermediate_size=config.intermediate_size,
|
130
|
+
quant_config=quant_config,
|
131
|
+
activation=self.activation,
|
132
|
+
activation_alpha=self.activation_alpha,
|
133
|
+
swiglu_limit=self.swiglu_limit,
|
134
|
+
with_bias=True,
|
135
|
+
prefix=add_prefix("experts", prefix),
|
136
|
+
**(
|
137
|
+
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
138
|
+
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
139
|
+
else {}
|
140
|
+
),
|
141
|
+
**extra_kwargs,
|
142
|
+
)
|
143
|
+
|
144
|
+
self.router = ReplicatedLinear(
|
145
|
+
config.hidden_size,
|
146
|
+
config.num_local_experts,
|
147
|
+
bias=True,
|
148
|
+
quant_config=None,
|
149
|
+
prefix=add_prefix("gate", prefix),
|
150
|
+
params_dtype=config.torch_dtype,
|
151
|
+
)
|
152
|
+
|
153
|
+
def forward(
|
154
|
+
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
155
|
+
) -> torch.Tensor:
|
156
|
+
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
157
|
+
return self.forward_normal(hidden_states)
|
158
|
+
else:
|
159
|
+
raise Exception("forward_deepep branch not implemented yet")
|
160
|
+
|
161
|
+
def get_moe_weights(self):
|
162
|
+
return [
|
163
|
+
x.data
|
164
|
+
for name, x in self.experts.named_parameters()
|
165
|
+
if name not in ["correction_bias"]
|
166
|
+
]
|
167
|
+
|
168
|
+
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
169
|
+
num_tokens, hidden_dim = hidden_states.shape
|
170
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
171
|
+
|
172
|
+
# router_logits: (num_tokens, n_experts)
|
173
|
+
router_logits, _ = self.router(hidden_states)
|
174
|
+
|
175
|
+
kwargs = {"hidden_states": hidden_states}
|
176
|
+
if self.topk is not None:
|
177
|
+
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
178
|
+
else:
|
179
|
+
kwargs["topk_output"] = (self.top_k, router_logits)
|
180
|
+
final_hidden_states = self.experts(**kwargs)
|
181
|
+
|
182
|
+
if self.tp_size > 1:
|
183
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
184
|
+
|
185
|
+
ans = final_hidden_states.view(num_tokens, hidden_dim)
|
186
|
+
return ans
|
187
|
+
|
188
|
+
|
189
|
+
class GptOssAttention(nn.Module):
|
190
|
+
def __init__(
|
191
|
+
self,
|
192
|
+
hidden_size: int,
|
193
|
+
num_heads: int,
|
194
|
+
num_kv_heads: int,
|
195
|
+
layer_id: int = 0,
|
196
|
+
rope_theta: float = 10000,
|
197
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
198
|
+
max_position_embeddings: int = 8192,
|
199
|
+
head_dim: Optional[int] = None,
|
200
|
+
rms_norm_eps: float = 1e-06,
|
201
|
+
attention_bias: bool = False,
|
202
|
+
quant_config: Optional[QuantizationConfig] = None,
|
203
|
+
prefix: str = "",
|
204
|
+
sliding_window_size: int = -1, # if -1, normal attention, else, window attention.
|
205
|
+
layer_type: str = "",
|
206
|
+
params_dtype: torch.dtype = torch.bfloat16,
|
207
|
+
) -> None:
|
208
|
+
super().__init__()
|
209
|
+
self.hidden_size = hidden_size
|
210
|
+
self.sliding_window_size = sliding_window_size
|
211
|
+
|
212
|
+
attn_tp_rank = get_attention_tp_rank()
|
213
|
+
attn_tp_size = get_attention_tp_size()
|
214
|
+
|
215
|
+
self.total_num_heads = num_heads
|
216
|
+
assert self.total_num_heads % attn_tp_size == 0
|
217
|
+
self.num_heads = self.total_num_heads // attn_tp_size
|
218
|
+
self.total_num_kv_heads = num_kv_heads
|
219
|
+
if self.total_num_kv_heads >= attn_tp_size:
|
220
|
+
# Number of KV heads is greater than TP size, so we partition
|
221
|
+
# the KV heads across multiple tensor parallel GPUs.
|
222
|
+
assert self.total_num_kv_heads % attn_tp_size == 0
|
223
|
+
else:
|
224
|
+
# Number of KV heads is less than TP size, so we replicate
|
225
|
+
# the KV heads across multiple tensor parallel GPUs.
|
226
|
+
assert attn_tp_size % self.total_num_kv_heads == 0
|
227
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
228
|
+
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
229
|
+
self.q_size = self.num_heads * self.head_dim
|
230
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
231
|
+
self.scaling = self.head_dim**-0.5
|
232
|
+
self.rope_theta = rope_theta
|
233
|
+
self.max_position_embeddings = max_position_embeddings
|
234
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
235
|
+
|
236
|
+
self.qkv_proj = QKVParallelLinear(
|
237
|
+
hidden_size,
|
238
|
+
self.head_dim,
|
239
|
+
self.total_num_heads,
|
240
|
+
self.total_num_kv_heads,
|
241
|
+
bias=attention_bias,
|
242
|
+
params_dtype=params_dtype,
|
243
|
+
quant_config=quant_config,
|
244
|
+
tp_rank=attn_tp_rank,
|
245
|
+
tp_size=attn_tp_size,
|
246
|
+
prefix=add_prefix("qkv_proj", prefix),
|
247
|
+
)
|
248
|
+
|
249
|
+
self.sinks = nn.Parameter(
|
250
|
+
torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False
|
251
|
+
)
|
252
|
+
|
253
|
+
self.o_proj = RowParallelLinear(
|
254
|
+
self.total_num_heads * self.head_dim,
|
255
|
+
hidden_size,
|
256
|
+
bias=attention_bias,
|
257
|
+
quant_config=quant_config,
|
258
|
+
tp_rank=attn_tp_rank,
|
259
|
+
tp_size=attn_tp_size,
|
260
|
+
reduce_results=False,
|
261
|
+
params_dtype=params_dtype,
|
262
|
+
prefix=add_prefix("o_proj", prefix),
|
263
|
+
)
|
264
|
+
|
265
|
+
self.rotary_emb = get_rope(
|
266
|
+
self.head_dim,
|
267
|
+
rotary_dim=self.head_dim,
|
268
|
+
max_position=max_position_embeddings,
|
269
|
+
base=rope_theta,
|
270
|
+
rope_scaling=rope_scaling,
|
271
|
+
)
|
272
|
+
|
273
|
+
assert layer_type in {"sliding_attention", "full_attention"}
|
274
|
+
use_sliding_window = layer_type == "sliding_attention"
|
275
|
+
self.attn = RadixAttention(
|
276
|
+
self.num_heads,
|
277
|
+
self.head_dim,
|
278
|
+
self.scaling,
|
279
|
+
num_kv_heads=self.num_kv_heads,
|
280
|
+
layer_id=layer_id,
|
281
|
+
prefix=add_prefix("attn", prefix),
|
282
|
+
sliding_window_size=(sliding_window_size if use_sliding_window else -1),
|
283
|
+
)
|
284
|
+
self.layer_id = layer_id
|
285
|
+
|
286
|
+
def forward_prepare(
|
287
|
+
self,
|
288
|
+
positions: torch.Tensor,
|
289
|
+
hidden_states: torch.Tensor,
|
290
|
+
forward_batch: ForwardBatch,
|
291
|
+
):
|
292
|
+
if hidden_states.shape[0] == 0:
|
293
|
+
return hidden_states, forward_batch, None
|
294
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
295
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
296
|
+
q, k = self.rotary_emb(positions, q, k)
|
297
|
+
inner_state = q, k, v, forward_batch
|
298
|
+
return None, forward_batch, inner_state
|
299
|
+
|
300
|
+
def forward_core(self, intermediate_state):
|
301
|
+
hidden_states, forward_batch, inner_state = intermediate_state
|
302
|
+
if inner_state is None:
|
303
|
+
return hidden_states
|
304
|
+
attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32))
|
305
|
+
output, _ = self.o_proj(attn_output)
|
306
|
+
return output
|
307
|
+
|
308
|
+
def forward(
|
309
|
+
self,
|
310
|
+
positions: torch.Tensor,
|
311
|
+
hidden_states: torch.Tensor,
|
312
|
+
forward_batch: ForwardBatch,
|
313
|
+
) -> torch.Tensor:
|
314
|
+
s = self.forward_prepare(
|
315
|
+
positions=positions,
|
316
|
+
hidden_states=hidden_states,
|
317
|
+
forward_batch=forward_batch,
|
318
|
+
)
|
319
|
+
return self.forward_core(s)
|
320
|
+
|
321
|
+
|
322
|
+
class GptOssDecoderLayer(nn.Module):
|
323
|
+
def __init__(
|
324
|
+
self,
|
325
|
+
config: GptOssConfig,
|
326
|
+
layer_id: int,
|
327
|
+
quant_config: Optional[QuantizationConfig] = None,
|
328
|
+
prefix: str = "",
|
329
|
+
sliding_window_size: int | None = None,
|
330
|
+
) -> None:
|
331
|
+
super().__init__()
|
332
|
+
self.config = config
|
333
|
+
self.hidden_size = config.hidden_size
|
334
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
335
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
336
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
337
|
+
head_dim = getattr(
|
338
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
339
|
+
)
|
340
|
+
rms_norm_eps = config.rms_norm_eps
|
341
|
+
attention_bias = config.attention_bias
|
342
|
+
|
343
|
+
if sliding_window_size is None:
|
344
|
+
self.sliding_window_size = get_attention_sliding_window_size(self.config)
|
345
|
+
else:
|
346
|
+
self.sliding_window_size = sliding_window_size
|
347
|
+
|
348
|
+
self.self_attn = GptOssAttention(
|
349
|
+
hidden_size=self.hidden_size,
|
350
|
+
num_heads=config.num_attention_heads,
|
351
|
+
num_kv_heads=config.num_key_value_heads,
|
352
|
+
layer_id=layer_id,
|
353
|
+
rope_theta=rope_theta,
|
354
|
+
rope_scaling=rope_scaling,
|
355
|
+
max_position_embeddings=max_position_embeddings,
|
356
|
+
head_dim=head_dim,
|
357
|
+
rms_norm_eps=rms_norm_eps,
|
358
|
+
attention_bias=attention_bias,
|
359
|
+
prefix=add_prefix("self_attn", prefix),
|
360
|
+
sliding_window_size=self.sliding_window_size,
|
361
|
+
layer_type=config.layer_types[layer_id],
|
362
|
+
params_dtype=config.torch_dtype,
|
363
|
+
)
|
364
|
+
|
365
|
+
self.layer_id = layer_id
|
366
|
+
|
367
|
+
self.attn_tp_size = get_attention_tp_size()
|
368
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
369
|
+
self.local_dp_size = get_local_attention_dp_size()
|
370
|
+
|
371
|
+
# GptOss all layers are sparse and have no nextn now
|
372
|
+
self.is_layer_sparse = True
|
373
|
+
is_previous_layer_sparse = True
|
374
|
+
|
375
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
376
|
+
layer_id=layer_id,
|
377
|
+
num_layers=config.num_hidden_layers,
|
378
|
+
is_layer_sparse=self.is_layer_sparse,
|
379
|
+
is_previous_layer_sparse=is_previous_layer_sparse,
|
380
|
+
)
|
381
|
+
|
382
|
+
if self.is_layer_sparse:
|
383
|
+
self.mlp = GptOssSparseMoeBlock(
|
384
|
+
layer_id=self.layer_id,
|
385
|
+
config=config,
|
386
|
+
quant_config=quant_config,
|
387
|
+
prefix=add_prefix("mlp", prefix),
|
388
|
+
)
|
389
|
+
else:
|
390
|
+
raise NotImplementedError(
|
391
|
+
"Dense MLP is not implemented for GptOssDecoderLayer. "
|
392
|
+
"Please use GptOssSparseMoeBlock instead."
|
393
|
+
)
|
394
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
395
|
+
self.post_attention_layernorm = RMSNorm(
|
396
|
+
config.hidden_size, eps=config.rms_norm_eps
|
397
|
+
)
|
398
|
+
|
399
|
+
self.layer_communicator = LayerCommunicator(
|
400
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
401
|
+
input_layernorm=self.input_layernorm,
|
402
|
+
post_attention_layernorm=self.post_attention_layernorm,
|
403
|
+
)
|
404
|
+
|
405
|
+
def forward(
|
406
|
+
self,
|
407
|
+
positions: torch.Tensor,
|
408
|
+
hidden_states: torch.Tensor,
|
409
|
+
forward_batch: ForwardBatch,
|
410
|
+
residual: Optional[torch.Tensor],
|
411
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
412
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
413
|
+
hidden_states, residual, forward_batch
|
414
|
+
)
|
415
|
+
|
416
|
+
if hidden_states.shape[0] != 0:
|
417
|
+
hidden_states = self.self_attn(
|
418
|
+
positions=positions,
|
419
|
+
hidden_states=hidden_states,
|
420
|
+
forward_batch=forward_batch,
|
421
|
+
)
|
422
|
+
|
423
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
424
|
+
hidden_states, residual, forward_batch
|
425
|
+
)
|
426
|
+
|
427
|
+
hidden_states = self.mlp(hidden_states, forward_batch)
|
428
|
+
|
429
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
430
|
+
hidden_states, residual, forward_batch
|
431
|
+
)
|
432
|
+
|
433
|
+
return hidden_states, residual
|
434
|
+
|
435
|
+
|
436
|
+
class GptOssModel(nn.Module):
|
437
|
+
def __init__(
|
438
|
+
self,
|
439
|
+
config: PretrainedConfig,
|
440
|
+
quant_config: Optional[QuantizationConfig] = None,
|
441
|
+
prefix: str = "",
|
442
|
+
decoder_layer_type: type[nn.Module] = GptOssDecoderLayer,
|
443
|
+
) -> None:
|
444
|
+
super().__init__()
|
445
|
+
self.padding_idx = config.pad_token_id
|
446
|
+
self.vocab_size = config.vocab_size
|
447
|
+
self.pp_group = get_pp_group()
|
448
|
+
|
449
|
+
if self.pp_group.is_first_rank:
|
450
|
+
self.embed_tokens = VocabParallelEmbedding(
|
451
|
+
config.vocab_size,
|
452
|
+
config.hidden_size,
|
453
|
+
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
454
|
+
prefix=add_prefix("embed_tokens", prefix),
|
455
|
+
)
|
456
|
+
else:
|
457
|
+
self.embed_tokens = PPMissingLayer()
|
458
|
+
|
459
|
+
# Use the provided decoder layer type or default to GptOssDecoderLayer
|
460
|
+
decoder_layer_type = decoder_layer_type or GptOssDecoderLayer
|
461
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
462
|
+
config.num_hidden_layers,
|
463
|
+
lambda idx, prefix: decoder_layer_type(
|
464
|
+
layer_id=idx,
|
465
|
+
config=config,
|
466
|
+
quant_config=quant_config,
|
467
|
+
prefix=prefix,
|
468
|
+
),
|
469
|
+
pp_rank=self.pp_group.rank_in_group,
|
470
|
+
pp_size=self.pp_group.world_size,
|
471
|
+
prefix=add_prefix("layers", prefix),
|
472
|
+
)
|
473
|
+
if self.pp_group.is_last_rank:
|
474
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
475
|
+
else:
|
476
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
477
|
+
|
478
|
+
self.layers_to_capture = []
|
479
|
+
|
480
|
+
def forward(
|
481
|
+
self,
|
482
|
+
input_ids: torch.Tensor,
|
483
|
+
positions: torch.Tensor,
|
484
|
+
forward_batch: ForwardBatch,
|
485
|
+
input_embeds: torch.Tensor = None,
|
486
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
487
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
488
|
+
if self.pp_group.is_first_rank:
|
489
|
+
if input_embeds is None:
|
490
|
+
hidden_states = self.embed_tokens(input_ids)
|
491
|
+
else:
|
492
|
+
hidden_states = input_embeds
|
493
|
+
residual = None
|
494
|
+
else:
|
495
|
+
assert pp_proxy_tensors is not None
|
496
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
497
|
+
residual = pp_proxy_tensors["residual"]
|
498
|
+
|
499
|
+
aux_hidden_states = []
|
500
|
+
for i in range(self.start_layer, self.end_layer):
|
501
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
502
|
+
if i in self.layers_to_capture:
|
503
|
+
aux_hidden_states.append(hidden_states + residual)
|
504
|
+
layer = self.layers[i]
|
505
|
+
hidden_states, residual = layer(
|
506
|
+
positions, hidden_states, forward_batch, residual
|
507
|
+
)
|
508
|
+
if not self.pp_group.is_last_rank:
|
509
|
+
return PPProxyTensors(
|
510
|
+
{
|
511
|
+
"hidden_states": hidden_states,
|
512
|
+
"residual": residual,
|
513
|
+
}
|
514
|
+
)
|
515
|
+
else:
|
516
|
+
if hidden_states.shape[0] != 0:
|
517
|
+
if residual is None:
|
518
|
+
hidden_states = self.norm(hidden_states)
|
519
|
+
else:
|
520
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
521
|
+
if len(aux_hidden_states) == 0:
|
522
|
+
return hidden_states
|
523
|
+
|
524
|
+
return hidden_states, aux_hidden_states
|
525
|
+
|
526
|
+
|
527
|
+
class GptOssForCausalLM(nn.Module):
|
528
|
+
fall_back_to_pt_during_load = False
|
529
|
+
|
530
|
+
def __init__(
|
531
|
+
self,
|
532
|
+
config: GptOssConfig,
|
533
|
+
quant_config: Optional[QuantizationConfig] = None,
|
534
|
+
prefix: str = "",
|
535
|
+
) -> None:
|
536
|
+
super().__init__()
|
537
|
+
self.pp_group = get_pp_group()
|
538
|
+
self.config = config
|
539
|
+
self.quant_config = quant_config
|
540
|
+
self.model = GptOssModel(
|
541
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
542
|
+
)
|
543
|
+
self.lm_head = ParallelLMHead(
|
544
|
+
config.vocab_size,
|
545
|
+
config.hidden_size,
|
546
|
+
# quant_config=quant_config,
|
547
|
+
prefix=add_prefix("lm_head", prefix),
|
548
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
549
|
+
)
|
550
|
+
self.logits_processor = LogitsProcessor(config)
|
551
|
+
self.capture_aux_hidden_states = False
|
552
|
+
|
553
|
+
@torch.no_grad()
|
554
|
+
def forward(
|
555
|
+
self,
|
556
|
+
input_ids: torch.Tensor,
|
557
|
+
positions: torch.Tensor,
|
558
|
+
forward_batch: ForwardBatch,
|
559
|
+
input_embeds: torch.Tensor = None,
|
560
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
561
|
+
) -> torch.Tensor:
|
562
|
+
hidden_states = self.model(
|
563
|
+
input_ids,
|
564
|
+
positions,
|
565
|
+
forward_batch,
|
566
|
+
input_embeds,
|
567
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
568
|
+
)
|
569
|
+
|
570
|
+
aux_hidden_states = None
|
571
|
+
if self.capture_aux_hidden_states:
|
572
|
+
hidden_states, aux_hidden_states = hidden_states
|
573
|
+
|
574
|
+
if self.pp_group.is_last_rank:
|
575
|
+
return self.logits_processor(
|
576
|
+
input_ids,
|
577
|
+
hidden_states,
|
578
|
+
self.lm_head,
|
579
|
+
forward_batch,
|
580
|
+
aux_hidden_states,
|
581
|
+
)
|
582
|
+
else:
|
583
|
+
return hidden_states
|
584
|
+
|
585
|
+
@property
|
586
|
+
def start_layer(self):
|
587
|
+
return self.model.start_layer
|
588
|
+
|
589
|
+
@property
|
590
|
+
def end_layer(self):
|
591
|
+
return self.model.end_layer
|
592
|
+
|
593
|
+
def _get_default_weight_mapping(self):
|
594
|
+
"""Generate default weight name mapping for GptOss safetensors."""
|
595
|
+
weight_mapping = {}
|
596
|
+
|
597
|
+
# Map router weights to gate
|
598
|
+
weight_mapping["embedding.weight"] = "model.embed_tokens.weight"
|
599
|
+
weight_mapping["unembedding.weight"] = "lm_head.weight"
|
600
|
+
weight_mapping["norm.scale"] = "model.norm.weight"
|
601
|
+
for layer_id in range(self.config.num_hidden_layers):
|
602
|
+
weight_mapping[f"block.{layer_id}.attn.q_proj.weight"] = (
|
603
|
+
f"model.layers.{layer_id}.self_attn.q_proj.weight"
|
604
|
+
)
|
605
|
+
weight_mapping[f"block.{layer_id}.attn.q_proj.bias"] = (
|
606
|
+
f"model.layers.{layer_id}.self_attn.q_proj.bias"
|
607
|
+
)
|
608
|
+
|
609
|
+
weight_mapping[f"block.{layer_id}.attn.k_proj.weight"] = (
|
610
|
+
f"model.layers.{layer_id}.self_attn.k_proj.weight"
|
611
|
+
)
|
612
|
+
weight_mapping[f"block.{layer_id}.attn.k_proj.bias"] = (
|
613
|
+
f"model.layers.{layer_id}.self_attn.k_proj.bias"
|
614
|
+
)
|
615
|
+
|
616
|
+
weight_mapping[f"block.{layer_id}.attn.v_proj.weight"] = (
|
617
|
+
f"model.layers.{layer_id}.self_attn.v_proj.weight"
|
618
|
+
)
|
619
|
+
weight_mapping[f"block.{layer_id}.attn.v_proj.bias"] = (
|
620
|
+
f"model.layers.{layer_id}.self_attn.v_proj.bias"
|
621
|
+
)
|
622
|
+
|
623
|
+
weight_mapping[f"block.{layer_id}.attn.out.weight"] = (
|
624
|
+
f"model.layers.{layer_id}.self_attn.o_proj.weight"
|
625
|
+
)
|
626
|
+
weight_mapping[f"block.{layer_id}.attn.out.bias"] = (
|
627
|
+
f"model.layers.{layer_id}.self_attn.o_proj.bias"
|
628
|
+
)
|
629
|
+
weight_mapping[f"block.{layer_id}.attn.sinks"] = (
|
630
|
+
f"model.layers.{layer_id}.self_attn.sinks"
|
631
|
+
)
|
632
|
+
weight_mapping[f"block.{layer_id}.attn.norm.scale"] = (
|
633
|
+
f"model.layers.{layer_id}.input_layernorm.weight"
|
634
|
+
)
|
635
|
+
|
636
|
+
weight_mapping[f"block.{layer_id}.mlp.gate.weight"] = (
|
637
|
+
f"model.layers.{layer_id}.mlp.router.weight"
|
638
|
+
)
|
639
|
+
weight_mapping[f"block.{layer_id}.mlp.gate.bias"] = (
|
640
|
+
f"model.layers.{layer_id}.mlp.router.bias"
|
641
|
+
)
|
642
|
+
weight_mapping[f"block.{layer_id}.mlp.norm.scale"] = (
|
643
|
+
f"model.layers.{layer_id}.post_attention_layernorm.weight"
|
644
|
+
)
|
645
|
+
weight_mapping[f"block.{layer_id}.mlp.experts.gate_up_proj"] = (
|
646
|
+
f"model.layers.{layer_id}.mlp.experts.gate_up_proj"
|
647
|
+
)
|
648
|
+
weight_mapping[f"block.{layer_id}.mlp.gate_up_proj_bias"] = (
|
649
|
+
f"model.layers.{layer_id}.mlp.experts.gate_up_proj_bias"
|
650
|
+
)
|
651
|
+
weight_mapping[f"block.{layer_id}.mlp.down_proj"] = (
|
652
|
+
f"model.layers.{layer_id}.mlp.experts.mlp2_weight"
|
653
|
+
)
|
654
|
+
weight_mapping[f"block.{layer_id}.mlp.down_proj_bias"] = (
|
655
|
+
f"model.layers.{layer_id}.mlp.experts.mlp2_bias"
|
656
|
+
)
|
657
|
+
|
658
|
+
return weight_mapping
|
659
|
+
|
660
|
+
# TODO beautify code
|
661
|
+
def load_weights(
|
662
|
+
self,
|
663
|
+
weights: Iterable[Tuple[str, torch.Tensor]],
|
664
|
+
is_nextn: bool = False,
|
665
|
+
weight_name_mapping: dict = None,
|
666
|
+
):
|
667
|
+
quant_config_name = (
|
668
|
+
self.quant_config.get_name() if self.quant_config is not None else None
|
669
|
+
)
|
670
|
+
if quant_config_name != "mxfp4":
|
671
|
+
self._load_normal_weights(
|
672
|
+
weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping
|
673
|
+
)
|
674
|
+
else:
|
675
|
+
self._load_weights_mxfp4(
|
676
|
+
weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping
|
677
|
+
)
|
678
|
+
|
679
|
+
def _load_weights_mxfp4(self, weights, is_nextn, weight_name_mapping):
|
680
|
+
mxfp4_weights = []
|
681
|
+
normal_weights = []
|
682
|
+
|
683
|
+
for name, weight in weights:
|
684
|
+
if (
|
685
|
+
".experts" in name
|
686
|
+
and self.quant_config is not None
|
687
|
+
and self.quant_config.get_name() == "mxfp4"
|
688
|
+
):
|
689
|
+
mxfp4_weights.append((name, weight))
|
690
|
+
else:
|
691
|
+
normal_weights.append((name, weight))
|
692
|
+
|
693
|
+
mxfp4_loaded_params = self._load_mxfp4_experts_weights(mxfp4_weights)
|
694
|
+
self._load_normal_weights(
|
695
|
+
normal_weights,
|
696
|
+
is_nextn=is_nextn,
|
697
|
+
weight_name_mapping=weight_name_mapping,
|
698
|
+
other_loaded_param_names=mxfp4_loaded_params,
|
699
|
+
)
|
700
|
+
|
701
|
+
def _load_mxfp4_experts_weights(self, weights):
|
702
|
+
|
703
|
+
params_dict = dict(self.named_parameters())
|
704
|
+
loaded_params: set[str] = set()
|
705
|
+
mxfp4_block = 32
|
706
|
+
|
707
|
+
moe_tp_rank = get_moe_tensor_parallel_rank()
|
708
|
+
moe_tp_size = get_moe_tensor_parallel_world_size()
|
709
|
+
moe_ep_rank = get_moe_expert_parallel_rank()
|
710
|
+
moe_ep_size = get_moe_expert_parallel_world_size()
|
711
|
+
|
712
|
+
intermediate_size = self.config.intermediate_size
|
713
|
+
intermediate_size_block = intermediate_size // mxfp4_block
|
714
|
+
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
|
715
|
+
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
716
|
+
|
717
|
+
# Calculate common slicing bounds for current rank
|
718
|
+
assert self.config.num_local_experts % moe_ep_size == 0
|
719
|
+
moe_num_global_experts = self.config.num_local_experts
|
720
|
+
moe_num_local_experts = self.config.num_local_experts // moe_ep_size
|
721
|
+
moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
|
722
|
+
moe_tp_rank_end = min(
|
723
|
+
(moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
|
724
|
+
)
|
725
|
+
moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
|
726
|
+
moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
|
727
|
+
|
728
|
+
for name, weight in weights:
|
729
|
+
weight = weight.cuda()
|
730
|
+
|
731
|
+
if "gate_up_proj_blocks" in name:
|
732
|
+
# Handle MLP gate and up projection weights
|
733
|
+
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
|
734
|
+
|
735
|
+
# flat weight from (E, 2 * N, block_size, entry_per_block)
|
736
|
+
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
|
737
|
+
weight = weight.view(
|
738
|
+
moe_num_global_experts, 2 * intermediate_size, -1
|
739
|
+
).contiguous()
|
740
|
+
|
741
|
+
narrow_weight = weight[
|
742
|
+
moe_ep_rank_start:moe_ep_rank_end,
|
743
|
+
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
|
744
|
+
...,
|
745
|
+
]
|
746
|
+
|
747
|
+
param = params_dict[new_name]
|
748
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
749
|
+
weight_loader(
|
750
|
+
param,
|
751
|
+
narrow_weight,
|
752
|
+
weight_name=new_name,
|
753
|
+
shard_id=None,
|
754
|
+
expert_id=None,
|
755
|
+
)
|
756
|
+
loaded_params.add(new_name)
|
757
|
+
|
758
|
+
elif "down_proj_blocks" in name:
|
759
|
+
# Handle MLP down projection weights
|
760
|
+
new_name = name.replace("down_proj_blocks", "w2_weight")
|
761
|
+
# same flatten here, but since 2 mx4 value are packed in 1
|
762
|
+
# uint8, divide by 2
|
763
|
+
weight = weight.view(
|
764
|
+
moe_num_global_experts, -1, intermediate_size // 2
|
765
|
+
).contiguous()
|
766
|
+
narrow_weight = weight[
|
767
|
+
moe_ep_rank_start:moe_ep_rank_end,
|
768
|
+
...,
|
769
|
+
moe_tp_rank_start // 2 : moe_tp_rank_end // 2,
|
770
|
+
]
|
771
|
+
|
772
|
+
param = params_dict[new_name]
|
773
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
774
|
+
weight_loader(
|
775
|
+
param,
|
776
|
+
narrow_weight,
|
777
|
+
weight_name=new_name,
|
778
|
+
shard_id=None,
|
779
|
+
expert_id=None,
|
780
|
+
)
|
781
|
+
loaded_params.add(new_name)
|
782
|
+
|
783
|
+
elif "gate_up_proj_scales" in name:
|
784
|
+
# Handle MLP gate and up projection weights scale
|
785
|
+
new_name = name.replace("gate_up_proj_scales", "w13_weight_scale")
|
786
|
+
narrow_weight = weight[
|
787
|
+
moe_ep_rank_start:moe_ep_rank_end,
|
788
|
+
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
|
789
|
+
...,
|
790
|
+
]
|
791
|
+
|
792
|
+
param = params_dict[new_name]
|
793
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
794
|
+
weight_loader(
|
795
|
+
param,
|
796
|
+
narrow_weight,
|
797
|
+
weight_name=new_name,
|
798
|
+
shard_id=None,
|
799
|
+
expert_id=None,
|
800
|
+
)
|
801
|
+
loaded_params.add(new_name)
|
802
|
+
|
803
|
+
elif "down_proj_scales" in name:
|
804
|
+
# Handle MLP down projection weights
|
805
|
+
new_name = name.replace("down_proj_scales", "w2_weight_scale")
|
806
|
+
narrow_weight = weight[
|
807
|
+
moe_ep_rank_start:moe_ep_rank_end,
|
808
|
+
...,
|
809
|
+
moe_tp_rank_start // mxfp4_block : moe_tp_rank_end // mxfp4_block,
|
810
|
+
]
|
811
|
+
|
812
|
+
param = params_dict[new_name]
|
813
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
814
|
+
weight_loader(
|
815
|
+
param,
|
816
|
+
narrow_weight,
|
817
|
+
weight_name=new_name,
|
818
|
+
shard_id=None,
|
819
|
+
expert_id=None,
|
820
|
+
)
|
821
|
+
loaded_params.add(new_name)
|
822
|
+
elif "gate_up_proj_bias" in name:
|
823
|
+
# Handle MLP gate and up projection biases
|
824
|
+
new_name = name.replace("gate_up_proj_bias", "w13_weight_bias")
|
825
|
+
|
826
|
+
narrow_weight = weight[
|
827
|
+
moe_ep_rank_start:moe_ep_rank_end,
|
828
|
+
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
|
829
|
+
]
|
830
|
+
|
831
|
+
param = params_dict[new_name]
|
832
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
833
|
+
weight_loader(
|
834
|
+
param,
|
835
|
+
narrow_weight,
|
836
|
+
weight_name=new_name,
|
837
|
+
shard_id=None,
|
838
|
+
expert_id=None,
|
839
|
+
)
|
840
|
+
loaded_params.add(new_name)
|
841
|
+
|
842
|
+
elif "down_proj_bias" in name:
|
843
|
+
narrow_weight = weight[moe_ep_rank_start:moe_ep_rank_end, ...]
|
844
|
+
if moe_tp_rank != 0:
|
845
|
+
narrow_weight = torch.zeros_like(narrow_weight)
|
846
|
+
|
847
|
+
# Handle MLP down projection bias
|
848
|
+
new_name = name.replace("down_proj_bias", "w2_weight_bias")
|
849
|
+
param = params_dict[new_name]
|
850
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
851
|
+
weight_loader(
|
852
|
+
param,
|
853
|
+
narrow_weight,
|
854
|
+
weight_name=new_name,
|
855
|
+
shard_id=None,
|
856
|
+
expert_id=None,
|
857
|
+
)
|
858
|
+
loaded_params.add(new_name)
|
859
|
+
|
860
|
+
return loaded_params
|
861
|
+
|
862
|
+
def _load_normal_weights(
|
863
|
+
self,
|
864
|
+
weights,
|
865
|
+
is_nextn: bool,
|
866
|
+
weight_name_mapping: dict,
|
867
|
+
other_loaded_param_names=[],
|
868
|
+
):
|
869
|
+
tp_rank = get_tensor_model_parallel_rank()
|
870
|
+
if is_nextn:
|
871
|
+
logging.warning(
|
872
|
+
"Loading weights for nextn is currently not supported in GptOssForCausalLM. "
|
873
|
+
)
|
874
|
+
return
|
875
|
+
weights = _canonicalize_weights(self.config, weights)
|
876
|
+
weights = sorted(weights, key=lambda x: x[0]) # Sort by name for consistency
|
877
|
+
|
878
|
+
new_weights = []
|
879
|
+
for name, p in weights:
|
880
|
+
if "qkv.weight" in name:
|
881
|
+
q_proj, k_proj, v_proj = p.split(
|
882
|
+
[
|
883
|
+
self.config.num_attention_heads * self.config.head_dim,
|
884
|
+
self.config.num_key_value_heads * self.config.head_dim,
|
885
|
+
self.config.num_key_value_heads * self.config.head_dim,
|
886
|
+
],
|
887
|
+
dim=0,
|
888
|
+
)
|
889
|
+
new_weights.append(
|
890
|
+
(f"{name.replace('qkv.weight', 'q_proj.weight')}", q_proj)
|
891
|
+
)
|
892
|
+
new_weights.append(
|
893
|
+
(f"{name.replace('qkv.weight', 'k_proj.weight')}", k_proj)
|
894
|
+
)
|
895
|
+
new_weights.append(
|
896
|
+
(f"{name.replace('qkv.weight', 'v_proj.weight')}", v_proj)
|
897
|
+
)
|
898
|
+
elif "qkv.bias" in name:
|
899
|
+
q_bias, k_bias, v_bias = p.split(
|
900
|
+
[
|
901
|
+
self.config.num_attention_heads * self.config.head_dim,
|
902
|
+
self.config.num_key_value_heads * self.config.head_dim,
|
903
|
+
self.config.num_key_value_heads * self.config.head_dim,
|
904
|
+
],
|
905
|
+
dim=0,
|
906
|
+
)
|
907
|
+
new_weights.append(
|
908
|
+
(f"{name.replace('qkv.bias', 'q_proj.bias')}", q_bias)
|
909
|
+
)
|
910
|
+
new_weights.append(
|
911
|
+
(f"{name.replace('qkv.bias', 'k_proj.bias')}", k_bias)
|
912
|
+
)
|
913
|
+
new_weights.append(
|
914
|
+
(f"{name.replace('qkv.bias', 'v_proj.bias')}", v_bias)
|
915
|
+
)
|
916
|
+
else:
|
917
|
+
new_weights.append((name, p))
|
918
|
+
weights = new_weights
|
919
|
+
|
920
|
+
# Use provided weight name mapping if available, otherwise use default
|
921
|
+
if weight_name_mapping is None:
|
922
|
+
weight_name_mapping = self._get_default_weight_mapping()
|
923
|
+
else:
|
924
|
+
# Merge with default mapping
|
925
|
+
default_mapping = self._get_default_weight_mapping()
|
926
|
+
default_mapping.update(weight_name_mapping)
|
927
|
+
weight_name_mapping = default_mapping
|
928
|
+
|
929
|
+
stacked_params_mapping = [
|
930
|
+
# (param_name, shard_name, shard_id)
|
931
|
+
("qkv_proj", "q_proj", "q"),
|
932
|
+
("qkv_proj", "k_proj", "k"),
|
933
|
+
("qkv_proj", "v_proj", "v"),
|
934
|
+
]
|
935
|
+
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
|
936
|
+
ckpt_gate_up_proj_name="gate_up_proj",
|
937
|
+
ckpt_down_proj_name="down_proj",
|
938
|
+
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
939
|
+
ckpt_down_proj_bias_name="down_proj_bias",
|
940
|
+
)
|
941
|
+
|
942
|
+
params_dict = dict(self.named_parameters())
|
943
|
+
params_checker = {k: False for k, v in params_dict.items()}
|
944
|
+
|
945
|
+
for other_loaded_param_name in other_loaded_param_names:
|
946
|
+
params_checker[other_loaded_param_name] = True
|
947
|
+
|
948
|
+
for name, loaded_weight in weights:
|
949
|
+
loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
|
950
|
+
|
951
|
+
# Apply weight name mapping if provided
|
952
|
+
if weight_name_mapping and name in weight_name_mapping:
|
953
|
+
name = weight_name_mapping[name]
|
954
|
+
|
955
|
+
layer_id = get_layer_id(name)
|
956
|
+
if (
|
957
|
+
layer_id is not None
|
958
|
+
and hasattr(self.model, "start_layer")
|
959
|
+
and (
|
960
|
+
layer_id < self.model.start_layer
|
961
|
+
or layer_id >= self.model.end_layer
|
962
|
+
)
|
963
|
+
):
|
964
|
+
continue
|
965
|
+
|
966
|
+
if "rotary_emb.inv_freq" in name:
|
967
|
+
continue
|
968
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
969
|
+
if weight_name not in name:
|
970
|
+
continue
|
971
|
+
if "mlp.experts" in name:
|
972
|
+
continue
|
973
|
+
|
974
|
+
name = name.replace(weight_name, param_name)
|
975
|
+
if name.endswith(".bias") and name not in params_dict:
|
976
|
+
continue
|
977
|
+
if name not in params_dict:
|
978
|
+
continue
|
979
|
+
|
980
|
+
param = params_dict[name]
|
981
|
+
weight_loader = param.weight_loader
|
982
|
+
weight_loader(param, loaded_weight, shard_id)
|
983
|
+
params_checker[name] = True
|
984
|
+
break
|
985
|
+
else:
|
986
|
+
for mapping in expert_params_mapping:
|
987
|
+
param_name, weight_name, shard_id = mapping
|
988
|
+
if weight_name not in name:
|
989
|
+
continue
|
990
|
+
name = name.replace(weight_name, param_name)
|
991
|
+
if name not in params_dict:
|
992
|
+
continue
|
993
|
+
param = params_dict[name]
|
994
|
+
weight_loader = param.weight_loader
|
995
|
+
if "bias" not in name:
|
996
|
+
loaded_weight = loaded_weight.transpose(-2, -1)
|
997
|
+
if "w2_weight_bias" in name and get_moe_tensor_parallel_rank() != 0:
|
998
|
+
loaded_weight = loaded_weight.zero_()
|
999
|
+
|
1000
|
+
weight_loader(
|
1001
|
+
param,
|
1002
|
+
loaded_weight,
|
1003
|
+
name,
|
1004
|
+
shard_id=shard_id,
|
1005
|
+
)
|
1006
|
+
params_checker[name] = True
|
1007
|
+
break
|
1008
|
+
else:
|
1009
|
+
if name.endswith(".bias") and name not in params_dict:
|
1010
|
+
continue
|
1011
|
+
if name not in params_dict:
|
1012
|
+
continue
|
1013
|
+
if name in params_dict.keys():
|
1014
|
+
param = params_dict[name]
|
1015
|
+
if "sinks" in name:
|
1016
|
+
start = tp_rank * param.numel()
|
1017
|
+
param.data.copy_(
|
1018
|
+
loaded_weight[start : start + param.numel()]
|
1019
|
+
)
|
1020
|
+
else:
|
1021
|
+
weight_loader = getattr(
|
1022
|
+
param, "weight_loader", default_weight_loader
|
1023
|
+
)
|
1024
|
+
weight_loader(param, loaded_weight)
|
1025
|
+
params_checker[name] = True
|
1026
|
+
else:
|
1027
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
1028
|
+
|
1029
|
+
not_loaded_params = [k for k, v in params_checker.items() if not v]
|
1030
|
+
if tp_rank == 0:
|
1031
|
+
if len(not_loaded_params) > 0:
|
1032
|
+
raise Exception(f"Not all parameters loaded: {not_loaded_params}")
|
1033
|
+
else:
|
1034
|
+
logging.info("All parameters loaded successfully.")
|
1035
|
+
|
1036
|
+
self.routed_experts_weights_of_layer = {
|
1037
|
+
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
|
1038
|
+
for layer_id in range(self.start_layer, self.end_layer)
|
1039
|
+
if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
|
1040
|
+
}
|
1041
|
+
|
1042
|
+
def get_embed_and_head(self):
|
1043
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
1044
|
+
|
1045
|
+
def set_embed_and_head(self, embed, head):
|
1046
|
+
del self.model.embed_tokens.weight
|
1047
|
+
del self.lm_head.weight
|
1048
|
+
self.model.embed_tokens.weight = embed
|
1049
|
+
self.lm_head.weight = head
|
1050
|
+
torch.cuda.empty_cache()
|
1051
|
+
torch.cuda.synchronize()
|
1052
|
+
|
1053
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
1054
|
+
if not self.pp_group.is_last_rank:
|
1055
|
+
return
|
1056
|
+
|
1057
|
+
if layer_ids is None:
|
1058
|
+
self.capture_aux_hidden_states = True
|
1059
|
+
num_layers = self.config.num_hidden_layers
|
1060
|
+
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
1061
|
+
else:
|
1062
|
+
self.capture_aux_hidden_states = True
|
1063
|
+
# we plus 1 here because in sglang, for the ith layer, it takes the output
|
1064
|
+
# of the (i-1)th layer as aux hidden state
|
1065
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
1066
|
+
|
1067
|
+
@classmethod
|
1068
|
+
def get_model_config_for_expert_location(cls, config):
|
1069
|
+
return ModelConfigForExpertLocation(
|
1070
|
+
num_layers=config.num_hidden_layers,
|
1071
|
+
num_logical_experts=config.num_local_experts,
|
1072
|
+
num_groups=None,
|
1073
|
+
)
|
1074
|
+
|
1075
|
+
def get_attention_sliding_window_size(self):
|
1076
|
+
return get_attention_sliding_window_size(self.config)
|
1077
|
+
|
1078
|
+
|
1079
|
+
def _canonicalize_weights(config, weights_in: Iterable[Tuple[str, torch.Tensor]]):
|
1080
|
+
weights_out_dict = dict(weights_in)
|
1081
|
+
|
1082
|
+
for layer_id in range(config.num_hidden_layers):
|
1083
|
+
for name_chunk in ["mlp1_weight", "mlp2_weight"]:
|
1084
|
+
name_prefix = f"block.{layer_id}.mlp.{name_chunk}"
|
1085
|
+
w_blocks = weights_out_dict.pop(f"{name_prefix}.blocks", None)
|
1086
|
+
w_scales = weights_out_dict.pop(f"{name_prefix}.scales", None)
|
1087
|
+
if w_blocks is not None:
|
1088
|
+
weights_out_dict[name_prefix] = _WeightCreator(
|
1089
|
+
partial(
|
1090
|
+
_dequant_mlp_weight,
|
1091
|
+
debug_name=name_prefix,
|
1092
|
+
w_blocks=w_blocks,
|
1093
|
+
w_scales=w_scales,
|
1094
|
+
)
|
1095
|
+
)
|
1096
|
+
|
1097
|
+
return list(weights_out_dict.items())
|
1098
|
+
|
1099
|
+
|
1100
|
+
def _dequant_mlp_weight(debug_name, w_blocks, w_scales):
|
1101
|
+
if get_tensor_model_parallel_rank() == 0:
|
1102
|
+
logger.info(f"Dequantize {debug_name} start")
|
1103
|
+
|
1104
|
+
original_device = w_blocks.device
|
1105
|
+
|
1106
|
+
w_blocks = w_blocks.cuda()
|
1107
|
+
w_scales = w_scales.cuda()
|
1108
|
+
|
1109
|
+
w_bf16 = dequant_mxfp4(w_block=w_blocks, w_scale=w_scales, out_dtype=torch.bfloat16)
|
1110
|
+
w_bf16 = w_bf16.transpose(-2, -1).contiguous()
|
1111
|
+
|
1112
|
+
if get_tensor_model_parallel_rank() == 0:
|
1113
|
+
logger.info(
|
1114
|
+
f"Dequantize {debug_name} end {w_blocks.shape=} {w_scales.shape=} {w_bf16.shape=}"
|
1115
|
+
)
|
1116
|
+
|
1117
|
+
return w_bf16.to(original_device)
|
1118
|
+
|
1119
|
+
|
1120
|
+
class _WeightCreator:
|
1121
|
+
def __init__(self, fn):
|
1122
|
+
self._fn = fn
|
1123
|
+
|
1124
|
+
@staticmethod
|
1125
|
+
def maybe_materialize(obj):
|
1126
|
+
if isinstance(obj, _WeightCreator):
|
1127
|
+
output = obj._fn()
|
1128
|
+
obj._fn = None
|
1129
|
+
return output
|
1130
|
+
|
1131
|
+
return obj
|
1132
|
+
|
1133
|
+
|
1134
|
+
EntryClass = GptOssForCausalLM
|