sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +23 -3
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +98 -603
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +6 -2
- sglang/srt/layers/quantization/fp8.py +0 -18
- sglang/srt/layers/quantization/modelopt_quant.py +2 -0
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/managers/cache_controller.py +143 -45
- sglang/srt/managers/data_parallel_controller.py +6 -0
- sglang/srt/managers/io_struct.py +12 -2
- sglang/srt/managers/scheduler.py +116 -669
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +166 -83
- sglang/srt/managers/tp_worker.py +9 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +45 -11
- sglang/srt/mem_cache/hiradix_cache.py +15 -4
- sglang/srt/mem_cache/memory_pool_host.py +73 -1
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/model_runner.py +20 -13
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +15 -56
- sglang/srt/models/glm4_moe.py +3 -1
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/qwen3_moe.py +12 -69
- sglang/srt/models/step3_vl.py +994 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +18 -13
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/test/test_utils.py +53 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,532 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""Inference-only Arcee Foundational Model (AFM) compatible with HuggingFace weights."""
|
15
|
+
|
16
|
+
import logging
|
17
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from torch import nn
|
21
|
+
from transformers import LlamaConfig
|
22
|
+
|
23
|
+
from sglang.srt.distributed import (
|
24
|
+
get_pp_group,
|
25
|
+
get_tensor_model_parallel_rank,
|
26
|
+
get_tensor_model_parallel_world_size,
|
27
|
+
)
|
28
|
+
from sglang.srt.layers.activation import get_act_fn
|
29
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
30
|
+
from sglang.srt.layers.linear import (
|
31
|
+
ColumnParallelLinear,
|
32
|
+
QKVParallelLinear,
|
33
|
+
RowParallelLinear,
|
34
|
+
)
|
35
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
36
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
37
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
40
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
41
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
42
|
+
ParallelLMHead,
|
43
|
+
VocabParallelEmbedding,
|
44
|
+
)
|
45
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
46
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
47
|
+
from sglang.srt.model_loader.weight_utils import (
|
48
|
+
default_weight_loader,
|
49
|
+
kv_cache_scales_loader,
|
50
|
+
maybe_remap_kv_scale_name,
|
51
|
+
)
|
52
|
+
from sglang.srt.utils import add_prefix, make_layers
|
53
|
+
|
54
|
+
logger = logging.getLogger(__name__)
|
55
|
+
|
56
|
+
|
57
|
+
class ArceeMLP(nn.Module):
|
58
|
+
"""
|
59
|
+
MLP block for the Arcee model, using a ReLU-squared activation function.
|
60
|
+
This differs from the Llama SwiGLU activation.
|
61
|
+
"""
|
62
|
+
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
hidden_size: int,
|
66
|
+
intermediate_size: int,
|
67
|
+
hidden_act: str,
|
68
|
+
quant_config: Optional[QuantizationConfig] = None,
|
69
|
+
prefix: str = "",
|
70
|
+
reduce_results: bool = True,
|
71
|
+
) -> None:
|
72
|
+
super().__init__()
|
73
|
+
# Arcee uses a single up-projection, not a merged gate/up projection.
|
74
|
+
self.up_proj = ColumnParallelLinear(
|
75
|
+
hidden_size,
|
76
|
+
intermediate_size,
|
77
|
+
bias=False,
|
78
|
+
quant_config=quant_config,
|
79
|
+
prefix=add_prefix("up_proj", prefix),
|
80
|
+
)
|
81
|
+
self.down_proj = RowParallelLinear(
|
82
|
+
intermediate_size,
|
83
|
+
hidden_size,
|
84
|
+
bias=False,
|
85
|
+
quant_config=quant_config,
|
86
|
+
prefix=add_prefix("down_proj", prefix),
|
87
|
+
reduce_results=reduce_results,
|
88
|
+
)
|
89
|
+
if hidden_act != "relu2":
|
90
|
+
raise ValueError(
|
91
|
+
f"Unsupported activation: {hidden_act}. "
|
92
|
+
"Arcee model in SGLang only supports 'relu2'."
|
93
|
+
)
|
94
|
+
# The activation function is relu(x)^2
|
95
|
+
self.act_fn = get_act_fn("relu2")
|
96
|
+
|
97
|
+
def forward(self, x, forward_batch=None):
|
98
|
+
x, _ = self.up_proj(x)
|
99
|
+
x = self.act_fn(x)
|
100
|
+
x, _ = self.down_proj(x)
|
101
|
+
return x
|
102
|
+
|
103
|
+
|
104
|
+
class ArceeAttention(nn.Module):
|
105
|
+
def __init__(
|
106
|
+
self,
|
107
|
+
config: LlamaConfig,
|
108
|
+
hidden_size: int,
|
109
|
+
num_heads: int,
|
110
|
+
num_kv_heads: int,
|
111
|
+
layer_id: int = 0,
|
112
|
+
rope_theta: float = 10000,
|
113
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
114
|
+
rope_is_neox_style: bool = True,
|
115
|
+
max_position_embeddings: int = 8192,
|
116
|
+
quant_config: Optional[QuantizationConfig] = None,
|
117
|
+
prefix: str = "",
|
118
|
+
bias: bool = False,
|
119
|
+
) -> None:
|
120
|
+
super().__init__()
|
121
|
+
self.hidden_size = hidden_size
|
122
|
+
tp_size = get_tensor_model_parallel_world_size()
|
123
|
+
self.total_num_heads = num_heads
|
124
|
+
assert self.total_num_heads % tp_size == 0
|
125
|
+
self.num_heads = self.total_num_heads // tp_size
|
126
|
+
self.total_num_kv_heads = num_kv_heads
|
127
|
+
if self.total_num_kv_heads >= tp_size:
|
128
|
+
assert self.total_num_kv_heads % tp_size == 0
|
129
|
+
else:
|
130
|
+
assert tp_size % self.total_num_kv_heads == 0
|
131
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
132
|
+
self.head_dim = getattr(config, "head_dim", None)
|
133
|
+
if self.head_dim is None:
|
134
|
+
self.head_dim = self.hidden_size // self.total_num_heads
|
135
|
+
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
136
|
+
self.rotary_dim = int(self.partial_rotary_factor * self.head_dim)
|
137
|
+
self.q_size = self.num_heads * self.head_dim
|
138
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
139
|
+
self.scaling = self.head_dim**-0.5
|
140
|
+
self.rope_theta = rope_theta
|
141
|
+
self.max_position_embeddings = max_position_embeddings
|
142
|
+
|
143
|
+
self.qkv_proj = QKVParallelLinear(
|
144
|
+
hidden_size,
|
145
|
+
self.head_dim,
|
146
|
+
self.total_num_heads,
|
147
|
+
self.total_num_kv_heads,
|
148
|
+
bias=bias,
|
149
|
+
quant_config=quant_config,
|
150
|
+
prefix=add_prefix("qkv_proj", prefix),
|
151
|
+
)
|
152
|
+
self.o_proj = RowParallelLinear(
|
153
|
+
self.total_num_heads * self.head_dim,
|
154
|
+
hidden_size,
|
155
|
+
bias=bias,
|
156
|
+
quant_config=quant_config,
|
157
|
+
prefix=add_prefix("o_proj", prefix),
|
158
|
+
)
|
159
|
+
|
160
|
+
self.rotary_emb = get_rope(
|
161
|
+
self.head_dim,
|
162
|
+
rotary_dim=self.rotary_dim,
|
163
|
+
max_position=max_position_embeddings,
|
164
|
+
base=rope_theta,
|
165
|
+
rope_scaling=rope_scaling,
|
166
|
+
is_neox_style=rope_is_neox_style,
|
167
|
+
)
|
168
|
+
self.attn = RadixAttention(
|
169
|
+
self.num_heads,
|
170
|
+
self.head_dim,
|
171
|
+
self.scaling,
|
172
|
+
num_kv_heads=self.num_kv_heads,
|
173
|
+
layer_id=layer_id,
|
174
|
+
quant_config=quant_config,
|
175
|
+
prefix=add_prefix("attn", prefix),
|
176
|
+
)
|
177
|
+
|
178
|
+
def forward(
|
179
|
+
self,
|
180
|
+
positions: torch.Tensor,
|
181
|
+
hidden_states: torch.Tensor,
|
182
|
+
forward_batch: ForwardBatch,
|
183
|
+
) -> torch.Tensor:
|
184
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
185
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
186
|
+
q, k = self.rotary_emb(positions, q, k)
|
187
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
188
|
+
output, _ = self.o_proj(attn_output)
|
189
|
+
return output
|
190
|
+
|
191
|
+
|
192
|
+
class ArceeDecoderLayer(nn.Module):
|
193
|
+
def __init__(
|
194
|
+
self,
|
195
|
+
config: LlamaConfig,
|
196
|
+
layer_id: int = 0,
|
197
|
+
quant_config: Optional[QuantizationConfig] = None,
|
198
|
+
prefix: str = "",
|
199
|
+
) -> None:
|
200
|
+
super().__init__()
|
201
|
+
self.hidden_size = config.hidden_size
|
202
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
203
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
204
|
+
if rope_scaling is not None and getattr(
|
205
|
+
config, "original_max_position_embeddings", None
|
206
|
+
):
|
207
|
+
rope_scaling["original_max_position_embeddings"] = (
|
208
|
+
config.original_max_position_embeddings
|
209
|
+
)
|
210
|
+
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
211
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
212
|
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
213
|
+
config, "bias", False
|
214
|
+
)
|
215
|
+
self.self_attn = ArceeAttention(
|
216
|
+
config=config,
|
217
|
+
hidden_size=self.hidden_size,
|
218
|
+
num_heads=config.num_attention_heads,
|
219
|
+
num_kv_heads=config.num_key_value_heads,
|
220
|
+
layer_id=layer_id,
|
221
|
+
rope_theta=rope_theta,
|
222
|
+
rope_scaling=rope_scaling,
|
223
|
+
rope_is_neox_style=rope_is_neox_style,
|
224
|
+
max_position_embeddings=max_position_embeddings,
|
225
|
+
quant_config=quant_config,
|
226
|
+
prefix=add_prefix("self_attn", prefix),
|
227
|
+
bias=attention_bias,
|
228
|
+
)
|
229
|
+
self.mlp = ArceeMLP(
|
230
|
+
hidden_size=self.hidden_size,
|
231
|
+
intermediate_size=config.intermediate_size,
|
232
|
+
hidden_act=config.hidden_act,
|
233
|
+
quant_config=quant_config,
|
234
|
+
prefix=add_prefix("mlp", prefix),
|
235
|
+
)
|
236
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
237
|
+
self.post_attention_layernorm = RMSNorm(
|
238
|
+
config.hidden_size, eps=config.rms_norm_eps
|
239
|
+
)
|
240
|
+
|
241
|
+
def forward(
|
242
|
+
self,
|
243
|
+
positions: torch.Tensor,
|
244
|
+
hidden_states: torch.Tensor,
|
245
|
+
forward_batch: ForwardBatch,
|
246
|
+
residual: Optional[torch.Tensor],
|
247
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
248
|
+
# Self Attention
|
249
|
+
if residual is None:
|
250
|
+
residual = hidden_states
|
251
|
+
hidden_states = self.input_layernorm(hidden_states)
|
252
|
+
else:
|
253
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
254
|
+
hidden_states = self.self_attn(
|
255
|
+
positions=positions,
|
256
|
+
hidden_states=hidden_states,
|
257
|
+
forward_batch=forward_batch,
|
258
|
+
)
|
259
|
+
|
260
|
+
# Fully Connected
|
261
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
262
|
+
hidden_states = self.mlp(hidden_states)
|
263
|
+
return hidden_states, residual
|
264
|
+
|
265
|
+
|
266
|
+
class ArceeModel(nn.Module):
|
267
|
+
def __init__(
|
268
|
+
self,
|
269
|
+
config: LlamaConfig,
|
270
|
+
quant_config: Optional[QuantizationConfig] = None,
|
271
|
+
prefix: str = "",
|
272
|
+
) -> None:
|
273
|
+
super().__init__()
|
274
|
+
self.config = config
|
275
|
+
self.padding_idx = config.pad_token_id
|
276
|
+
self.vocab_size = config.vocab_size
|
277
|
+
self.pp_group = get_pp_group()
|
278
|
+
if self.pp_group.is_first_rank:
|
279
|
+
self.embed_tokens = VocabParallelEmbedding(
|
280
|
+
config.vocab_size,
|
281
|
+
config.hidden_size,
|
282
|
+
quant_config=quant_config,
|
283
|
+
prefix=add_prefix("embed_tokens", prefix),
|
284
|
+
)
|
285
|
+
else:
|
286
|
+
self.embed_tokens = PPMissingLayer()
|
287
|
+
|
288
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
289
|
+
config.num_hidden_layers,
|
290
|
+
lambda idx, prefix: ArceeDecoderLayer(
|
291
|
+
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
292
|
+
),
|
293
|
+
pp_rank=self.pp_group.rank_in_group,
|
294
|
+
pp_size=self.pp_group.world_size,
|
295
|
+
prefix="model.layers",
|
296
|
+
)
|
297
|
+
|
298
|
+
if self.pp_group.is_last_rank:
|
299
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
300
|
+
else:
|
301
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
302
|
+
self.layers_to_capture = []
|
303
|
+
|
304
|
+
def forward(
|
305
|
+
self,
|
306
|
+
input_ids: torch.Tensor,
|
307
|
+
positions: torch.Tensor,
|
308
|
+
forward_batch: ForwardBatch,
|
309
|
+
input_embeds: torch.Tensor = None,
|
310
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
311
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
|
312
|
+
if self.pp_group.is_first_rank:
|
313
|
+
if input_embeds is None:
|
314
|
+
hidden_states = self.embed_tokens(input_ids)
|
315
|
+
else:
|
316
|
+
hidden_states = input_embeds
|
317
|
+
residual = None
|
318
|
+
else:
|
319
|
+
assert pp_proxy_tensors is not None
|
320
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
321
|
+
residual = pp_proxy_tensors["residual"]
|
322
|
+
|
323
|
+
aux_hidden_states = []
|
324
|
+
for i in range(self.start_layer, self.end_layer):
|
325
|
+
if i in self.layers_to_capture:
|
326
|
+
aux_hidden_states.append(hidden_states + residual)
|
327
|
+
layer = self.layers[i]
|
328
|
+
hidden_states, residual = layer(
|
329
|
+
positions,
|
330
|
+
hidden_states,
|
331
|
+
forward_batch,
|
332
|
+
residual,
|
333
|
+
)
|
334
|
+
|
335
|
+
if not self.pp_group.is_last_rank:
|
336
|
+
return PPProxyTensors(
|
337
|
+
{
|
338
|
+
"hidden_states": hidden_states,
|
339
|
+
"residual": residual,
|
340
|
+
}
|
341
|
+
)
|
342
|
+
else:
|
343
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
344
|
+
|
345
|
+
if len(aux_hidden_states) == 0:
|
346
|
+
return hidden_states
|
347
|
+
|
348
|
+
return hidden_states, aux_hidden_states
|
349
|
+
|
350
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
351
|
+
tp_size = get_tensor_model_parallel_world_size()
|
352
|
+
tp_rank = get_tensor_model_parallel_rank()
|
353
|
+
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
354
|
+
quantization_param_path,
|
355
|
+
tp_rank,
|
356
|
+
tp_size,
|
357
|
+
self.config.num_hidden_layers,
|
358
|
+
self.config.__class__.model_type,
|
359
|
+
):
|
360
|
+
if not isinstance(self.layers[layer_idx], nn.Identity):
|
361
|
+
layer_self_attn = self.layers[layer_idx].self_attn
|
362
|
+
|
363
|
+
if hasattr(layer_self_attn.attn, "k_scale"):
|
364
|
+
layer_self_attn.attn.k_scale = scaling_factor
|
365
|
+
layer_self_attn.attn.v_scale = scaling_factor
|
366
|
+
else:
|
367
|
+
raise RuntimeError(
|
368
|
+
"Self attention has no KV cache scaling factor attribute!"
|
369
|
+
)
|
370
|
+
|
371
|
+
|
372
|
+
class ArceeForCausalLM(nn.Module):
|
373
|
+
# BitandBytes specific attributes
|
374
|
+
default_bitsandbytes_target_modules = [
|
375
|
+
# Note: gate_proj is removed compared to Llama
|
376
|
+
".down_proj.",
|
377
|
+
".up_proj.",
|
378
|
+
".q_proj.",
|
379
|
+
".k_proj.",
|
380
|
+
".v_proj.",
|
381
|
+
".o_proj.",
|
382
|
+
]
|
383
|
+
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
384
|
+
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
385
|
+
bitsandbytes_stacked_params_mapping = {
|
386
|
+
# shard_name, weight_name, index
|
387
|
+
# Note: gate_proj and up_proj are removed as they are not stacked in ArceeMLP
|
388
|
+
".q_proj": (".qkv_proj", 0),
|
389
|
+
".k_proj": (".qkv_proj", 1),
|
390
|
+
".v_proj": (".qkv_proj", 2),
|
391
|
+
}
|
392
|
+
|
393
|
+
def __init__(
|
394
|
+
self,
|
395
|
+
config: LlamaConfig,
|
396
|
+
quant_config: Optional[QuantizationConfig] = None,
|
397
|
+
prefix: str = "",
|
398
|
+
) -> None:
|
399
|
+
super().__init__()
|
400
|
+
self.pp_group = get_pp_group()
|
401
|
+
self.config = config
|
402
|
+
self.quant_config = quant_config
|
403
|
+
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
|
404
|
+
# Arcee does not tie word embeddings
|
405
|
+
self.lm_head = ParallelLMHead(
|
406
|
+
config.vocab_size,
|
407
|
+
config.hidden_size,
|
408
|
+
quant_config=quant_config,
|
409
|
+
prefix=add_prefix("lm_head", prefix),
|
410
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
411
|
+
)
|
412
|
+
self.logits_processor = LogitsProcessor(config)
|
413
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
414
|
+
# Parameters that are stacked in a single tensor in this model
|
415
|
+
self.stacked_params_mapping = [
|
416
|
+
# (param_name, shard_name, shard_id)
|
417
|
+
(".qkv_proj", ".q_proj", "q"),
|
418
|
+
(".qkv_proj", ".k_proj", "k"),
|
419
|
+
(".qkv_proj", ".v_proj", "v"),
|
420
|
+
]
|
421
|
+
self.capture_aux_hidden_states = False
|
422
|
+
|
423
|
+
def _init_model(
|
424
|
+
self,
|
425
|
+
config: LlamaConfig,
|
426
|
+
quant_config: Optional[QuantizationConfig] = None,
|
427
|
+
prefix: str = "",
|
428
|
+
):
|
429
|
+
return ArceeModel(config, quant_config=quant_config, prefix=prefix)
|
430
|
+
|
431
|
+
@torch.no_grad()
|
432
|
+
def forward(
|
433
|
+
self,
|
434
|
+
input_ids: torch.Tensor,
|
435
|
+
positions: torch.Tensor,
|
436
|
+
forward_batch: ForwardBatch,
|
437
|
+
input_embeds: torch.Tensor = None,
|
438
|
+
get_embedding: bool = False,
|
439
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
440
|
+
) -> LogitsProcessorOutput:
|
441
|
+
hidden_states = self.model(
|
442
|
+
input_ids,
|
443
|
+
positions,
|
444
|
+
forward_batch,
|
445
|
+
input_embeds,
|
446
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
447
|
+
)
|
448
|
+
|
449
|
+
aux_hidden_states = None
|
450
|
+
if self.capture_aux_hidden_states:
|
451
|
+
hidden_states, aux_hidden_states = hidden_states
|
452
|
+
|
453
|
+
if self.pp_group.is_last_rank:
|
454
|
+
if not get_embedding:
|
455
|
+
return self.logits_processor(
|
456
|
+
input_ids,
|
457
|
+
hidden_states,
|
458
|
+
self.lm_head,
|
459
|
+
forward_batch,
|
460
|
+
aux_hidden_states,
|
461
|
+
)
|
462
|
+
else:
|
463
|
+
return self.pooler(hidden_states, forward_batch)
|
464
|
+
else:
|
465
|
+
return hidden_states
|
466
|
+
|
467
|
+
@property
|
468
|
+
def start_layer(self):
|
469
|
+
return self.model.start_layer
|
470
|
+
|
471
|
+
@property
|
472
|
+
def end_layer(self):
|
473
|
+
return self.model.end_layer
|
474
|
+
|
475
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
476
|
+
return self.model.embed_tokens
|
477
|
+
|
478
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
479
|
+
params_dict = dict(self.named_parameters())
|
480
|
+
|
481
|
+
for name, loaded_weight in weights:
|
482
|
+
layer_id = get_layer_id(name)
|
483
|
+
if (
|
484
|
+
layer_id is not None
|
485
|
+
and hasattr(self.model, "start_layer")
|
486
|
+
and (
|
487
|
+
layer_id < self.model.start_layer
|
488
|
+
or layer_id >= self.model.end_layer
|
489
|
+
)
|
490
|
+
):
|
491
|
+
continue
|
492
|
+
if "rotary_emb.inv_freq" in name or "projector" in name:
|
493
|
+
continue
|
494
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
495
|
+
continue
|
496
|
+
|
497
|
+
# Handle FP8 kv-scale remapping
|
498
|
+
if "scale" in name:
|
499
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
500
|
+
if name is None:
|
501
|
+
continue
|
502
|
+
|
503
|
+
is_stacked = False
|
504
|
+
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
505
|
+
if weight_name not in name:
|
506
|
+
continue
|
507
|
+
|
508
|
+
name = name.replace(weight_name, param_name)
|
509
|
+
if name not in params_dict:
|
510
|
+
continue
|
511
|
+
|
512
|
+
param = params_dict[name]
|
513
|
+
weight_loader = param.weight_loader
|
514
|
+
weight_loader(param, loaded_weight, shard_id)
|
515
|
+
is_stacked = True
|
516
|
+
break
|
517
|
+
|
518
|
+
if not is_stacked:
|
519
|
+
if name in params_dict:
|
520
|
+
param = params_dict[name]
|
521
|
+
weight_loader = getattr(
|
522
|
+
param, "weight_loader", default_weight_loader
|
523
|
+
)
|
524
|
+
weight_loader(param, loaded_weight)
|
525
|
+
else:
|
526
|
+
logger.warning(f"Parameter {name} not found in model.")
|
527
|
+
|
528
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
529
|
+
self.model.load_kv_cache_scales(quantization_param_path)
|
530
|
+
|
531
|
+
|
532
|
+
EntryClass = [ArceeForCausalLM]
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module):
|
|
325
325
|
num_experts=config.n_routed_experts
|
326
326
|
+ self.num_fused_shared_experts
|
327
327
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
328
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
328
329
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
329
330
|
hidden_size=config.hidden_size,
|
330
331
|
intermediate_size=config.moe_intermediate_size,
|
@@ -594,41 +595,13 @@ class DeepseekV2MoE(nn.Module):
|
|
594
595
|
topk_weights = torch.empty(
|
595
596
|
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
596
597
|
)
|
597
|
-
|
598
|
-
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
599
|
-
(
|
600
|
-
hidden_states,
|
601
|
-
topk_idx,
|
602
|
-
topk_weights,
|
603
|
-
reorder_topk_ids,
|
604
|
-
num_recv_tokens_per_expert,
|
605
|
-
seg_indptr,
|
606
|
-
masked_m,
|
607
|
-
expected_m,
|
608
|
-
) = self.deepep_dispatcher.dispatch(
|
609
|
-
hidden_states=hidden_states,
|
610
|
-
topk_idx=topk_idx,
|
611
|
-
topk_weights=topk_weights,
|
612
|
-
forward_batch=forward_batch,
|
613
|
-
)
|
598
|
+
|
614
599
|
final_hidden_states = self.experts(
|
615
600
|
hidden_states=hidden_states,
|
616
601
|
topk_idx=topk_idx,
|
617
602
|
topk_weights=topk_weights,
|
618
|
-
reorder_topk_ids=reorder_topk_ids,
|
619
|
-
seg_indptr=seg_indptr,
|
620
|
-
masked_m=masked_m,
|
621
|
-
expected_m=expected_m,
|
622
|
-
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
623
603
|
forward_batch=forward_batch,
|
624
604
|
)
|
625
|
-
if self.ep_size > 1:
|
626
|
-
final_hidden_states = self.deepep_dispatcher.combine(
|
627
|
-
hidden_states=final_hidden_states,
|
628
|
-
topk_idx=topk_idx,
|
629
|
-
topk_weights=topk_weights,
|
630
|
-
forward_batch=forward_batch,
|
631
|
-
)
|
632
605
|
|
633
606
|
if shared_output is not None:
|
634
607
|
x = shared_output
|
@@ -689,8 +662,7 @@ class DeepseekV2MoE(nn.Module):
|
|
689
662
|
|
690
663
|
def op_dispatch_a(self, state):
|
691
664
|
if self.ep_size > 1:
|
692
|
-
|
693
|
-
self.deepep_dispatcher.dispatch_a(
|
665
|
+
self.experts.deepep_dispatcher.dispatch_a(
|
694
666
|
hidden_states=state.hidden_states_mlp_input,
|
695
667
|
topk_idx=state.pop("topk_idx_local"),
|
696
668
|
topk_weights=state.pop("topk_weights_local"),
|
@@ -703,46 +675,32 @@ class DeepseekV2MoE(nn.Module):
|
|
703
675
|
with get_global_expert_distribution_recorder().with_current_layer(
|
704
676
|
self.layer_id
|
705
677
|
):
|
706
|
-
(
|
707
|
-
state.hidden_states_experts_input,
|
708
|
-
state.topk_idx_dispatched,
|
709
|
-
state.topk_weights_dispatched,
|
710
|
-
state.reorder_topk_ids,
|
711
|
-
state.num_recv_tokens_per_expert,
|
712
|
-
state.seg_indptr,
|
713
|
-
state.masked_m,
|
714
|
-
state.expected_m,
|
715
|
-
) = self.deepep_dispatcher.dispatch_b(
|
678
|
+
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
|
716
679
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
717
680
|
)
|
718
681
|
|
719
682
|
def op_experts(self, state):
|
720
|
-
state.hidden_states_experts_output = self.experts(
|
721
|
-
|
722
|
-
topk_idx=state.topk_idx_dispatched,
|
723
|
-
topk_weights=state.topk_weights_dispatched,
|
724
|
-
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
725
|
-
seg_indptr=state.pop("seg_indptr"),
|
726
|
-
masked_m=state.pop("masked_m"),
|
727
|
-
expected_m=state.pop("expected_m"),
|
728
|
-
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
729
|
-
forward_batch=state.forward_batch,
|
683
|
+
state.hidden_states_experts_output = self.experts.moe_impl(
|
684
|
+
dispatch_output=state.dispatch_output,
|
730
685
|
)
|
731
686
|
|
732
687
|
def op_combine_a(self, state):
|
733
688
|
if self.ep_size > 1:
|
734
|
-
self.deepep_dispatcher.combine_a(
|
689
|
+
self.experts.deepep_dispatcher.combine_a(
|
735
690
|
hidden_states=state.pop("hidden_states_experts_output"),
|
736
|
-
topk_idx=state.
|
737
|
-
topk_weights=state.
|
691
|
+
topk_idx=state.dispatch_output.topk_idx,
|
692
|
+
topk_weights=state.dispatch_output.topk_weights,
|
738
693
|
forward_batch=state.forward_batch,
|
739
694
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
740
695
|
)
|
696
|
+
state.pop("dispatch_output")
|
741
697
|
|
742
698
|
def op_combine_b(self, state):
|
743
699
|
if self.ep_size > 1:
|
744
|
-
state.hidden_states_after_combine =
|
745
|
-
|
700
|
+
state.hidden_states_after_combine = (
|
701
|
+
self.experts.deepep_dispatcher.combine_b(
|
702
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
703
|
+
)
|
746
704
|
)
|
747
705
|
|
748
706
|
def op_output(self, state):
|
@@ -2155,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2155
2113
|
|
2156
2114
|
if disable_reason is not None:
|
2157
2115
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
2116
|
+
self.num_fused_shared_experts = 0
|
2158
2117
|
log_info_on_rank0(
|
2159
2118
|
logger,
|
2160
2119
|
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
sglang/srt/models/glm4_moe.py
CHANGED
@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
434
434
|
num_experts=config.n_routed_experts
|
435
435
|
+ self.num_fused_shared_experts
|
436
436
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
437
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
437
438
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
438
439
|
hidden_size=config.hidden_size,
|
439
440
|
intermediate_size=config.moe_intermediate_size,
|
@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|
740
741
|
global_server_args_dict["enable_deepep_moe"]
|
741
742
|
or global_server_args_dict["enable_ep_moe"]
|
742
743
|
):
|
743
|
-
disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
744
|
+
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
744
745
|
|
745
746
|
if disable_reason is not None:
|
746
747
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
748
|
+
self.num_fused_shared_experts = 0
|
747
749
|
log_info_on_rank0(
|
748
750
|
logger,
|
749
751
|
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
sglang/srt/models/granitemoe.py
CHANGED
@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
|
|
43
43
|
top_k: int,
|
44
44
|
hidden_size: int,
|
45
45
|
intermediate_size: int,
|
46
|
+
layer_id: int,
|
46
47
|
params_dtype: Optional[torch.dtype] = None,
|
47
48
|
quant_config: Optional[QuantizationConfig] = None,
|
48
49
|
tp_size: Optional[int] = None,
|
@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
|
|
71
72
|
top_k=top_k,
|
72
73
|
hidden_size=hidden_size,
|
73
74
|
intermediate_size=intermediate_size,
|
75
|
+
layer_id=layer_id,
|
74
76
|
params_dtype=params_dtype,
|
75
77
|
reduce_results=True,
|
76
78
|
quant_config=quant_config,
|
@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
|
|
203
205
|
top_k=config.num_experts_per_tok,
|
204
206
|
hidden_size=config.hidden_size,
|
205
207
|
intermediate_size=config.intermediate_size,
|
208
|
+
layer_id=layer_id,
|
206
209
|
quant_config=quant_config,
|
207
210
|
prefix=f"{prefix}.block_sparse_moe",
|
208
211
|
)
|