sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +39 -674
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
- sglang/srt/layers/quantization/fp8.py +52 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +165 -67
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +90 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +60 -17
- sglang/srt/mem_cache/hiradix_cache.py +36 -8
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +418 -29
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/glm4_moe.py +6 -4
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +991 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +49 -18
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,532 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""Inference-only Arcee Foundational Model (AFM) compatible with HuggingFace weights."""
|
15
|
+
|
16
|
+
import logging
|
17
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from torch import nn
|
21
|
+
from transformers import LlamaConfig
|
22
|
+
|
23
|
+
from sglang.srt.distributed import (
|
24
|
+
get_pp_group,
|
25
|
+
get_tensor_model_parallel_rank,
|
26
|
+
get_tensor_model_parallel_world_size,
|
27
|
+
)
|
28
|
+
from sglang.srt.layers.activation import get_act_fn
|
29
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
30
|
+
from sglang.srt.layers.linear import (
|
31
|
+
ColumnParallelLinear,
|
32
|
+
QKVParallelLinear,
|
33
|
+
RowParallelLinear,
|
34
|
+
)
|
35
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
36
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
37
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
40
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
41
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
42
|
+
ParallelLMHead,
|
43
|
+
VocabParallelEmbedding,
|
44
|
+
)
|
45
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
46
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
47
|
+
from sglang.srt.model_loader.weight_utils import (
|
48
|
+
default_weight_loader,
|
49
|
+
kv_cache_scales_loader,
|
50
|
+
maybe_remap_kv_scale_name,
|
51
|
+
)
|
52
|
+
from sglang.srt.utils import add_prefix, make_layers
|
53
|
+
|
54
|
+
logger = logging.getLogger(__name__)
|
55
|
+
|
56
|
+
|
57
|
+
class ArceeMLP(nn.Module):
|
58
|
+
"""
|
59
|
+
MLP block for the Arcee model, using a ReLU-squared activation function.
|
60
|
+
This differs from the Llama SwiGLU activation.
|
61
|
+
"""
|
62
|
+
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
hidden_size: int,
|
66
|
+
intermediate_size: int,
|
67
|
+
hidden_act: str,
|
68
|
+
quant_config: Optional[QuantizationConfig] = None,
|
69
|
+
prefix: str = "",
|
70
|
+
reduce_results: bool = True,
|
71
|
+
) -> None:
|
72
|
+
super().__init__()
|
73
|
+
# Arcee uses a single up-projection, not a merged gate/up projection.
|
74
|
+
self.up_proj = ColumnParallelLinear(
|
75
|
+
hidden_size,
|
76
|
+
intermediate_size,
|
77
|
+
bias=False,
|
78
|
+
quant_config=quant_config,
|
79
|
+
prefix=add_prefix("up_proj", prefix),
|
80
|
+
)
|
81
|
+
self.down_proj = RowParallelLinear(
|
82
|
+
intermediate_size,
|
83
|
+
hidden_size,
|
84
|
+
bias=False,
|
85
|
+
quant_config=quant_config,
|
86
|
+
prefix=add_prefix("down_proj", prefix),
|
87
|
+
reduce_results=reduce_results,
|
88
|
+
)
|
89
|
+
if hidden_act != "relu2":
|
90
|
+
raise ValueError(
|
91
|
+
f"Unsupported activation: {hidden_act}. "
|
92
|
+
"Arcee model in SGLang only supports 'relu2'."
|
93
|
+
)
|
94
|
+
# The activation function is relu(x)^2
|
95
|
+
self.act_fn = get_act_fn("relu2")
|
96
|
+
|
97
|
+
def forward(self, x, forward_batch=None):
|
98
|
+
x, _ = self.up_proj(x)
|
99
|
+
x = self.act_fn(x)
|
100
|
+
x, _ = self.down_proj(x)
|
101
|
+
return x
|
102
|
+
|
103
|
+
|
104
|
+
class ArceeAttention(nn.Module):
|
105
|
+
def __init__(
|
106
|
+
self,
|
107
|
+
config: LlamaConfig,
|
108
|
+
hidden_size: int,
|
109
|
+
num_heads: int,
|
110
|
+
num_kv_heads: int,
|
111
|
+
layer_id: int = 0,
|
112
|
+
rope_theta: float = 10000,
|
113
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
114
|
+
rope_is_neox_style: bool = True,
|
115
|
+
max_position_embeddings: int = 8192,
|
116
|
+
quant_config: Optional[QuantizationConfig] = None,
|
117
|
+
prefix: str = "",
|
118
|
+
bias: bool = False,
|
119
|
+
) -> None:
|
120
|
+
super().__init__()
|
121
|
+
self.hidden_size = hidden_size
|
122
|
+
tp_size = get_tensor_model_parallel_world_size()
|
123
|
+
self.total_num_heads = num_heads
|
124
|
+
assert self.total_num_heads % tp_size == 0
|
125
|
+
self.num_heads = self.total_num_heads // tp_size
|
126
|
+
self.total_num_kv_heads = num_kv_heads
|
127
|
+
if self.total_num_kv_heads >= tp_size:
|
128
|
+
assert self.total_num_kv_heads % tp_size == 0
|
129
|
+
else:
|
130
|
+
assert tp_size % self.total_num_kv_heads == 0
|
131
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
132
|
+
self.head_dim = getattr(config, "head_dim", None)
|
133
|
+
if self.head_dim is None:
|
134
|
+
self.head_dim = self.hidden_size // self.total_num_heads
|
135
|
+
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
136
|
+
self.rotary_dim = int(self.partial_rotary_factor * self.head_dim)
|
137
|
+
self.q_size = self.num_heads * self.head_dim
|
138
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
139
|
+
self.scaling = self.head_dim**-0.5
|
140
|
+
self.rope_theta = rope_theta
|
141
|
+
self.max_position_embeddings = max_position_embeddings
|
142
|
+
|
143
|
+
self.qkv_proj = QKVParallelLinear(
|
144
|
+
hidden_size,
|
145
|
+
self.head_dim,
|
146
|
+
self.total_num_heads,
|
147
|
+
self.total_num_kv_heads,
|
148
|
+
bias=bias,
|
149
|
+
quant_config=quant_config,
|
150
|
+
prefix=add_prefix("qkv_proj", prefix),
|
151
|
+
)
|
152
|
+
self.o_proj = RowParallelLinear(
|
153
|
+
self.total_num_heads * self.head_dim,
|
154
|
+
hidden_size,
|
155
|
+
bias=bias,
|
156
|
+
quant_config=quant_config,
|
157
|
+
prefix=add_prefix("o_proj", prefix),
|
158
|
+
)
|
159
|
+
|
160
|
+
self.rotary_emb = get_rope(
|
161
|
+
self.head_dim,
|
162
|
+
rotary_dim=self.rotary_dim,
|
163
|
+
max_position=max_position_embeddings,
|
164
|
+
base=rope_theta,
|
165
|
+
rope_scaling=rope_scaling,
|
166
|
+
is_neox_style=rope_is_neox_style,
|
167
|
+
)
|
168
|
+
self.attn = RadixAttention(
|
169
|
+
self.num_heads,
|
170
|
+
self.head_dim,
|
171
|
+
self.scaling,
|
172
|
+
num_kv_heads=self.num_kv_heads,
|
173
|
+
layer_id=layer_id,
|
174
|
+
quant_config=quant_config,
|
175
|
+
prefix=add_prefix("attn", prefix),
|
176
|
+
)
|
177
|
+
|
178
|
+
def forward(
|
179
|
+
self,
|
180
|
+
positions: torch.Tensor,
|
181
|
+
hidden_states: torch.Tensor,
|
182
|
+
forward_batch: ForwardBatch,
|
183
|
+
) -> torch.Tensor:
|
184
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
185
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
186
|
+
q, k = self.rotary_emb(positions, q, k)
|
187
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
188
|
+
output, _ = self.o_proj(attn_output)
|
189
|
+
return output
|
190
|
+
|
191
|
+
|
192
|
+
class ArceeDecoderLayer(nn.Module):
|
193
|
+
def __init__(
|
194
|
+
self,
|
195
|
+
config: LlamaConfig,
|
196
|
+
layer_id: int = 0,
|
197
|
+
quant_config: Optional[QuantizationConfig] = None,
|
198
|
+
prefix: str = "",
|
199
|
+
) -> None:
|
200
|
+
super().__init__()
|
201
|
+
self.hidden_size = config.hidden_size
|
202
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
203
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
204
|
+
if rope_scaling is not None and getattr(
|
205
|
+
config, "original_max_position_embeddings", None
|
206
|
+
):
|
207
|
+
rope_scaling["original_max_position_embeddings"] = (
|
208
|
+
config.original_max_position_embeddings
|
209
|
+
)
|
210
|
+
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
211
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
212
|
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
213
|
+
config, "bias", False
|
214
|
+
)
|
215
|
+
self.self_attn = ArceeAttention(
|
216
|
+
config=config,
|
217
|
+
hidden_size=self.hidden_size,
|
218
|
+
num_heads=config.num_attention_heads,
|
219
|
+
num_kv_heads=config.num_key_value_heads,
|
220
|
+
layer_id=layer_id,
|
221
|
+
rope_theta=rope_theta,
|
222
|
+
rope_scaling=rope_scaling,
|
223
|
+
rope_is_neox_style=rope_is_neox_style,
|
224
|
+
max_position_embeddings=max_position_embeddings,
|
225
|
+
quant_config=quant_config,
|
226
|
+
prefix=add_prefix("self_attn", prefix),
|
227
|
+
bias=attention_bias,
|
228
|
+
)
|
229
|
+
self.mlp = ArceeMLP(
|
230
|
+
hidden_size=self.hidden_size,
|
231
|
+
intermediate_size=config.intermediate_size,
|
232
|
+
hidden_act=config.hidden_act,
|
233
|
+
quant_config=quant_config,
|
234
|
+
prefix=add_prefix("mlp", prefix),
|
235
|
+
)
|
236
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
237
|
+
self.post_attention_layernorm = RMSNorm(
|
238
|
+
config.hidden_size, eps=config.rms_norm_eps
|
239
|
+
)
|
240
|
+
|
241
|
+
def forward(
|
242
|
+
self,
|
243
|
+
positions: torch.Tensor,
|
244
|
+
hidden_states: torch.Tensor,
|
245
|
+
forward_batch: ForwardBatch,
|
246
|
+
residual: Optional[torch.Tensor],
|
247
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
248
|
+
# Self Attention
|
249
|
+
if residual is None:
|
250
|
+
residual = hidden_states
|
251
|
+
hidden_states = self.input_layernorm(hidden_states)
|
252
|
+
else:
|
253
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
254
|
+
hidden_states = self.self_attn(
|
255
|
+
positions=positions,
|
256
|
+
hidden_states=hidden_states,
|
257
|
+
forward_batch=forward_batch,
|
258
|
+
)
|
259
|
+
|
260
|
+
# Fully Connected
|
261
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
262
|
+
hidden_states = self.mlp(hidden_states)
|
263
|
+
return hidden_states, residual
|
264
|
+
|
265
|
+
|
266
|
+
class ArceeModel(nn.Module):
|
267
|
+
def __init__(
|
268
|
+
self,
|
269
|
+
config: LlamaConfig,
|
270
|
+
quant_config: Optional[QuantizationConfig] = None,
|
271
|
+
prefix: str = "",
|
272
|
+
) -> None:
|
273
|
+
super().__init__()
|
274
|
+
self.config = config
|
275
|
+
self.padding_idx = config.pad_token_id
|
276
|
+
self.vocab_size = config.vocab_size
|
277
|
+
self.pp_group = get_pp_group()
|
278
|
+
if self.pp_group.is_first_rank:
|
279
|
+
self.embed_tokens = VocabParallelEmbedding(
|
280
|
+
config.vocab_size,
|
281
|
+
config.hidden_size,
|
282
|
+
quant_config=quant_config,
|
283
|
+
prefix=add_prefix("embed_tokens", prefix),
|
284
|
+
)
|
285
|
+
else:
|
286
|
+
self.embed_tokens = PPMissingLayer()
|
287
|
+
|
288
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
289
|
+
config.num_hidden_layers,
|
290
|
+
lambda idx, prefix: ArceeDecoderLayer(
|
291
|
+
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
292
|
+
),
|
293
|
+
pp_rank=self.pp_group.rank_in_group,
|
294
|
+
pp_size=self.pp_group.world_size,
|
295
|
+
prefix="model.layers",
|
296
|
+
)
|
297
|
+
|
298
|
+
if self.pp_group.is_last_rank:
|
299
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
300
|
+
else:
|
301
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
302
|
+
self.layers_to_capture = []
|
303
|
+
|
304
|
+
def forward(
|
305
|
+
self,
|
306
|
+
input_ids: torch.Tensor,
|
307
|
+
positions: torch.Tensor,
|
308
|
+
forward_batch: ForwardBatch,
|
309
|
+
input_embeds: torch.Tensor = None,
|
310
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
311
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
|
312
|
+
if self.pp_group.is_first_rank:
|
313
|
+
if input_embeds is None:
|
314
|
+
hidden_states = self.embed_tokens(input_ids)
|
315
|
+
else:
|
316
|
+
hidden_states = input_embeds
|
317
|
+
residual = None
|
318
|
+
else:
|
319
|
+
assert pp_proxy_tensors is not None
|
320
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
321
|
+
residual = pp_proxy_tensors["residual"]
|
322
|
+
|
323
|
+
aux_hidden_states = []
|
324
|
+
for i in range(self.start_layer, self.end_layer):
|
325
|
+
if i in self.layers_to_capture:
|
326
|
+
aux_hidden_states.append(hidden_states + residual)
|
327
|
+
layer = self.layers[i]
|
328
|
+
hidden_states, residual = layer(
|
329
|
+
positions,
|
330
|
+
hidden_states,
|
331
|
+
forward_batch,
|
332
|
+
residual,
|
333
|
+
)
|
334
|
+
|
335
|
+
if not self.pp_group.is_last_rank:
|
336
|
+
return PPProxyTensors(
|
337
|
+
{
|
338
|
+
"hidden_states": hidden_states,
|
339
|
+
"residual": residual,
|
340
|
+
}
|
341
|
+
)
|
342
|
+
else:
|
343
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
344
|
+
|
345
|
+
if len(aux_hidden_states) == 0:
|
346
|
+
return hidden_states
|
347
|
+
|
348
|
+
return hidden_states, aux_hidden_states
|
349
|
+
|
350
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
351
|
+
tp_size = get_tensor_model_parallel_world_size()
|
352
|
+
tp_rank = get_tensor_model_parallel_rank()
|
353
|
+
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
354
|
+
quantization_param_path,
|
355
|
+
tp_rank,
|
356
|
+
tp_size,
|
357
|
+
self.config.num_hidden_layers,
|
358
|
+
self.config.__class__.model_type,
|
359
|
+
):
|
360
|
+
if not isinstance(self.layers[layer_idx], nn.Identity):
|
361
|
+
layer_self_attn = self.layers[layer_idx].self_attn
|
362
|
+
|
363
|
+
if hasattr(layer_self_attn.attn, "k_scale"):
|
364
|
+
layer_self_attn.attn.k_scale = scaling_factor
|
365
|
+
layer_self_attn.attn.v_scale = scaling_factor
|
366
|
+
else:
|
367
|
+
raise RuntimeError(
|
368
|
+
"Self attention has no KV cache scaling factor attribute!"
|
369
|
+
)
|
370
|
+
|
371
|
+
|
372
|
+
class ArceeForCausalLM(nn.Module):
|
373
|
+
# BitandBytes specific attributes
|
374
|
+
default_bitsandbytes_target_modules = [
|
375
|
+
# Note: gate_proj is removed compared to Llama
|
376
|
+
".down_proj.",
|
377
|
+
".up_proj.",
|
378
|
+
".q_proj.",
|
379
|
+
".k_proj.",
|
380
|
+
".v_proj.",
|
381
|
+
".o_proj.",
|
382
|
+
]
|
383
|
+
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
384
|
+
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
385
|
+
bitsandbytes_stacked_params_mapping = {
|
386
|
+
# shard_name, weight_name, index
|
387
|
+
# Note: gate_proj and up_proj are removed as they are not stacked in ArceeMLP
|
388
|
+
".q_proj": (".qkv_proj", 0),
|
389
|
+
".k_proj": (".qkv_proj", 1),
|
390
|
+
".v_proj": (".qkv_proj", 2),
|
391
|
+
}
|
392
|
+
|
393
|
+
def __init__(
|
394
|
+
self,
|
395
|
+
config: LlamaConfig,
|
396
|
+
quant_config: Optional[QuantizationConfig] = None,
|
397
|
+
prefix: str = "",
|
398
|
+
) -> None:
|
399
|
+
super().__init__()
|
400
|
+
self.pp_group = get_pp_group()
|
401
|
+
self.config = config
|
402
|
+
self.quant_config = quant_config
|
403
|
+
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
|
404
|
+
# Arcee does not tie word embeddings
|
405
|
+
self.lm_head = ParallelLMHead(
|
406
|
+
config.vocab_size,
|
407
|
+
config.hidden_size,
|
408
|
+
quant_config=quant_config,
|
409
|
+
prefix=add_prefix("lm_head", prefix),
|
410
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
411
|
+
)
|
412
|
+
self.logits_processor = LogitsProcessor(config)
|
413
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
414
|
+
# Parameters that are stacked in a single tensor in this model
|
415
|
+
self.stacked_params_mapping = [
|
416
|
+
# (param_name, shard_name, shard_id)
|
417
|
+
(".qkv_proj", ".q_proj", "q"),
|
418
|
+
(".qkv_proj", ".k_proj", "k"),
|
419
|
+
(".qkv_proj", ".v_proj", "v"),
|
420
|
+
]
|
421
|
+
self.capture_aux_hidden_states = False
|
422
|
+
|
423
|
+
def _init_model(
|
424
|
+
self,
|
425
|
+
config: LlamaConfig,
|
426
|
+
quant_config: Optional[QuantizationConfig] = None,
|
427
|
+
prefix: str = "",
|
428
|
+
):
|
429
|
+
return ArceeModel(config, quant_config=quant_config, prefix=prefix)
|
430
|
+
|
431
|
+
@torch.no_grad()
|
432
|
+
def forward(
|
433
|
+
self,
|
434
|
+
input_ids: torch.Tensor,
|
435
|
+
positions: torch.Tensor,
|
436
|
+
forward_batch: ForwardBatch,
|
437
|
+
input_embeds: torch.Tensor = None,
|
438
|
+
get_embedding: bool = False,
|
439
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
440
|
+
) -> LogitsProcessorOutput:
|
441
|
+
hidden_states = self.model(
|
442
|
+
input_ids,
|
443
|
+
positions,
|
444
|
+
forward_batch,
|
445
|
+
input_embeds,
|
446
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
447
|
+
)
|
448
|
+
|
449
|
+
aux_hidden_states = None
|
450
|
+
if self.capture_aux_hidden_states:
|
451
|
+
hidden_states, aux_hidden_states = hidden_states
|
452
|
+
|
453
|
+
if self.pp_group.is_last_rank:
|
454
|
+
if not get_embedding:
|
455
|
+
return self.logits_processor(
|
456
|
+
input_ids,
|
457
|
+
hidden_states,
|
458
|
+
self.lm_head,
|
459
|
+
forward_batch,
|
460
|
+
aux_hidden_states,
|
461
|
+
)
|
462
|
+
else:
|
463
|
+
return self.pooler(hidden_states, forward_batch)
|
464
|
+
else:
|
465
|
+
return hidden_states
|
466
|
+
|
467
|
+
@property
|
468
|
+
def start_layer(self):
|
469
|
+
return self.model.start_layer
|
470
|
+
|
471
|
+
@property
|
472
|
+
def end_layer(self):
|
473
|
+
return self.model.end_layer
|
474
|
+
|
475
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
476
|
+
return self.model.embed_tokens
|
477
|
+
|
478
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
479
|
+
params_dict = dict(self.named_parameters())
|
480
|
+
|
481
|
+
for name, loaded_weight in weights:
|
482
|
+
layer_id = get_layer_id(name)
|
483
|
+
if (
|
484
|
+
layer_id is not None
|
485
|
+
and hasattr(self.model, "start_layer")
|
486
|
+
and (
|
487
|
+
layer_id < self.model.start_layer
|
488
|
+
or layer_id >= self.model.end_layer
|
489
|
+
)
|
490
|
+
):
|
491
|
+
continue
|
492
|
+
if "rotary_emb.inv_freq" in name or "projector" in name:
|
493
|
+
continue
|
494
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
495
|
+
continue
|
496
|
+
|
497
|
+
# Handle FP8 kv-scale remapping
|
498
|
+
if "scale" in name:
|
499
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
500
|
+
if name is None:
|
501
|
+
continue
|
502
|
+
|
503
|
+
is_stacked = False
|
504
|
+
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
505
|
+
if weight_name not in name:
|
506
|
+
continue
|
507
|
+
|
508
|
+
name = name.replace(weight_name, param_name)
|
509
|
+
if name not in params_dict:
|
510
|
+
continue
|
511
|
+
|
512
|
+
param = params_dict[name]
|
513
|
+
weight_loader = param.weight_loader
|
514
|
+
weight_loader(param, loaded_weight, shard_id)
|
515
|
+
is_stacked = True
|
516
|
+
break
|
517
|
+
|
518
|
+
if not is_stacked:
|
519
|
+
if name in params_dict:
|
520
|
+
param = params_dict[name]
|
521
|
+
weight_loader = getattr(
|
522
|
+
param, "weight_loader", default_weight_loader
|
523
|
+
)
|
524
|
+
weight_loader(param, loaded_weight)
|
525
|
+
else:
|
526
|
+
logger.warning(f"Parameter {name} not found in model.")
|
527
|
+
|
528
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
529
|
+
self.model.load_kv_cache_scales(quantization_param_path)
|
530
|
+
|
531
|
+
|
532
|
+
EntryClass = [ArceeForCausalLM]
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -59,7 +59,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
59
59
|
from sglang.srt.layers.moe.ep_moe.layer import (
|
60
60
|
DeepEPMoE,
|
61
61
|
get_moe_impl_class,
|
62
|
-
|
62
|
+
should_use_flashinfer_trtllm_moe,
|
63
63
|
)
|
64
64
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
65
65
|
from sglang.srt.layers.moe.topk import TopK
|
@@ -252,8 +252,7 @@ class MoEGate(nn.Module):
|
|
252
252
|
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
|
253
253
|
if (
|
254
254
|
_is_cuda
|
255
|
-
and
|
256
|
-
and hidden_states.shape[0] < 4
|
255
|
+
and hidden_states.shape[0] <= 16
|
257
256
|
and hidden_states.shape[1] == 7168
|
258
257
|
and self.weight.shape[0] == 256
|
259
258
|
and _device_sm >= 90
|
@@ -317,7 +316,7 @@ class DeepseekV2MoE(nn.Module):
|
|
317
316
|
correction_bias=self.gate.e_score_correction_bias,
|
318
317
|
routed_scaling_factor=self.routed_scaling_factor,
|
319
318
|
)
|
320
|
-
if not
|
319
|
+
if not should_use_flashinfer_trtllm_moe()
|
321
320
|
else None
|
322
321
|
)
|
323
322
|
|
@@ -325,6 +324,7 @@ class DeepseekV2MoE(nn.Module):
|
|
325
324
|
num_experts=config.n_routed_experts
|
326
325
|
+ self.num_fused_shared_experts
|
327
326
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
327
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
328
328
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
329
329
|
hidden_size=config.hidden_size,
|
330
330
|
intermediate_size=config.moe_intermediate_size,
|
@@ -351,11 +351,10 @@ class DeepseekV2MoE(nn.Module):
|
|
351
351
|
renormalize=config.norm_topk_prob,
|
352
352
|
use_grouped_topk=True,
|
353
353
|
num_expert_group=config.n_group,
|
354
|
-
num_fused_shared_experts=self.num_fused_shared_experts,
|
355
354
|
topk_group=config.topk_group,
|
356
355
|
correction_bias=self.gate.e_score_correction_bias,
|
357
356
|
)
|
358
|
-
if
|
357
|
+
if should_use_flashinfer_trtllm_moe()
|
359
358
|
else {}
|
360
359
|
),
|
361
360
|
)
|
@@ -1258,6 +1257,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1258
1257
|
self.current_attention_backend == "fa3"
|
1259
1258
|
or self.current_attention_backend == "flashinfer"
|
1260
1259
|
or self.current_attention_backend == "cutlass_mla"
|
1260
|
+
or self.current_attention_backend == "trtllm_mla"
|
1261
1261
|
):
|
1262
1262
|
attn_output = self.attn_mqa(
|
1263
1263
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
@@ -2112,6 +2112,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2112
2112
|
|
2113
2113
|
if disable_reason is not None:
|
2114
2114
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
2115
|
+
self.num_fused_shared_experts = 0
|
2115
2116
|
log_info_on_rank0(
|
2116
2117
|
logger,
|
2117
2118
|
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
sglang/srt/models/glm4_moe.py
CHANGED
@@ -52,7 +52,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
52
52
|
from sglang.srt.layers.moe.ep_moe.layer import (
|
53
53
|
DeepEPMoE,
|
54
54
|
get_moe_impl_class,
|
55
|
-
|
55
|
+
should_use_flashinfer_trtllm_moe,
|
56
56
|
)
|
57
57
|
from sglang.srt.layers.moe.topk import TopK
|
58
58
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -426,7 +426,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
426
426
|
correction_bias=self.gate.e_score_correction_bias,
|
427
427
|
routed_scaling_factor=self.routed_scaling_factor,
|
428
428
|
)
|
429
|
-
if not
|
429
|
+
if not should_use_flashinfer_trtllm_moe()
|
430
430
|
else None
|
431
431
|
)
|
432
432
|
|
@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
434
434
|
num_experts=config.n_routed_experts
|
435
435
|
+ self.num_fused_shared_experts
|
436
436
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
437
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
437
438
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
438
439
|
hidden_size=config.hidden_size,
|
439
440
|
intermediate_size=config.moe_intermediate_size,
|
@@ -464,7 +465,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
464
465
|
topk_group=config.topk_group,
|
465
466
|
correction_bias=self.gate.e_score_correction_bias,
|
466
467
|
)
|
467
|
-
if
|
468
|
+
if should_use_flashinfer_trtllm_moe()
|
468
469
|
else {}
|
469
470
|
),
|
470
471
|
)
|
@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|
740
741
|
global_server_args_dict["enable_deepep_moe"]
|
741
742
|
or global_server_args_dict["enable_ep_moe"]
|
742
743
|
):
|
743
|
-
disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
744
|
+
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
744
745
|
|
745
746
|
if disable_reason is not None:
|
746
747
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
748
|
+
self.num_fused_shared_experts = 0
|
747
749
|
log_info_on_rank0(
|
748
750
|
logger,
|
749
751
|
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
sglang/srt/models/granitemoe.py
CHANGED
@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
|
|
43
43
|
top_k: int,
|
44
44
|
hidden_size: int,
|
45
45
|
intermediate_size: int,
|
46
|
+
layer_id: int,
|
46
47
|
params_dtype: Optional[torch.dtype] = None,
|
47
48
|
quant_config: Optional[QuantizationConfig] = None,
|
48
49
|
tp_size: Optional[int] = None,
|
@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
|
|
71
72
|
top_k=top_k,
|
72
73
|
hidden_size=hidden_size,
|
73
74
|
intermediate_size=intermediate_size,
|
75
|
+
layer_id=layer_id,
|
74
76
|
params_dtype=params_dtype,
|
75
77
|
reduce_results=True,
|
76
78
|
quant_config=quant_config,
|
@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
|
|
203
205
|
top_k=config.num_experts_per_tok,
|
204
206
|
hidden_size=config.hidden_size,
|
205
207
|
intermediate_size=config.intermediate_size,
|
208
|
+
layer_id=layer_id,
|
206
209
|
quant_config=quant_config,
|
207
210
|
prefix=f"{prefix}.block_sparse_moe",
|
208
211
|
)
|
sglang/srt/models/grok.py
CHANGED
@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
|
|
78
78
|
def __init__(
|
79
79
|
self,
|
80
80
|
config: PretrainedConfig,
|
81
|
+
layer_id: int,
|
81
82
|
num_experts: int,
|
82
83
|
top_k: int,
|
83
84
|
hidden_size: int,
|
@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module):
|
|
128
129
|
self.experts = MoEImpl(
|
129
130
|
num_experts=num_experts,
|
130
131
|
top_k=top_k,
|
132
|
+
layer_id=layer_id,
|
131
133
|
hidden_size=hidden_size,
|
132
134
|
intermediate_size=intermediate_size,
|
133
135
|
params_dtype=params_dtype,
|
@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module):
|
|
331
333
|
)
|
332
334
|
self.block_sparse_moe = Grok1MoE(
|
333
335
|
config=config,
|
336
|
+
layer_id=layer_id,
|
334
337
|
num_experts=config.num_local_experts,
|
335
338
|
top_k=config.num_experts_per_tok,
|
336
339
|
hidden_size=config.hidden_size,
|