sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +23 -3
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +98 -603
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +6 -2
- sglang/srt/layers/quantization/fp8.py +0 -18
- sglang/srt/layers/quantization/modelopt_quant.py +2 -0
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/managers/cache_controller.py +143 -45
- sglang/srt/managers/data_parallel_controller.py +6 -0
- sglang/srt/managers/io_struct.py +12 -2
- sglang/srt/managers/scheduler.py +116 -669
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +166 -83
- sglang/srt/managers/tp_worker.py +9 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +45 -11
- sglang/srt/mem_cache/hiradix_cache.py +15 -4
- sglang/srt/mem_cache/memory_pool_host.py +73 -1
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/model_runner.py +20 -13
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +15 -56
- sglang/srt/models/glm4_moe.py +3 -1
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/qwen3_moe.py +12 -69
- sglang/srt/models/step3_vl.py +994 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +18 -13
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/test/test_utils.py +53 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,994 @@
|
|
1
|
+
import logging
|
2
|
+
import math
|
3
|
+
from collections.abc import Iterable
|
4
|
+
from math import sqrt
|
5
|
+
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
6
|
+
|
7
|
+
import torch
|
8
|
+
from torch import nn
|
9
|
+
from torch.nn import LayerNorm
|
10
|
+
from torch.nn import functional as F
|
11
|
+
from transformers import PretrainedConfig
|
12
|
+
from transformers.activations import ACT2FN
|
13
|
+
|
14
|
+
from sglang.srt.configs.step3_vl import (
|
15
|
+
Step3TextConfig,
|
16
|
+
Step3VisionEncoderConfig,
|
17
|
+
Step3VLConfig,
|
18
|
+
)
|
19
|
+
from sglang.srt.distributed import (
|
20
|
+
get_tensor_model_parallel_rank,
|
21
|
+
get_tensor_model_parallel_world_size,
|
22
|
+
tensor_model_parallel_all_reduce,
|
23
|
+
)
|
24
|
+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
25
|
+
from sglang.srt.layers.activation import SiluAndMul
|
26
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
27
|
+
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
28
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
29
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
30
|
+
from sglang.srt.layers.linear import (
|
31
|
+
ColumnParallelLinear,
|
32
|
+
MergedColumnParallelLinear,
|
33
|
+
ReplicatedLinear,
|
34
|
+
RowParallelLinear,
|
35
|
+
)
|
36
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
38
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
39
|
+
from sglang.srt.layers.moe.topk import TopK
|
40
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
41
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
43
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
44
|
+
ParallelLMHead,
|
45
|
+
VocabParallelEmbedding,
|
46
|
+
)
|
47
|
+
from sglang.srt.managers.mm_utils import (
|
48
|
+
MultiModalityDataPaddingPatternMultimodalTokens,
|
49
|
+
general_mm_embed_routine,
|
50
|
+
)
|
51
|
+
from sglang.srt.managers.schedule_batch import (
|
52
|
+
Modality,
|
53
|
+
MultimodalDataItem,
|
54
|
+
MultimodalInputs,
|
55
|
+
global_server_args_dict,
|
56
|
+
)
|
57
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
58
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
59
|
+
from sglang.srt.utils import add_prefix, log_info_on_rank0, make_layers
|
60
|
+
|
61
|
+
logger = logging.getLogger(__name__)
|
62
|
+
|
63
|
+
|
64
|
+
"""
|
65
|
+
Text Model
|
66
|
+
"""
|
67
|
+
|
68
|
+
|
69
|
+
class Step3TextMLP(nn.Module):
|
70
|
+
def __init__(
|
71
|
+
self,
|
72
|
+
hidden_size: int,
|
73
|
+
intermediate_size: int,
|
74
|
+
hidden_act: str,
|
75
|
+
quant_config: Optional[QuantizationConfig] = None,
|
76
|
+
prefix: str = "",
|
77
|
+
) -> None:
|
78
|
+
super().__init__()
|
79
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
80
|
+
hidden_size,
|
81
|
+
[intermediate_size] * 2,
|
82
|
+
bias=False,
|
83
|
+
quant_config=quant_config,
|
84
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
85
|
+
)
|
86
|
+
self.down_proj = RowParallelLinear(
|
87
|
+
intermediate_size,
|
88
|
+
hidden_size,
|
89
|
+
bias=False,
|
90
|
+
quant_config=quant_config,
|
91
|
+
prefix=add_prefix("down_proj", prefix),
|
92
|
+
)
|
93
|
+
if hidden_act != "silu":
|
94
|
+
raise ValueError(
|
95
|
+
f"Unsupported activation: {hidden_act}. "
|
96
|
+
"Only silu is supported for now."
|
97
|
+
)
|
98
|
+
self.act_fn = SiluAndMul()
|
99
|
+
|
100
|
+
def forward(self, x):
|
101
|
+
gate_up, _ = self.gate_up_proj(x)
|
102
|
+
x = self.act_fn(gate_up)
|
103
|
+
x, _ = self.down_proj(x)
|
104
|
+
return x
|
105
|
+
|
106
|
+
|
107
|
+
class Step3TextMoEMLP(nn.Module):
|
108
|
+
# Native
|
109
|
+
def __init__(
|
110
|
+
self,
|
111
|
+
layer_id: int,
|
112
|
+
config: Step3TextConfig,
|
113
|
+
quant_config: Optional[QuantizationConfig] = None,
|
114
|
+
prefix: str = "",
|
115
|
+
):
|
116
|
+
super().__init__()
|
117
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
118
|
+
self.layer_id = layer_id
|
119
|
+
if self.tp_size > config.moe_num_experts:
|
120
|
+
raise ValueError(
|
121
|
+
f"Tensor parallel size {self.tp_size} is greater than "
|
122
|
+
f"the number of experts {config.moe_num_experts}."
|
123
|
+
)
|
124
|
+
|
125
|
+
self.topk = TopK(
|
126
|
+
top_k=config.moe_top_k,
|
127
|
+
renormalize=config.norm_expert_weight,
|
128
|
+
use_grouped_topk=False,
|
129
|
+
)
|
130
|
+
|
131
|
+
self.experts = get_moe_impl_class()(
|
132
|
+
num_experts=config.moe_num_experts,
|
133
|
+
top_k=config.moe_top_k,
|
134
|
+
hidden_size=config.hidden_size,
|
135
|
+
intermediate_size=config.moe_intermediate_size,
|
136
|
+
layer_id=layer_id,
|
137
|
+
quant_config=quant_config,
|
138
|
+
prefix=add_prefix("experts", prefix),
|
139
|
+
)
|
140
|
+
|
141
|
+
self.gate = ReplicatedLinear(
|
142
|
+
config.hidden_size,
|
143
|
+
output_size=config.moe_num_experts,
|
144
|
+
bias=False,
|
145
|
+
quant_config=None,
|
146
|
+
prefix=add_prefix("gate", prefix),
|
147
|
+
)
|
148
|
+
|
149
|
+
if global_server_args_dict["enable_deepep_moe"]:
|
150
|
+
raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
|
151
|
+
|
152
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
153
|
+
num_tokens, hidden_dim = hidden_states.shape
|
154
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
155
|
+
|
156
|
+
router_logits, _ = self.gate(hidden_states)
|
157
|
+
topk_output = self.topk(hidden_states, router_logits)
|
158
|
+
final_hidden_states = self.experts(
|
159
|
+
hidden_states=hidden_states, topk_output=topk_output
|
160
|
+
)
|
161
|
+
|
162
|
+
if self.tp_size > 1:
|
163
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
164
|
+
return final_hidden_states.view(num_tokens, hidden_dim)
|
165
|
+
|
166
|
+
|
167
|
+
class Step3TextAttention(nn.Module):
|
168
|
+
def __init__(
|
169
|
+
self,
|
170
|
+
hidden_size: int,
|
171
|
+
num_heads: int,
|
172
|
+
num_kv_heads: int,
|
173
|
+
head_dim: int,
|
174
|
+
share_q_dim: int,
|
175
|
+
layer_id: int = 0,
|
176
|
+
rope_theta: float = 10000,
|
177
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
178
|
+
max_position_embeddings: int = 8192,
|
179
|
+
quant_config: Optional[QuantizationConfig] = None,
|
180
|
+
rms_norm_eps=None,
|
181
|
+
prefix: str = "",
|
182
|
+
) -> None:
|
183
|
+
super().__init__()
|
184
|
+
self.hidden_size = hidden_size
|
185
|
+
|
186
|
+
attn_tp_rank = get_attention_tp_rank()
|
187
|
+
attn_tp_size = get_attention_tp_size()
|
188
|
+
|
189
|
+
self.all_tp_rank = get_tensor_model_parallel_rank()
|
190
|
+
self.total_num_heads = num_heads
|
191
|
+
self.attn_tp_rank = attn_tp_rank
|
192
|
+
self.layer_id = layer_id
|
193
|
+
assert self.total_num_heads % attn_tp_size == 0
|
194
|
+
self.num_heads = self.total_num_heads // attn_tp_size
|
195
|
+
self.total_num_kv_heads = num_kv_heads
|
196
|
+
if self.total_num_kv_heads >= attn_tp_size:
|
197
|
+
# Number of KV heads is greater than TP size, so we partition
|
198
|
+
# the KV heads across multiple tensor parallel GPUs.
|
199
|
+
assert self.total_num_kv_heads % attn_tp_size == 0
|
200
|
+
else:
|
201
|
+
# Number of KV heads is less than TP size, so we replicate
|
202
|
+
# the KV heads across multiple tensor parallel GPUs.
|
203
|
+
assert attn_tp_size % self.total_num_kv_heads == 0
|
204
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
205
|
+
self.head_dim = head_dim
|
206
|
+
self.q_size = share_q_dim if share_q_dim else head_dim
|
207
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
208
|
+
|
209
|
+
self.scaling = self.head_dim**-0.5
|
210
|
+
self.rope_theta = rope_theta
|
211
|
+
self.max_position_embeddings = max_position_embeddings
|
212
|
+
|
213
|
+
self.qkv_proj = MergedColumnParallelLinear(
|
214
|
+
hidden_size,
|
215
|
+
[self.q_size, self.kv_size, self.kv_size],
|
216
|
+
bias=False,
|
217
|
+
quant_config=quant_config,
|
218
|
+
tp_rank=0, # In fact, we need a MergedReplicatedLinear
|
219
|
+
tp_size=1,
|
220
|
+
prefix=add_prefix("qkv_proj", prefix),
|
221
|
+
)
|
222
|
+
|
223
|
+
self.o_proj = RowParallelLinear(
|
224
|
+
self.total_num_heads * self.head_dim,
|
225
|
+
hidden_size,
|
226
|
+
bias=False,
|
227
|
+
quant_config=quant_config,
|
228
|
+
tp_rank=attn_tp_rank,
|
229
|
+
tp_size=attn_tp_size,
|
230
|
+
reduce_results=False,
|
231
|
+
prefix=add_prefix("o_proj", prefix),
|
232
|
+
)
|
233
|
+
|
234
|
+
self.inter_norm = RMSNorm(self.q_size, eps=rms_norm_eps)
|
235
|
+
|
236
|
+
self.wq = ColumnParallelLinear(
|
237
|
+
self.q_size,
|
238
|
+
self.head_dim * self.total_num_heads,
|
239
|
+
bias=False,
|
240
|
+
quant_config=quant_config,
|
241
|
+
tp_rank=attn_tp_rank,
|
242
|
+
tp_size=attn_tp_size,
|
243
|
+
prefix=add_prefix("wq", prefix),
|
244
|
+
)
|
245
|
+
self.rotary_emb = get_rope(
|
246
|
+
self.head_dim,
|
247
|
+
rotary_dim=self.head_dim,
|
248
|
+
max_position=max_position_embeddings,
|
249
|
+
base=rope_theta,
|
250
|
+
rope_scaling=rope_scaling,
|
251
|
+
)
|
252
|
+
self.attn = RadixAttention(
|
253
|
+
self.num_heads,
|
254
|
+
self.head_dim,
|
255
|
+
self.scaling,
|
256
|
+
num_kv_heads=self.num_kv_heads,
|
257
|
+
layer_id=layer_id,
|
258
|
+
quant_config=quant_config,
|
259
|
+
prefix=add_prefix("attn", prefix),
|
260
|
+
)
|
261
|
+
|
262
|
+
def forward(
|
263
|
+
self,
|
264
|
+
positions: torch.Tensor,
|
265
|
+
hidden_states: torch.Tensor,
|
266
|
+
forward_batch: ForwardBatch,
|
267
|
+
) -> torch.Tensor:
|
268
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
269
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
270
|
+
q = self.inter_norm(q.contiguous())
|
271
|
+
q, _ = self.wq(q)
|
272
|
+
q, k = self.rotary_emb(positions, q, k)
|
273
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
274
|
+
output, _ = self.o_proj(attn_output)
|
275
|
+
return output
|
276
|
+
|
277
|
+
|
278
|
+
class Step3TextDecoderLayer(nn.Module):
|
279
|
+
def __init__(
|
280
|
+
self,
|
281
|
+
config: Step3TextConfig,
|
282
|
+
layer_id: int,
|
283
|
+
quant_config: Optional[QuantizationConfig] = None,
|
284
|
+
prefix: str = "",
|
285
|
+
) -> None:
|
286
|
+
super().__init__()
|
287
|
+
self.hidden_size = config.hidden_size
|
288
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
289
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
290
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
291
|
+
head_dim = getattr(
|
292
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
293
|
+
)
|
294
|
+
# TODO: support shared experts fusion
|
295
|
+
# self.n_shared_experts = 1
|
296
|
+
# self.num_fused_shared_experts = (
|
297
|
+
# 0
|
298
|
+
# if global_server_args_dict["disable_shared_experts_fusion"]
|
299
|
+
# else self.n_shared_experts
|
300
|
+
# )
|
301
|
+
self.num_fused_shared_experts = 0
|
302
|
+
rms_norm_eps = config.rms_norm_eps
|
303
|
+
self.self_attn = Step3TextAttention(
|
304
|
+
hidden_size=self.hidden_size,
|
305
|
+
num_heads=config.num_attention_heads,
|
306
|
+
num_kv_heads=1,
|
307
|
+
head_dim=head_dim,
|
308
|
+
share_q_dim=config.share_q_dim,
|
309
|
+
layer_id=layer_id,
|
310
|
+
rope_theta=rope_theta,
|
311
|
+
rope_scaling=rope_scaling,
|
312
|
+
max_position_embeddings=max_position_embeddings,
|
313
|
+
rms_norm_eps=rms_norm_eps,
|
314
|
+
quant_config=quant_config,
|
315
|
+
prefix=add_prefix("self_attn", prefix),
|
316
|
+
)
|
317
|
+
|
318
|
+
moe_layers_enum = getattr(config, "moe_layers_enum", None)
|
319
|
+
if moe_layers_enum is not None:
|
320
|
+
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
|
321
|
+
else:
|
322
|
+
# Default to 1dense.
|
323
|
+
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
|
324
|
+
|
325
|
+
self.use_moe = False
|
326
|
+
|
327
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
328
|
+
self.post_attention_layernorm = RMSNorm(
|
329
|
+
config.hidden_size, eps=config.rms_norm_eps
|
330
|
+
)
|
331
|
+
|
332
|
+
self.layer_id = layer_id
|
333
|
+
self.is_layer_sparse = True if layer_id in moe_layers_idx else False
|
334
|
+
self.is_previous_layer_sparse = (
|
335
|
+
True if layer_id - 1 in moe_layers_idx else False
|
336
|
+
)
|
337
|
+
|
338
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
339
|
+
layer_id=layer_id,
|
340
|
+
num_layers=config.num_hidden_layers,
|
341
|
+
is_layer_sparse=self.is_layer_sparse,
|
342
|
+
is_previous_layer_sparse=self.is_previous_layer_sparse,
|
343
|
+
)
|
344
|
+
|
345
|
+
if not self.is_layer_sparse:
|
346
|
+
self.mlp = Step3TextMLP(
|
347
|
+
hidden_size=config.hidden_size,
|
348
|
+
intermediate_size=config.intermediate_size,
|
349
|
+
hidden_act="silu",
|
350
|
+
quant_config=quant_config,
|
351
|
+
prefix=add_prefix("mlp", prefix),
|
352
|
+
)
|
353
|
+
else:
|
354
|
+
self.use_moe = True
|
355
|
+
if self.num_fused_shared_experts == 0:
|
356
|
+
self.moe = Step3TextMoEMLP(
|
357
|
+
layer_id=layer_id,
|
358
|
+
config=config,
|
359
|
+
quant_config=quant_config,
|
360
|
+
prefix=add_prefix("mlp", prefix),
|
361
|
+
)
|
362
|
+
self.share_expert = Step3TextMLP(
|
363
|
+
hidden_size=config.hidden_size,
|
364
|
+
intermediate_size=config.share_expert_dim,
|
365
|
+
hidden_act="silu",
|
366
|
+
quant_config=quant_config,
|
367
|
+
prefix=add_prefix("share_expert", prefix),
|
368
|
+
)
|
369
|
+
else:
|
370
|
+
self.moe = Step3TextMoEMLP(
|
371
|
+
layer_id=layer_id,
|
372
|
+
config=config,
|
373
|
+
quant_config=quant_config,
|
374
|
+
prefix=add_prefix("mlp", prefix),
|
375
|
+
)
|
376
|
+
|
377
|
+
self.layer_communicator = LayerCommunicator(
|
378
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
379
|
+
input_layernorm=self.input_layernorm,
|
380
|
+
post_attention_layernorm=self.post_attention_layernorm,
|
381
|
+
)
|
382
|
+
|
383
|
+
def moe_mlp_forward(self, hidden_states):
|
384
|
+
if not self.num_fused_shared_experts:
|
385
|
+
h = hidden_states.clone()
|
386
|
+
hidden_states = self.moe(hidden_states)
|
387
|
+
hidden_states += self.share_expert(h)
|
388
|
+
else:
|
389
|
+
hidden_states = self.moe(hidden_states)
|
390
|
+
return hidden_states
|
391
|
+
|
392
|
+
def forward(
|
393
|
+
self,
|
394
|
+
positions: torch.Tensor,
|
395
|
+
hidden_states: torch.Tensor,
|
396
|
+
forward_batch: ForwardBatch,
|
397
|
+
residual: Optional[torch.Tensor],
|
398
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
399
|
+
|
400
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
401
|
+
hidden_states, residual, forward_batch
|
402
|
+
)
|
403
|
+
|
404
|
+
if hidden_states.shape[0] != 0:
|
405
|
+
hidden_states = self.self_attn(
|
406
|
+
positions=positions,
|
407
|
+
hidden_states=hidden_states,
|
408
|
+
forward_batch=forward_batch,
|
409
|
+
)
|
410
|
+
|
411
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
412
|
+
hidden_states, residual, forward_batch
|
413
|
+
)
|
414
|
+
if self.use_moe:
|
415
|
+
hidden_states = self.moe_mlp_forward(hidden_states)
|
416
|
+
else:
|
417
|
+
hidden_states = self.mlp(hidden_states)
|
418
|
+
|
419
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
420
|
+
hidden_states, residual, forward_batch
|
421
|
+
)
|
422
|
+
|
423
|
+
return hidden_states, residual
|
424
|
+
|
425
|
+
|
426
|
+
class Step3TextModel(nn.Module):
|
427
|
+
def __init__(
|
428
|
+
self,
|
429
|
+
config: PretrainedConfig,
|
430
|
+
quant_config: Optional[QuantizationConfig] = None,
|
431
|
+
prefix: str = "",
|
432
|
+
) -> None:
|
433
|
+
super().__init__()
|
434
|
+
self.padding_idx = config.pad_token_id
|
435
|
+
self.vocab_size = config.vocab_size
|
436
|
+
|
437
|
+
self.embed_tokens = VocabParallelEmbedding(
|
438
|
+
config.vocab_size,
|
439
|
+
config.hidden_size,
|
440
|
+
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
441
|
+
prefix=add_prefix("embed_tokens", prefix),
|
442
|
+
)
|
443
|
+
|
444
|
+
self.layers = make_layers(
|
445
|
+
config.num_hidden_layers,
|
446
|
+
lambda idx, prefix: Step3TextDecoderLayer(
|
447
|
+
layer_id=idx,
|
448
|
+
config=config,
|
449
|
+
quant_config=quant_config,
|
450
|
+
prefix=prefix,
|
451
|
+
),
|
452
|
+
prefix=add_prefix("layers", prefix),
|
453
|
+
)
|
454
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
455
|
+
|
456
|
+
def get_input_embeddings(self):
|
457
|
+
return self.embed_tokens
|
458
|
+
|
459
|
+
def forward(
|
460
|
+
self,
|
461
|
+
input_ids: torch.Tensor,
|
462
|
+
positions: torch.Tensor,
|
463
|
+
forward_batch: ForwardBatch,
|
464
|
+
input_embeds: torch.Tensor = None,
|
465
|
+
) -> torch.Tensor:
|
466
|
+
if input_embeds is None:
|
467
|
+
hidden_states = self.embed_tokens(input_ids)
|
468
|
+
else:
|
469
|
+
hidden_states = input_embeds
|
470
|
+
|
471
|
+
residual = None
|
472
|
+
for i in range(len(self.layers)):
|
473
|
+
layer = self.layers[i]
|
474
|
+
hidden_states, residual = layer(
|
475
|
+
positions, hidden_states, forward_batch, residual
|
476
|
+
)
|
477
|
+
|
478
|
+
if hidden_states.shape[0] != 0:
|
479
|
+
if residual is None:
|
480
|
+
hidden_states = self.norm(hidden_states)
|
481
|
+
else:
|
482
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
483
|
+
return hidden_states
|
484
|
+
|
485
|
+
|
486
|
+
"""
|
487
|
+
Vision Model
|
488
|
+
"""
|
489
|
+
|
490
|
+
|
491
|
+
def get_abs_pos(abs_pos, tgt_size):
|
492
|
+
dim = abs_pos.size(-1)
|
493
|
+
abs_pos_new = abs_pos.squeeze(0)
|
494
|
+
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
|
495
|
+
|
496
|
+
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
|
497
|
+
tgt_size = int(math.sqrt(tgt_size))
|
498
|
+
dtype = abs_pos.dtype
|
499
|
+
|
500
|
+
if src_size != tgt_size:
|
501
|
+
old_pos_embed = (
|
502
|
+
old_pos_embed.view(1, src_size, src_size, dim)
|
503
|
+
.permute(0, 3, 1, 2)
|
504
|
+
.contiguous()
|
505
|
+
)
|
506
|
+
old_pos_embed = old_pos_embed.to(torch.float32)
|
507
|
+
new_pos_embed = F.interpolate(
|
508
|
+
old_pos_embed,
|
509
|
+
size=(tgt_size, tgt_size),
|
510
|
+
mode="bicubic",
|
511
|
+
antialias=True,
|
512
|
+
align_corners=False,
|
513
|
+
).to(dtype)
|
514
|
+
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
|
515
|
+
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
|
516
|
+
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
|
517
|
+
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
|
518
|
+
return vision_pos_embed
|
519
|
+
else:
|
520
|
+
return abs_pos
|
521
|
+
|
522
|
+
|
523
|
+
class Step3VisionMLP(nn.Module):
|
524
|
+
def __init__(
|
525
|
+
self,
|
526
|
+
dim: int,
|
527
|
+
intermediate_size: int,
|
528
|
+
bias: bool = True,
|
529
|
+
hidden_act="quick_gelu",
|
530
|
+
quant_config: Optional[QuantizationConfig] = None,
|
531
|
+
prefix: str = "",
|
532
|
+
) -> None:
|
533
|
+
super().__init__()
|
534
|
+
self.fc1 = ColumnParallelLinear(
|
535
|
+
dim,
|
536
|
+
intermediate_size,
|
537
|
+
bias=bias,
|
538
|
+
quant_config=quant_config,
|
539
|
+
prefix=add_prefix("gate_proj", prefix),
|
540
|
+
)
|
541
|
+
self.act = ACT2FN[hidden_act] # quick_gelu
|
542
|
+
self.fc2 = RowParallelLinear(
|
543
|
+
intermediate_size,
|
544
|
+
dim,
|
545
|
+
bias=bias,
|
546
|
+
quant_config=quant_config,
|
547
|
+
prefix=add_prefix("down_proj", prefix),
|
548
|
+
)
|
549
|
+
|
550
|
+
def forward(self, hidden_states) -> torch.Tensor:
|
551
|
+
hidden_states, _ = self.fc1(hidden_states)
|
552
|
+
hidden_states = self.act(hidden_states)
|
553
|
+
hidden_states, _ = self.fc2(hidden_states)
|
554
|
+
return hidden_states
|
555
|
+
|
556
|
+
|
557
|
+
class Step3VisionAttention(nn.Module):
|
558
|
+
def __init__(
|
559
|
+
self,
|
560
|
+
dim: int,
|
561
|
+
num_heads: int = 16,
|
562
|
+
qkv_backend="fa3",
|
563
|
+
quant_config=None,
|
564
|
+
prefix: str = "",
|
565
|
+
) -> None:
|
566
|
+
|
567
|
+
super().__init__()
|
568
|
+
self.num_heads = num_heads
|
569
|
+
self.head_dim = dim // num_heads
|
570
|
+
self.out_proj = RowParallelLinear(
|
571
|
+
dim,
|
572
|
+
dim,
|
573
|
+
bias=True,
|
574
|
+
quant_config=quant_config,
|
575
|
+
prefix=add_prefix("out_proj", prefix),
|
576
|
+
)
|
577
|
+
self.scale = self.head_dim**-0.5
|
578
|
+
|
579
|
+
self.attn = VisionAttention(
|
580
|
+
embed_dim=dim,
|
581
|
+
num_heads=num_heads,
|
582
|
+
projection_size=dim,
|
583
|
+
use_qkv_parallel=True,
|
584
|
+
rotary_embed="normal",
|
585
|
+
proj_bias=True,
|
586
|
+
qkv_backend=qkv_backend,
|
587
|
+
quant_config=quant_config,
|
588
|
+
prefix=add_prefix("attn", prefix),
|
589
|
+
)
|
590
|
+
|
591
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
592
|
+
attn_output = self.attn(hidden_states)
|
593
|
+
return attn_output
|
594
|
+
|
595
|
+
|
596
|
+
class Step3VisionEmbeddings(nn.Module):
|
597
|
+
|
598
|
+
def __init__(self, config: Step3VisionEncoderConfig):
|
599
|
+
super().__init__()
|
600
|
+
self.config = config
|
601
|
+
self.embed_dim = config.hidden_size
|
602
|
+
self.image_size = config.image_size
|
603
|
+
self.patch_size = config.patch_size
|
604
|
+
|
605
|
+
self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))
|
606
|
+
|
607
|
+
self.patch_embedding = nn.Conv2d(
|
608
|
+
in_channels=config.num_channels,
|
609
|
+
out_channels=self.embed_dim,
|
610
|
+
kernel_size=self.patch_size,
|
611
|
+
stride=self.patch_size,
|
612
|
+
bias=True,
|
613
|
+
)
|
614
|
+
|
615
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
616
|
+
self.pad_tp_size = 4 # hard code for padding
|
617
|
+
# To load the pretrained weights, we still use P+1 as the seqlen
|
618
|
+
self.position_embedding = torch.nn.Embedding(
|
619
|
+
self.num_patches + 1, self.embed_dim
|
620
|
+
)
|
621
|
+
self.register_buffer(
|
622
|
+
"position_ids",
|
623
|
+
torch.arange(self.num_patches + 1).expand((1, -1)),
|
624
|
+
persistent=False,
|
625
|
+
)
|
626
|
+
|
627
|
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
628
|
+
batch_size = pixel_values.shape[0]
|
629
|
+
patch_embeds = self.patch_embedding(
|
630
|
+
pixel_values
|
631
|
+
) # shape = [*, width, grid, grid]
|
632
|
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
633
|
+
|
634
|
+
# pad
|
635
|
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
636
|
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
637
|
+
embeddings = embeddings + get_abs_pos(
|
638
|
+
self.position_embedding(self.position_ids), patch_embeds.size(1)
|
639
|
+
)
|
640
|
+
embeddings = torch.cat(
|
641
|
+
[
|
642
|
+
embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, 1),
|
643
|
+
embeddings,
|
644
|
+
],
|
645
|
+
dim=1,
|
646
|
+
)
|
647
|
+
return embeddings
|
648
|
+
|
649
|
+
|
650
|
+
class Step3VisionEncoderLayer(nn.Module):
|
651
|
+
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
652
|
+
super().__init__()
|
653
|
+
self.embed_dim = config.hidden_size
|
654
|
+
self.layer_norm1 = LayerNorm(self.embed_dim, eps=1e-6)
|
655
|
+
self.layer_norm2 = LayerNorm(self.embed_dim, eps=1e-6)
|
656
|
+
|
657
|
+
self.self_attn = Step3VisionAttention(
|
658
|
+
self.embed_dim, num_heads=config.num_attention_heads
|
659
|
+
)
|
660
|
+
self.mlp = Step3VisionMLP(
|
661
|
+
dim=self.embed_dim,
|
662
|
+
intermediate_size=config.intermediate_size,
|
663
|
+
hidden_act=config.hidden_act,
|
664
|
+
)
|
665
|
+
|
666
|
+
def forward(self, hidden_states) -> torch.Tensor:
|
667
|
+
hidden_states = hidden_states + self.layer_norm1(self.self_attn(hidden_states))
|
668
|
+
hidden_states = hidden_states + self.layer_norm2(self.mlp(hidden_states))
|
669
|
+
return hidden_states
|
670
|
+
|
671
|
+
|
672
|
+
class Step3VisionTransformer(nn.Module):
|
673
|
+
def __init__(self, config: Step3VisionEncoderConfig):
|
674
|
+
super().__init__()
|
675
|
+
self.config = config
|
676
|
+
self.image_size = config.image_size
|
677
|
+
self.embeddings = Step3VisionEmbeddings(config)
|
678
|
+
self.transformer = Step3VisionEncoder(config)
|
679
|
+
|
680
|
+
@property
|
681
|
+
def dtype(self) -> torch.dtype:
|
682
|
+
return self.embeddings.patch_embedding.weight.dtype
|
683
|
+
|
684
|
+
def forward(
|
685
|
+
self,
|
686
|
+
pixel_values: torch.Tensor,
|
687
|
+
):
|
688
|
+
hidden_states = self.embeddings(pixel_values)
|
689
|
+
hidden_states = self.transformer(inputs_embeds=hidden_states)
|
690
|
+
return hidden_states
|
691
|
+
|
692
|
+
|
693
|
+
class Step3VisionEncoder(nn.Module):
|
694
|
+
"""
|
695
|
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
696
|
+
[`Step3VisionEncoderLayer`].
|
697
|
+
|
698
|
+
Args:
|
699
|
+
config: StepVisionEncoderConfig
|
700
|
+
"""
|
701
|
+
|
702
|
+
def __init__(self, config: Step3VisionEncoderConfig):
|
703
|
+
super().__init__()
|
704
|
+
self.config = config
|
705
|
+
self.layers = nn.ModuleList(
|
706
|
+
[Step3VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
707
|
+
)
|
708
|
+
|
709
|
+
def forward(
|
710
|
+
self,
|
711
|
+
inputs_embeds,
|
712
|
+
) -> torch.Tensor:
|
713
|
+
|
714
|
+
hidden_states = inputs_embeds
|
715
|
+
for encoder_layer in self.layers:
|
716
|
+
hidden_states = encoder_layer(
|
717
|
+
hidden_states,
|
718
|
+
)
|
719
|
+
|
720
|
+
return hidden_states
|
721
|
+
|
722
|
+
|
723
|
+
class Step3VLForConditionalGeneration(nn.Module):
|
724
|
+
|
725
|
+
def __init__(
|
726
|
+
self,
|
727
|
+
config: Step3VLConfig,
|
728
|
+
quant_config: Optional[QuantizationConfig] = None,
|
729
|
+
prefix: str = "",
|
730
|
+
) -> None:
|
731
|
+
super().__init__()
|
732
|
+
self.config = config
|
733
|
+
self.quant_config = quant_config
|
734
|
+
self.model = Step3TextModel(
|
735
|
+
config.text_config, quant_config, prefix=add_prefix("model", prefix)
|
736
|
+
)
|
737
|
+
|
738
|
+
self.vision_model = Step3VisionTransformer(config.vision_config)
|
739
|
+
|
740
|
+
self.vit_downsampler = nn.Conv2d(
|
741
|
+
config.vision_config.hidden_size,
|
742
|
+
config.vision_config.output_hidden_size,
|
743
|
+
kernel_size=2,
|
744
|
+
stride=config.understand_projector_stride,
|
745
|
+
)
|
746
|
+
self.vit_downsampler2 = nn.Conv2d(
|
747
|
+
config.vision_config.output_hidden_size,
|
748
|
+
config.vision_config.output_hidden_size * 2,
|
749
|
+
kernel_size=3,
|
750
|
+
stride=2,
|
751
|
+
padding=1,
|
752
|
+
)
|
753
|
+
self.vit_large_projector = nn.Linear(
|
754
|
+
config.vision_config.output_hidden_size * 2,
|
755
|
+
config.hidden_size,
|
756
|
+
bias=config.projector_bias,
|
757
|
+
)
|
758
|
+
|
759
|
+
# TODO: support shared experts fusion
|
760
|
+
# self.n_shared_experts = 1
|
761
|
+
# self.num_fused_shared_experts = (
|
762
|
+
# 0
|
763
|
+
# if global_server_args_dict["disable_shared_experts_fusion"]
|
764
|
+
# else self.n_shared_experts
|
765
|
+
# )
|
766
|
+
self.num_fused_shared_experts = 0
|
767
|
+
self.config.tie_word_embeddings = False
|
768
|
+
if getattr(self.config, "tie_word_embeddings", False):
|
769
|
+
self.lm_head = self.model.embed_tokens
|
770
|
+
else:
|
771
|
+
self.lm_head = ParallelLMHead(
|
772
|
+
config.text_config.vocab_size,
|
773
|
+
config.text_config.hidden_size,
|
774
|
+
quant_config=quant_config,
|
775
|
+
prefix=add_prefix("lm_head", prefix),
|
776
|
+
)
|
777
|
+
self.logits_processor = LogitsProcessor(config.text_config)
|
778
|
+
|
779
|
+
def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
780
|
+
return self.vision_model(input_tensor)[:, 4:]
|
781
|
+
|
782
|
+
def _flatten_embeddings(self, embeddings) -> torch.Tensor:
|
783
|
+
|
784
|
+
if isinstance(embeddings, torch.Tensor):
|
785
|
+
# Flatten all but the last dimension.
|
786
|
+
return embeddings.flatten(0, -2)
|
787
|
+
|
788
|
+
return torch.cat(tuple(self._flatten_embeddings(t) for t in embeddings))
|
789
|
+
|
790
|
+
def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
|
791
|
+
B, P = image_features.shape[:2]
|
792
|
+
HW = int(sqrt(P))
|
793
|
+
image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
|
794
|
+
image_features = self.vit_downsampler(image_features)
|
795
|
+
image_features = self.vit_downsampler2(image_features)
|
796
|
+
n_dim = image_features.size(1)
|
797
|
+
image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
|
798
|
+
image_features = self.vit_large_projector(image_features)
|
799
|
+
return image_features
|
800
|
+
|
801
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
802
|
+
assert len(items) == 1 # We only have images.
|
803
|
+
|
804
|
+
item = items[0]
|
805
|
+
pixel_values = item.feature.type(self.vision_model.dtype)
|
806
|
+
num_patches = item.model_specific_data.get("num_patches")
|
807
|
+
patch_pixel_values = item.model_specific_data.get("patch_pixel_values", None)
|
808
|
+
if patch_pixel_values is not None:
|
809
|
+
patch_pixel_values = patch_pixel_values.type(self.vision_model.dtype)
|
810
|
+
|
811
|
+
if patch_pixel_values is not None:
|
812
|
+
patch_pixel_values = patch_pixel_values.to("cuda")
|
813
|
+
|
814
|
+
image_features = self._get_vision_model_output(pixel_values)
|
815
|
+
patch_image_features = (
|
816
|
+
self._get_vision_model_output(patch_pixel_values)
|
817
|
+
if patch_pixel_values is not None
|
818
|
+
else None
|
819
|
+
)
|
820
|
+
|
821
|
+
image_features = self._process_image_features(image_features)
|
822
|
+
patch_image_features = (
|
823
|
+
self._process_image_features(patch_image_features)
|
824
|
+
if patch_image_features is not None
|
825
|
+
else None
|
826
|
+
)
|
827
|
+
|
828
|
+
merged_image_features = []
|
829
|
+
cur_patch_idx = 0
|
830
|
+
for i, num_patch in enumerate(num_patches):
|
831
|
+
cur_feature = []
|
832
|
+
if num_patch > 0:
|
833
|
+
patch_slice = patch_image_features[
|
834
|
+
cur_patch_idx : cur_patch_idx + num_patch
|
835
|
+
]
|
836
|
+
cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
|
837
|
+
cur_feature.append(image_features[i].view(-1, image_features.shape[-1]))
|
838
|
+
cur_patch_idx += num_patch
|
839
|
+
merged_image_features.append(
|
840
|
+
torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0]
|
841
|
+
)
|
842
|
+
return self._flatten_embeddings(merged_image_features)
|
843
|
+
|
844
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
845
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
846
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
847
|
+
|
848
|
+
@torch.no_grad()
|
849
|
+
def forward(
|
850
|
+
self,
|
851
|
+
input_ids: torch.Tensor,
|
852
|
+
positions: torch.Tensor,
|
853
|
+
forward_batch: ForwardBatch,
|
854
|
+
input_embeds: torch.Tensor = None,
|
855
|
+
) -> torch.Tensor:
|
856
|
+
hidden_states = general_mm_embed_routine(
|
857
|
+
input_ids=input_ids,
|
858
|
+
forward_batch=forward_batch,
|
859
|
+
language_model=self.model,
|
860
|
+
data_embedding_funcs={
|
861
|
+
Modality.IMAGE: self.get_image_feature,
|
862
|
+
},
|
863
|
+
positions=positions,
|
864
|
+
)
|
865
|
+
|
866
|
+
return self.logits_processor(
|
867
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
868
|
+
)
|
869
|
+
|
870
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
871
|
+
# TODO:
|
872
|
+
stacked_params_mapping = [
|
873
|
+
# (param_name, shard_name, shard_id)
|
874
|
+
(".qkv_proj", ".q_proj", 0),
|
875
|
+
(".qkv_proj", ".k_proj", 1),
|
876
|
+
(".qkv_proj", ".v_proj", 2),
|
877
|
+
(".gate_up_proj", ".gate_proj", 0),
|
878
|
+
(".gate_up_proj", ".up_proj", 1),
|
879
|
+
]
|
880
|
+
|
881
|
+
if self.num_fused_shared_experts > 0:
|
882
|
+
assert self.num_fused_shared_experts == 1
|
883
|
+
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
|
884
|
+
|
885
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
886
|
+
ckpt_gate_proj_name="gate_proj",
|
887
|
+
ckpt_down_proj_name="down_proj",
|
888
|
+
ckpt_up_proj_name="up_proj",
|
889
|
+
num_experts=self.config.text_config.moe_num_experts
|
890
|
+
+ self.num_fused_shared_experts,
|
891
|
+
)
|
892
|
+
|
893
|
+
params_dict = dict(self.named_parameters())
|
894
|
+
loaded_params = set()
|
895
|
+
|
896
|
+
def match_expert_and_shard_ids(name_path: str, weight_path: str) -> bool:
|
897
|
+
name_parts = name_path.split(".")
|
898
|
+
weight_parts = weight_path.split(".")
|
899
|
+
shard_id_matches = name_parts[4] == weight_parts[2]
|
900
|
+
return shard_id_matches
|
901
|
+
|
902
|
+
for name, loaded_weight in weights:
|
903
|
+
if "vision_model" in name:
|
904
|
+
# 1.It’s not great, but let’s leave it like this for now
|
905
|
+
name = name.replace("self_attn", "self_attn.attn")
|
906
|
+
# 2.
|
907
|
+
name = name.replace("out_proj", "proj")
|
908
|
+
|
909
|
+
# TODO: support vision model
|
910
|
+
if self.num_fused_shared_experts > 0 and "share" in name:
|
911
|
+
# assert False
|
912
|
+
name = name.replace("share_expert", "moe")
|
913
|
+
for mapping in expert_params_mapping:
|
914
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
915
|
+
if (
|
916
|
+
expert_id != self.config.text_config.moe_num_experts
|
917
|
+
or not match_expert_and_shard_ids(name, weight_name)
|
918
|
+
):
|
919
|
+
continue
|
920
|
+
|
921
|
+
part_name = weight_name.split(".")[-2]
|
922
|
+
fake_weight_name = name.replace(part_name, weight_name[:-1])
|
923
|
+
actual_param_name = name.replace(part_name + ".", param_name)
|
924
|
+
param = params_dict[actual_param_name]
|
925
|
+
weight_loader = param.weight_loader
|
926
|
+
weight_loader(
|
927
|
+
param,
|
928
|
+
loaded_weight,
|
929
|
+
name,
|
930
|
+
shard_id=shard_id,
|
931
|
+
expert_id=expert_id,
|
932
|
+
)
|
933
|
+
break
|
934
|
+
continue
|
935
|
+
|
936
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
937
|
+
if weight_name not in name:
|
938
|
+
continue
|
939
|
+
if "gate." not in name and "moe" in name:
|
940
|
+
continue
|
941
|
+
name = name.replace(weight_name, param_name)
|
942
|
+
param = params_dict[name]
|
943
|
+
weight_loader = param.weight_loader
|
944
|
+
weight_loader(param, loaded_weight, shard_id)
|
945
|
+
loaded_params.add(name)
|
946
|
+
break
|
947
|
+
else:
|
948
|
+
if "moe" not in name:
|
949
|
+
param = params_dict[name]
|
950
|
+
weight_loader = getattr(
|
951
|
+
param, "weight_loader", default_weight_loader
|
952
|
+
)
|
953
|
+
weight_loader(param, loaded_weight)
|
954
|
+
loaded_params.add(name)
|
955
|
+
else:
|
956
|
+
if "gate." in name:
|
957
|
+
name = name.replace(weight_name, param_name)
|
958
|
+
param = params_dict[name]
|
959
|
+
weight_loader = param.weight_loader
|
960
|
+
weight_loader(param, loaded_weight)
|
961
|
+
loaded_params.add(name)
|
962
|
+
continue
|
963
|
+
|
964
|
+
for mapping in expert_params_mapping:
|
965
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
966
|
+
if expert_id == self.config.text_config.moe_num_experts:
|
967
|
+
continue
|
968
|
+
if not match_expert_and_shard_ids(name, weight_name):
|
969
|
+
continue
|
970
|
+
part_name = weight_name.split(".")[-2]
|
971
|
+
fake_weight_name = name.replace(part_name, weight_name[:-1])
|
972
|
+
actual_param_name = name.replace(part_name + ".", param_name)
|
973
|
+
param = params_dict[actual_param_name]
|
974
|
+
weight_loader = param.weight_loader
|
975
|
+
weight_loader(
|
976
|
+
param,
|
977
|
+
loaded_weight[expert_id],
|
978
|
+
name,
|
979
|
+
shard_id=shard_id,
|
980
|
+
expert_id=expert_id,
|
981
|
+
)
|
982
|
+
loaded_params.add(actual_param_name)
|
983
|
+
# Don't break here, because this 'loaded_weight' includes all the weights for this layer
|
984
|
+
|
985
|
+
@classmethod
|
986
|
+
def get_model_config_for_expert_location(cls, config: Step3VLConfig):
|
987
|
+
return ModelConfigForExpertLocation(
|
988
|
+
num_layers=config.text_config.num_hidden_layers,
|
989
|
+
num_logical_experts=config.text_config.moe_num_experts,
|
990
|
+
num_groups=None,
|
991
|
+
)
|
992
|
+
|
993
|
+
|
994
|
+
EntryClass = Step3VLForConditionalGeneration
|