sglang 0.4.9.post6__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +3 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +20 -640
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
- sglang/srt/layers/quantization/fp8.py +0 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/managers/cache_controller.py +143 -45
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +89 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +45 -11
- sglang/srt/mem_cache/hiradix_cache.py +15 -4
- sglang/srt/mem_cache/memory_pool_host.py +73 -1
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/model_runner.py +5 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +2 -0
- sglang/srt/models/glm4_moe.py +3 -1
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +994 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +10 -13
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/RECORD +69 -56
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,532 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""Inference-only Arcee Foundational Model (AFM) compatible with HuggingFace weights."""
|
15
|
+
|
16
|
+
import logging
|
17
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from torch import nn
|
21
|
+
from transformers import LlamaConfig
|
22
|
+
|
23
|
+
from sglang.srt.distributed import (
|
24
|
+
get_pp_group,
|
25
|
+
get_tensor_model_parallel_rank,
|
26
|
+
get_tensor_model_parallel_world_size,
|
27
|
+
)
|
28
|
+
from sglang.srt.layers.activation import get_act_fn
|
29
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
30
|
+
from sglang.srt.layers.linear import (
|
31
|
+
ColumnParallelLinear,
|
32
|
+
QKVParallelLinear,
|
33
|
+
RowParallelLinear,
|
34
|
+
)
|
35
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
36
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
37
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
40
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
41
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
42
|
+
ParallelLMHead,
|
43
|
+
VocabParallelEmbedding,
|
44
|
+
)
|
45
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
46
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
47
|
+
from sglang.srt.model_loader.weight_utils import (
|
48
|
+
default_weight_loader,
|
49
|
+
kv_cache_scales_loader,
|
50
|
+
maybe_remap_kv_scale_name,
|
51
|
+
)
|
52
|
+
from sglang.srt.utils import add_prefix, make_layers
|
53
|
+
|
54
|
+
logger = logging.getLogger(__name__)
|
55
|
+
|
56
|
+
|
57
|
+
class ArceeMLP(nn.Module):
|
58
|
+
"""
|
59
|
+
MLP block for the Arcee model, using a ReLU-squared activation function.
|
60
|
+
This differs from the Llama SwiGLU activation.
|
61
|
+
"""
|
62
|
+
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
hidden_size: int,
|
66
|
+
intermediate_size: int,
|
67
|
+
hidden_act: str,
|
68
|
+
quant_config: Optional[QuantizationConfig] = None,
|
69
|
+
prefix: str = "",
|
70
|
+
reduce_results: bool = True,
|
71
|
+
) -> None:
|
72
|
+
super().__init__()
|
73
|
+
# Arcee uses a single up-projection, not a merged gate/up projection.
|
74
|
+
self.up_proj = ColumnParallelLinear(
|
75
|
+
hidden_size,
|
76
|
+
intermediate_size,
|
77
|
+
bias=False,
|
78
|
+
quant_config=quant_config,
|
79
|
+
prefix=add_prefix("up_proj", prefix),
|
80
|
+
)
|
81
|
+
self.down_proj = RowParallelLinear(
|
82
|
+
intermediate_size,
|
83
|
+
hidden_size,
|
84
|
+
bias=False,
|
85
|
+
quant_config=quant_config,
|
86
|
+
prefix=add_prefix("down_proj", prefix),
|
87
|
+
reduce_results=reduce_results,
|
88
|
+
)
|
89
|
+
if hidden_act != "relu2":
|
90
|
+
raise ValueError(
|
91
|
+
f"Unsupported activation: {hidden_act}. "
|
92
|
+
"Arcee model in SGLang only supports 'relu2'."
|
93
|
+
)
|
94
|
+
# The activation function is relu(x)^2
|
95
|
+
self.act_fn = get_act_fn("relu2")
|
96
|
+
|
97
|
+
def forward(self, x, forward_batch=None):
|
98
|
+
x, _ = self.up_proj(x)
|
99
|
+
x = self.act_fn(x)
|
100
|
+
x, _ = self.down_proj(x)
|
101
|
+
return x
|
102
|
+
|
103
|
+
|
104
|
+
class ArceeAttention(nn.Module):
|
105
|
+
def __init__(
|
106
|
+
self,
|
107
|
+
config: LlamaConfig,
|
108
|
+
hidden_size: int,
|
109
|
+
num_heads: int,
|
110
|
+
num_kv_heads: int,
|
111
|
+
layer_id: int = 0,
|
112
|
+
rope_theta: float = 10000,
|
113
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
114
|
+
rope_is_neox_style: bool = True,
|
115
|
+
max_position_embeddings: int = 8192,
|
116
|
+
quant_config: Optional[QuantizationConfig] = None,
|
117
|
+
prefix: str = "",
|
118
|
+
bias: bool = False,
|
119
|
+
) -> None:
|
120
|
+
super().__init__()
|
121
|
+
self.hidden_size = hidden_size
|
122
|
+
tp_size = get_tensor_model_parallel_world_size()
|
123
|
+
self.total_num_heads = num_heads
|
124
|
+
assert self.total_num_heads % tp_size == 0
|
125
|
+
self.num_heads = self.total_num_heads // tp_size
|
126
|
+
self.total_num_kv_heads = num_kv_heads
|
127
|
+
if self.total_num_kv_heads >= tp_size:
|
128
|
+
assert self.total_num_kv_heads % tp_size == 0
|
129
|
+
else:
|
130
|
+
assert tp_size % self.total_num_kv_heads == 0
|
131
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
132
|
+
self.head_dim = getattr(config, "head_dim", None)
|
133
|
+
if self.head_dim is None:
|
134
|
+
self.head_dim = self.hidden_size // self.total_num_heads
|
135
|
+
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
136
|
+
self.rotary_dim = int(self.partial_rotary_factor * self.head_dim)
|
137
|
+
self.q_size = self.num_heads * self.head_dim
|
138
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
139
|
+
self.scaling = self.head_dim**-0.5
|
140
|
+
self.rope_theta = rope_theta
|
141
|
+
self.max_position_embeddings = max_position_embeddings
|
142
|
+
|
143
|
+
self.qkv_proj = QKVParallelLinear(
|
144
|
+
hidden_size,
|
145
|
+
self.head_dim,
|
146
|
+
self.total_num_heads,
|
147
|
+
self.total_num_kv_heads,
|
148
|
+
bias=bias,
|
149
|
+
quant_config=quant_config,
|
150
|
+
prefix=add_prefix("qkv_proj", prefix),
|
151
|
+
)
|
152
|
+
self.o_proj = RowParallelLinear(
|
153
|
+
self.total_num_heads * self.head_dim,
|
154
|
+
hidden_size,
|
155
|
+
bias=bias,
|
156
|
+
quant_config=quant_config,
|
157
|
+
prefix=add_prefix("o_proj", prefix),
|
158
|
+
)
|
159
|
+
|
160
|
+
self.rotary_emb = get_rope(
|
161
|
+
self.head_dim,
|
162
|
+
rotary_dim=self.rotary_dim,
|
163
|
+
max_position=max_position_embeddings,
|
164
|
+
base=rope_theta,
|
165
|
+
rope_scaling=rope_scaling,
|
166
|
+
is_neox_style=rope_is_neox_style,
|
167
|
+
)
|
168
|
+
self.attn = RadixAttention(
|
169
|
+
self.num_heads,
|
170
|
+
self.head_dim,
|
171
|
+
self.scaling,
|
172
|
+
num_kv_heads=self.num_kv_heads,
|
173
|
+
layer_id=layer_id,
|
174
|
+
quant_config=quant_config,
|
175
|
+
prefix=add_prefix("attn", prefix),
|
176
|
+
)
|
177
|
+
|
178
|
+
def forward(
|
179
|
+
self,
|
180
|
+
positions: torch.Tensor,
|
181
|
+
hidden_states: torch.Tensor,
|
182
|
+
forward_batch: ForwardBatch,
|
183
|
+
) -> torch.Tensor:
|
184
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
185
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
186
|
+
q, k = self.rotary_emb(positions, q, k)
|
187
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
188
|
+
output, _ = self.o_proj(attn_output)
|
189
|
+
return output
|
190
|
+
|
191
|
+
|
192
|
+
class ArceeDecoderLayer(nn.Module):
|
193
|
+
def __init__(
|
194
|
+
self,
|
195
|
+
config: LlamaConfig,
|
196
|
+
layer_id: int = 0,
|
197
|
+
quant_config: Optional[QuantizationConfig] = None,
|
198
|
+
prefix: str = "",
|
199
|
+
) -> None:
|
200
|
+
super().__init__()
|
201
|
+
self.hidden_size = config.hidden_size
|
202
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
203
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
204
|
+
if rope_scaling is not None and getattr(
|
205
|
+
config, "original_max_position_embeddings", None
|
206
|
+
):
|
207
|
+
rope_scaling["original_max_position_embeddings"] = (
|
208
|
+
config.original_max_position_embeddings
|
209
|
+
)
|
210
|
+
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
211
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
212
|
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
213
|
+
config, "bias", False
|
214
|
+
)
|
215
|
+
self.self_attn = ArceeAttention(
|
216
|
+
config=config,
|
217
|
+
hidden_size=self.hidden_size,
|
218
|
+
num_heads=config.num_attention_heads,
|
219
|
+
num_kv_heads=config.num_key_value_heads,
|
220
|
+
layer_id=layer_id,
|
221
|
+
rope_theta=rope_theta,
|
222
|
+
rope_scaling=rope_scaling,
|
223
|
+
rope_is_neox_style=rope_is_neox_style,
|
224
|
+
max_position_embeddings=max_position_embeddings,
|
225
|
+
quant_config=quant_config,
|
226
|
+
prefix=add_prefix("self_attn", prefix),
|
227
|
+
bias=attention_bias,
|
228
|
+
)
|
229
|
+
self.mlp = ArceeMLP(
|
230
|
+
hidden_size=self.hidden_size,
|
231
|
+
intermediate_size=config.intermediate_size,
|
232
|
+
hidden_act=config.hidden_act,
|
233
|
+
quant_config=quant_config,
|
234
|
+
prefix=add_prefix("mlp", prefix),
|
235
|
+
)
|
236
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
237
|
+
self.post_attention_layernorm = RMSNorm(
|
238
|
+
config.hidden_size, eps=config.rms_norm_eps
|
239
|
+
)
|
240
|
+
|
241
|
+
def forward(
|
242
|
+
self,
|
243
|
+
positions: torch.Tensor,
|
244
|
+
hidden_states: torch.Tensor,
|
245
|
+
forward_batch: ForwardBatch,
|
246
|
+
residual: Optional[torch.Tensor],
|
247
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
248
|
+
# Self Attention
|
249
|
+
if residual is None:
|
250
|
+
residual = hidden_states
|
251
|
+
hidden_states = self.input_layernorm(hidden_states)
|
252
|
+
else:
|
253
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
254
|
+
hidden_states = self.self_attn(
|
255
|
+
positions=positions,
|
256
|
+
hidden_states=hidden_states,
|
257
|
+
forward_batch=forward_batch,
|
258
|
+
)
|
259
|
+
|
260
|
+
# Fully Connected
|
261
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
262
|
+
hidden_states = self.mlp(hidden_states)
|
263
|
+
return hidden_states, residual
|
264
|
+
|
265
|
+
|
266
|
+
class ArceeModel(nn.Module):
|
267
|
+
def __init__(
|
268
|
+
self,
|
269
|
+
config: LlamaConfig,
|
270
|
+
quant_config: Optional[QuantizationConfig] = None,
|
271
|
+
prefix: str = "",
|
272
|
+
) -> None:
|
273
|
+
super().__init__()
|
274
|
+
self.config = config
|
275
|
+
self.padding_idx = config.pad_token_id
|
276
|
+
self.vocab_size = config.vocab_size
|
277
|
+
self.pp_group = get_pp_group()
|
278
|
+
if self.pp_group.is_first_rank:
|
279
|
+
self.embed_tokens = VocabParallelEmbedding(
|
280
|
+
config.vocab_size,
|
281
|
+
config.hidden_size,
|
282
|
+
quant_config=quant_config,
|
283
|
+
prefix=add_prefix("embed_tokens", prefix),
|
284
|
+
)
|
285
|
+
else:
|
286
|
+
self.embed_tokens = PPMissingLayer()
|
287
|
+
|
288
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
289
|
+
config.num_hidden_layers,
|
290
|
+
lambda idx, prefix: ArceeDecoderLayer(
|
291
|
+
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
292
|
+
),
|
293
|
+
pp_rank=self.pp_group.rank_in_group,
|
294
|
+
pp_size=self.pp_group.world_size,
|
295
|
+
prefix="model.layers",
|
296
|
+
)
|
297
|
+
|
298
|
+
if self.pp_group.is_last_rank:
|
299
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
300
|
+
else:
|
301
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
302
|
+
self.layers_to_capture = []
|
303
|
+
|
304
|
+
def forward(
|
305
|
+
self,
|
306
|
+
input_ids: torch.Tensor,
|
307
|
+
positions: torch.Tensor,
|
308
|
+
forward_batch: ForwardBatch,
|
309
|
+
input_embeds: torch.Tensor = None,
|
310
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
311
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
|
312
|
+
if self.pp_group.is_first_rank:
|
313
|
+
if input_embeds is None:
|
314
|
+
hidden_states = self.embed_tokens(input_ids)
|
315
|
+
else:
|
316
|
+
hidden_states = input_embeds
|
317
|
+
residual = None
|
318
|
+
else:
|
319
|
+
assert pp_proxy_tensors is not None
|
320
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
321
|
+
residual = pp_proxy_tensors["residual"]
|
322
|
+
|
323
|
+
aux_hidden_states = []
|
324
|
+
for i in range(self.start_layer, self.end_layer):
|
325
|
+
if i in self.layers_to_capture:
|
326
|
+
aux_hidden_states.append(hidden_states + residual)
|
327
|
+
layer = self.layers[i]
|
328
|
+
hidden_states, residual = layer(
|
329
|
+
positions,
|
330
|
+
hidden_states,
|
331
|
+
forward_batch,
|
332
|
+
residual,
|
333
|
+
)
|
334
|
+
|
335
|
+
if not self.pp_group.is_last_rank:
|
336
|
+
return PPProxyTensors(
|
337
|
+
{
|
338
|
+
"hidden_states": hidden_states,
|
339
|
+
"residual": residual,
|
340
|
+
}
|
341
|
+
)
|
342
|
+
else:
|
343
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
344
|
+
|
345
|
+
if len(aux_hidden_states) == 0:
|
346
|
+
return hidden_states
|
347
|
+
|
348
|
+
return hidden_states, aux_hidden_states
|
349
|
+
|
350
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
351
|
+
tp_size = get_tensor_model_parallel_world_size()
|
352
|
+
tp_rank = get_tensor_model_parallel_rank()
|
353
|
+
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
354
|
+
quantization_param_path,
|
355
|
+
tp_rank,
|
356
|
+
tp_size,
|
357
|
+
self.config.num_hidden_layers,
|
358
|
+
self.config.__class__.model_type,
|
359
|
+
):
|
360
|
+
if not isinstance(self.layers[layer_idx], nn.Identity):
|
361
|
+
layer_self_attn = self.layers[layer_idx].self_attn
|
362
|
+
|
363
|
+
if hasattr(layer_self_attn.attn, "k_scale"):
|
364
|
+
layer_self_attn.attn.k_scale = scaling_factor
|
365
|
+
layer_self_attn.attn.v_scale = scaling_factor
|
366
|
+
else:
|
367
|
+
raise RuntimeError(
|
368
|
+
"Self attention has no KV cache scaling factor attribute!"
|
369
|
+
)
|
370
|
+
|
371
|
+
|
372
|
+
class ArceeForCausalLM(nn.Module):
|
373
|
+
# BitandBytes specific attributes
|
374
|
+
default_bitsandbytes_target_modules = [
|
375
|
+
# Note: gate_proj is removed compared to Llama
|
376
|
+
".down_proj.",
|
377
|
+
".up_proj.",
|
378
|
+
".q_proj.",
|
379
|
+
".k_proj.",
|
380
|
+
".v_proj.",
|
381
|
+
".o_proj.",
|
382
|
+
]
|
383
|
+
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
384
|
+
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
385
|
+
bitsandbytes_stacked_params_mapping = {
|
386
|
+
# shard_name, weight_name, index
|
387
|
+
# Note: gate_proj and up_proj are removed as they are not stacked in ArceeMLP
|
388
|
+
".q_proj": (".qkv_proj", 0),
|
389
|
+
".k_proj": (".qkv_proj", 1),
|
390
|
+
".v_proj": (".qkv_proj", 2),
|
391
|
+
}
|
392
|
+
|
393
|
+
def __init__(
|
394
|
+
self,
|
395
|
+
config: LlamaConfig,
|
396
|
+
quant_config: Optional[QuantizationConfig] = None,
|
397
|
+
prefix: str = "",
|
398
|
+
) -> None:
|
399
|
+
super().__init__()
|
400
|
+
self.pp_group = get_pp_group()
|
401
|
+
self.config = config
|
402
|
+
self.quant_config = quant_config
|
403
|
+
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
|
404
|
+
# Arcee does not tie word embeddings
|
405
|
+
self.lm_head = ParallelLMHead(
|
406
|
+
config.vocab_size,
|
407
|
+
config.hidden_size,
|
408
|
+
quant_config=quant_config,
|
409
|
+
prefix=add_prefix("lm_head", prefix),
|
410
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
411
|
+
)
|
412
|
+
self.logits_processor = LogitsProcessor(config)
|
413
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
414
|
+
# Parameters that are stacked in a single tensor in this model
|
415
|
+
self.stacked_params_mapping = [
|
416
|
+
# (param_name, shard_name, shard_id)
|
417
|
+
(".qkv_proj", ".q_proj", "q"),
|
418
|
+
(".qkv_proj", ".k_proj", "k"),
|
419
|
+
(".qkv_proj", ".v_proj", "v"),
|
420
|
+
]
|
421
|
+
self.capture_aux_hidden_states = False
|
422
|
+
|
423
|
+
def _init_model(
|
424
|
+
self,
|
425
|
+
config: LlamaConfig,
|
426
|
+
quant_config: Optional[QuantizationConfig] = None,
|
427
|
+
prefix: str = "",
|
428
|
+
):
|
429
|
+
return ArceeModel(config, quant_config=quant_config, prefix=prefix)
|
430
|
+
|
431
|
+
@torch.no_grad()
|
432
|
+
def forward(
|
433
|
+
self,
|
434
|
+
input_ids: torch.Tensor,
|
435
|
+
positions: torch.Tensor,
|
436
|
+
forward_batch: ForwardBatch,
|
437
|
+
input_embeds: torch.Tensor = None,
|
438
|
+
get_embedding: bool = False,
|
439
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
440
|
+
) -> LogitsProcessorOutput:
|
441
|
+
hidden_states = self.model(
|
442
|
+
input_ids,
|
443
|
+
positions,
|
444
|
+
forward_batch,
|
445
|
+
input_embeds,
|
446
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
447
|
+
)
|
448
|
+
|
449
|
+
aux_hidden_states = None
|
450
|
+
if self.capture_aux_hidden_states:
|
451
|
+
hidden_states, aux_hidden_states = hidden_states
|
452
|
+
|
453
|
+
if self.pp_group.is_last_rank:
|
454
|
+
if not get_embedding:
|
455
|
+
return self.logits_processor(
|
456
|
+
input_ids,
|
457
|
+
hidden_states,
|
458
|
+
self.lm_head,
|
459
|
+
forward_batch,
|
460
|
+
aux_hidden_states,
|
461
|
+
)
|
462
|
+
else:
|
463
|
+
return self.pooler(hidden_states, forward_batch)
|
464
|
+
else:
|
465
|
+
return hidden_states
|
466
|
+
|
467
|
+
@property
|
468
|
+
def start_layer(self):
|
469
|
+
return self.model.start_layer
|
470
|
+
|
471
|
+
@property
|
472
|
+
def end_layer(self):
|
473
|
+
return self.model.end_layer
|
474
|
+
|
475
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
476
|
+
return self.model.embed_tokens
|
477
|
+
|
478
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
479
|
+
params_dict = dict(self.named_parameters())
|
480
|
+
|
481
|
+
for name, loaded_weight in weights:
|
482
|
+
layer_id = get_layer_id(name)
|
483
|
+
if (
|
484
|
+
layer_id is not None
|
485
|
+
and hasattr(self.model, "start_layer")
|
486
|
+
and (
|
487
|
+
layer_id < self.model.start_layer
|
488
|
+
or layer_id >= self.model.end_layer
|
489
|
+
)
|
490
|
+
):
|
491
|
+
continue
|
492
|
+
if "rotary_emb.inv_freq" in name or "projector" in name:
|
493
|
+
continue
|
494
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
495
|
+
continue
|
496
|
+
|
497
|
+
# Handle FP8 kv-scale remapping
|
498
|
+
if "scale" in name:
|
499
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
500
|
+
if name is None:
|
501
|
+
continue
|
502
|
+
|
503
|
+
is_stacked = False
|
504
|
+
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
505
|
+
if weight_name not in name:
|
506
|
+
continue
|
507
|
+
|
508
|
+
name = name.replace(weight_name, param_name)
|
509
|
+
if name not in params_dict:
|
510
|
+
continue
|
511
|
+
|
512
|
+
param = params_dict[name]
|
513
|
+
weight_loader = param.weight_loader
|
514
|
+
weight_loader(param, loaded_weight, shard_id)
|
515
|
+
is_stacked = True
|
516
|
+
break
|
517
|
+
|
518
|
+
if not is_stacked:
|
519
|
+
if name in params_dict:
|
520
|
+
param = params_dict[name]
|
521
|
+
weight_loader = getattr(
|
522
|
+
param, "weight_loader", default_weight_loader
|
523
|
+
)
|
524
|
+
weight_loader(param, loaded_weight)
|
525
|
+
else:
|
526
|
+
logger.warning(f"Parameter {name} not found in model.")
|
527
|
+
|
528
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
529
|
+
self.model.load_kv_cache_scales(quantization_param_path)
|
530
|
+
|
531
|
+
|
532
|
+
EntryClass = [ArceeForCausalLM]
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module):
|
|
325
325
|
num_experts=config.n_routed_experts
|
326
326
|
+ self.num_fused_shared_experts
|
327
327
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
328
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
328
329
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
329
330
|
hidden_size=config.hidden_size,
|
330
331
|
intermediate_size=config.moe_intermediate_size,
|
@@ -2112,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2112
2113
|
|
2113
2114
|
if disable_reason is not None:
|
2114
2115
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
2116
|
+
self.num_fused_shared_experts = 0
|
2115
2117
|
log_info_on_rank0(
|
2116
2118
|
logger,
|
2117
2119
|
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
sglang/srt/models/glm4_moe.py
CHANGED
@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
434
434
|
num_experts=config.n_routed_experts
|
435
435
|
+ self.num_fused_shared_experts
|
436
436
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
437
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
437
438
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
438
439
|
hidden_size=config.hidden_size,
|
439
440
|
intermediate_size=config.moe_intermediate_size,
|
@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|
740
741
|
global_server_args_dict["enable_deepep_moe"]
|
741
742
|
or global_server_args_dict["enable_ep_moe"]
|
742
743
|
):
|
743
|
-
disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
744
|
+
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
744
745
|
|
745
746
|
if disable_reason is not None:
|
746
747
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
748
|
+
self.num_fused_shared_experts = 0
|
747
749
|
log_info_on_rank0(
|
748
750
|
logger,
|
749
751
|
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
sglang/srt/models/granitemoe.py
CHANGED
@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
|
|
43
43
|
top_k: int,
|
44
44
|
hidden_size: int,
|
45
45
|
intermediate_size: int,
|
46
|
+
layer_id: int,
|
46
47
|
params_dtype: Optional[torch.dtype] = None,
|
47
48
|
quant_config: Optional[QuantizationConfig] = None,
|
48
49
|
tp_size: Optional[int] = None,
|
@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
|
|
71
72
|
top_k=top_k,
|
72
73
|
hidden_size=hidden_size,
|
73
74
|
intermediate_size=intermediate_size,
|
75
|
+
layer_id=layer_id,
|
74
76
|
params_dtype=params_dtype,
|
75
77
|
reduce_results=True,
|
76
78
|
quant_config=quant_config,
|
@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
|
|
203
205
|
top_k=config.num_experts_per_tok,
|
204
206
|
hidden_size=config.hidden_size,
|
205
207
|
intermediate_size=config.intermediate_size,
|
208
|
+
layer_id=layer_id,
|
206
209
|
quant_config=quant_config,
|
207
210
|
prefix=f"{prefix}.block_sparse_moe",
|
208
211
|
)
|
sglang/srt/models/grok.py
CHANGED
@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
|
|
78
78
|
def __init__(
|
79
79
|
self,
|
80
80
|
config: PretrainedConfig,
|
81
|
+
layer_id: int,
|
81
82
|
num_experts: int,
|
82
83
|
top_k: int,
|
83
84
|
hidden_size: int,
|
@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module):
|
|
128
129
|
self.experts = MoEImpl(
|
129
130
|
num_experts=num_experts,
|
130
131
|
top_k=top_k,
|
132
|
+
layer_id=layer_id,
|
131
133
|
hidden_size=hidden_size,
|
132
134
|
intermediate_size=intermediate_size,
|
133
135
|
params_dtype=params_dtype,
|
@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module):
|
|
331
333
|
)
|
332
334
|
self.block_sparse_moe = Grok1MoE(
|
333
335
|
config=config,
|
336
|
+
layer_id=layer_id,
|
334
337
|
num_experts=config.num_local_experts,
|
335
338
|
top_k=config.num_experts_per_tok,
|
336
339
|
hidden_size=config.hidden_size,
|
sglang/srt/models/hunyuan.py
CHANGED
sglang/srt/models/llama4.py
CHANGED
@@ -87,6 +87,7 @@ class Llama4MoE(nn.Module):
|
|
87
87
|
def __init__(
|
88
88
|
self,
|
89
89
|
config: Llama4TextConfig,
|
90
|
+
layer_id: int,
|
90
91
|
quant_config: Optional[QuantizationConfig] = None,
|
91
92
|
prefix: str = "",
|
92
93
|
):
|
@@ -114,6 +115,7 @@ class Llama4MoE(nn.Module):
|
|
114
115
|
num_experts=config.num_local_experts,
|
115
116
|
hidden_size=config.hidden_size,
|
116
117
|
intermediate_size=intermediate_size_moe,
|
118
|
+
layer_id=layer_id,
|
117
119
|
reduce_results=False,
|
118
120
|
quant_config=quant_config,
|
119
121
|
apply_router_weight_on_input=True,
|
@@ -373,6 +375,7 @@ class Llama4DecoderLayer(nn.Module):
|
|
373
375
|
if is_moe_layer:
|
374
376
|
self.feed_forward = Llama4MoE(
|
375
377
|
config=config,
|
378
|
+
layer_id=layer_id,
|
376
379
|
quant_config=quant_config,
|
377
380
|
prefix=add_prefix("feed_forward", prefix),
|
378
381
|
)
|
sglang/srt/models/mixtral.py
CHANGED
@@ -69,6 +69,7 @@ class MixtralMoE(nn.Module):
|
|
69
69
|
top_k: int,
|
70
70
|
hidden_size: int,
|
71
71
|
intermediate_size: int,
|
72
|
+
layer_id: int,
|
72
73
|
params_dtype: Optional[torch.dtype] = None,
|
73
74
|
quant_config: Optional[QuantizationConfig] = None,
|
74
75
|
tp_size: Optional[int] = None,
|
@@ -97,6 +98,7 @@ class MixtralMoE(nn.Module):
|
|
97
98
|
self.experts = MoEImpl(
|
98
99
|
num_experts=num_experts,
|
99
100
|
top_k=top_k,
|
101
|
+
layer_id=layer_id,
|
100
102
|
hidden_size=hidden_size,
|
101
103
|
intermediate_size=intermediate_size,
|
102
104
|
params_dtype=params_dtype,
|
@@ -226,6 +228,7 @@ class MixtralDecoderLayer(nn.Module):
|
|
226
228
|
top_k=config.num_experts_per_tok,
|
227
229
|
hidden_size=config.hidden_size,
|
228
230
|
intermediate_size=config.intermediate_size,
|
231
|
+
layer_id=layer_id,
|
229
232
|
quant_config=quant_config,
|
230
233
|
prefix=add_prefix("block_sparse_moe", prefix),
|
231
234
|
)
|
sglang/srt/models/olmoe.py
CHANGED
@@ -63,6 +63,7 @@ class OlmoeMoE(nn.Module):
|
|
63
63
|
params_dtype: Optional[torch.dtype] = None,
|
64
64
|
quant_config: Optional[QuantizationConfig] = None,
|
65
65
|
tp_size: Optional[int] = None,
|
66
|
+
layer_id: int = 0,
|
66
67
|
prefix: str = "",
|
67
68
|
):
|
68
69
|
super().__init__()
|
@@ -89,6 +90,7 @@ class OlmoeMoE(nn.Module):
|
|
89
90
|
reduce_results=True,
|
90
91
|
quant_config=quant_config,
|
91
92
|
tp_size=tp_size,
|
93
|
+
layer_id=layer_id,
|
92
94
|
prefix=add_prefix("experts", prefix),
|
93
95
|
)
|
94
96
|
|
@@ -224,6 +226,7 @@ class OlmoeDecoderLayer(nn.Module):
|
|
224
226
|
top_k=config.num_experts_per_tok,
|
225
227
|
hidden_size=config.hidden_size,
|
226
228
|
intermediate_size=config.intermediate_size,
|
229
|
+
layer_id=layer_id,
|
227
230
|
quant_config=quant_config,
|
228
231
|
prefix=add_prefix("mlp", prefix),
|
229
232
|
)
|