sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +4 -3
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/launch_server.py +3 -2
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -4
- sglang/srt/managers/image_processor.py +1 -1
- sglang/srt/managers/io_struct.py +48 -12
- sglang/srt/managers/schedule_batch.py +42 -36
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +111 -46
- sglang/srt/managers/session_controller.py +0 -3
- sglang/srt/managers/tokenizer_manager.py +169 -100
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +14 -51
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +10 -12
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +391 -0
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +12 -9
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +10 -6
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +303 -204
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +253 -48
- sglang/test/test_utils.py +27 -7
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- sglang-0.3.6.post2.dist-info/RECORD +0 -164
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,391 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
# Adapted from
|
16
|
+
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/olmo2.py
|
17
|
+
"""Inference-only OLMo2 model compatible with HuggingFace weights."""
|
18
|
+
from functools import partial
|
19
|
+
from typing import Iterable, Optional, Tuple
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch import nn
|
23
|
+
from transformers import PretrainedConfig
|
24
|
+
from vllm.distributed import (
|
25
|
+
get_tensor_model_parallel_rank,
|
26
|
+
get_tensor_model_parallel_world_size,
|
27
|
+
split_tensor_along_last_dim,
|
28
|
+
tensor_model_parallel_all_gather,
|
29
|
+
)
|
30
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
31
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
32
|
+
|
33
|
+
from sglang.srt.layers.activation import SiluAndMul
|
34
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
35
|
+
from sglang.srt.layers.linear import (
|
36
|
+
MergedColumnParallelLinear,
|
37
|
+
QKVParallelLinear,
|
38
|
+
RowParallelLinear,
|
39
|
+
)
|
40
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
44
|
+
ParallelLMHead,
|
45
|
+
VocabParallelEmbedding,
|
46
|
+
)
|
47
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
|
+
from sglang.srt.utils import make_layers
|
49
|
+
|
50
|
+
|
51
|
+
class Olmo2Attention(nn.Module):
|
52
|
+
"""
|
53
|
+
This is the attention block where the output is computed as
|
54
|
+
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
|
55
|
+
(plus another skip connection).
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __init__(
|
59
|
+
self,
|
60
|
+
config: PretrainedConfig,
|
61
|
+
layer_id: int = 0,
|
62
|
+
quant_config: Optional[QuantizationConfig] = None,
|
63
|
+
):
|
64
|
+
super().__init__()
|
65
|
+
self.config = config
|
66
|
+
self.hidden_size = config.hidden_size
|
67
|
+
tp_size = get_tensor_model_parallel_world_size()
|
68
|
+
self.total_num_heads = config.num_attention_heads
|
69
|
+
|
70
|
+
assert self.hidden_size % self.total_num_heads == 0
|
71
|
+
assert self.total_num_heads % tp_size == 0
|
72
|
+
|
73
|
+
self.num_heads = self.total_num_heads // tp_size
|
74
|
+
self.total_num_kv_heads = self.config.num_key_value_heads
|
75
|
+
|
76
|
+
if self.total_num_kv_heads >= tp_size:
|
77
|
+
# Number of KV heads is greater than TP size, so we partition
|
78
|
+
# the KV heads across multiple tensor parallel GPUs.
|
79
|
+
assert self.total_num_kv_heads % tp_size == 0
|
80
|
+
else:
|
81
|
+
# Number of KV heads is less than TP size, so we replicate
|
82
|
+
# the KV heads across multiple tensor parallel GPUs.
|
83
|
+
assert tp_size % self.total_num_kv_heads == 0
|
84
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
85
|
+
|
86
|
+
self.head_dim = self.hidden_size // self.total_num_heads
|
87
|
+
self.max_position_embeddings = config.max_position_embeddings
|
88
|
+
self.rope_theta = config.rope_theta
|
89
|
+
|
90
|
+
# Attention input projection. Projects x -> (q, k, v)
|
91
|
+
self.qkv_proj = QKVParallelLinear(
|
92
|
+
self.hidden_size,
|
93
|
+
self.head_dim,
|
94
|
+
self.total_num_heads,
|
95
|
+
bias=config.attention_bias,
|
96
|
+
)
|
97
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
98
|
+
|
99
|
+
self.k_norm = RMSNorm(
|
100
|
+
self.total_num_kv_heads * self.head_dim,
|
101
|
+
eps=self.config.rms_norm_eps,
|
102
|
+
)
|
103
|
+
self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
104
|
+
# Rotary embeddings.
|
105
|
+
self.rotary_emb = get_rope(
|
106
|
+
self.head_dim,
|
107
|
+
rotary_dim=self.head_dim,
|
108
|
+
max_position=self.max_position_embeddings,
|
109
|
+
base=self.rope_theta,
|
110
|
+
)
|
111
|
+
self.scaling = self.head_dim**-0.5
|
112
|
+
self.attn = RadixAttention(
|
113
|
+
self.num_heads,
|
114
|
+
self.head_dim,
|
115
|
+
self.scaling,
|
116
|
+
num_kv_heads=self.num_kv_heads,
|
117
|
+
layer_id=layer_id,
|
118
|
+
)
|
119
|
+
|
120
|
+
# Attention output projection.
|
121
|
+
self.o_proj = RowParallelLinear(
|
122
|
+
self.head_dim * self.total_num_heads,
|
123
|
+
self.hidden_size,
|
124
|
+
bias=config.attention_bias,
|
125
|
+
)
|
126
|
+
|
127
|
+
def _apply_qk_norm(
|
128
|
+
self, q: torch.Tensor, k: torch.Tensor
|
129
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
130
|
+
if self.tp_size > 1:
|
131
|
+
q = tensor_model_parallel_all_gather(q.contiguous())
|
132
|
+
k = tensor_model_parallel_all_gather(k.contiguous())
|
133
|
+
q = self.q_norm.forward_native(q)
|
134
|
+
k = self.k_norm.forward_native(k)
|
135
|
+
if self.tp_size > 1:
|
136
|
+
splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
|
137
|
+
q = splitter(q)[self.tp_rank]
|
138
|
+
k = splitter(k)[self.tp_rank]
|
139
|
+
return q, k
|
140
|
+
|
141
|
+
def forward(
|
142
|
+
self,
|
143
|
+
positions: torch.Tensor,
|
144
|
+
hidden_states: torch.Tensor,
|
145
|
+
forward_batch: ForwardBatch,
|
146
|
+
) -> torch.Tensor:
|
147
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
148
|
+
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
149
|
+
q, k = self._apply_qk_norm(q, k)
|
150
|
+
q, k = self.rotary_emb(positions, q, k)
|
151
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
152
|
+
output, _ = self.o_proj(attn_output)
|
153
|
+
return output
|
154
|
+
|
155
|
+
|
156
|
+
class Olmo2MLP(nn.Module):
|
157
|
+
"""
|
158
|
+
This is the MLP block where the output is computed as
|
159
|
+
``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))``
|
160
|
+
(plus another skip connection).
|
161
|
+
"""
|
162
|
+
|
163
|
+
def __init__(
|
164
|
+
self,
|
165
|
+
config: PretrainedConfig,
|
166
|
+
quant_config: Optional[QuantizationConfig] = None,
|
167
|
+
):
|
168
|
+
super().__init__()
|
169
|
+
self.config = config
|
170
|
+
self.hidden_size = config.hidden_size
|
171
|
+
self.intermediate_size = config.intermediate_size
|
172
|
+
|
173
|
+
# Feed-forward input projection.
|
174
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
175
|
+
self.hidden_size,
|
176
|
+
[self.intermediate_size] * 2,
|
177
|
+
bias=False,
|
178
|
+
quant_config=quant_config,
|
179
|
+
)
|
180
|
+
|
181
|
+
# Activation function.
|
182
|
+
self.act_fn = SiluAndMul()
|
183
|
+
|
184
|
+
# Feed-forward output projection.
|
185
|
+
self.down_proj = RowParallelLinear(
|
186
|
+
self.intermediate_size,
|
187
|
+
self.hidden_size,
|
188
|
+
bias=False,
|
189
|
+
quant_config=quant_config,
|
190
|
+
)
|
191
|
+
|
192
|
+
def forward(
|
193
|
+
self,
|
194
|
+
x: torch.Tensor,
|
195
|
+
) -> torch.Tensor:
|
196
|
+
gate_up, _ = self.gate_up_proj(x)
|
197
|
+
x = self.act_fn(gate_up)
|
198
|
+
x, _ = self.down_proj(x)
|
199
|
+
return x
|
200
|
+
|
201
|
+
|
202
|
+
class Olmo2DecoderLayer(nn.Module):
|
203
|
+
"""
|
204
|
+
This is a typical transformer block where the output is
|
205
|
+
computed as ``MLP(LN(x + Attention(LN(x))))``
|
206
|
+
(plus another skip connection).
|
207
|
+
"""
|
208
|
+
|
209
|
+
def __init__(
|
210
|
+
self,
|
211
|
+
config: PretrainedConfig,
|
212
|
+
layer_id: int = 0,
|
213
|
+
quant_config: Optional[QuantizationConfig] = None,
|
214
|
+
):
|
215
|
+
super().__init__()
|
216
|
+
# Attention block.
|
217
|
+
self.self_attn = Olmo2Attention(config, layer_id, quant_config)
|
218
|
+
|
219
|
+
# MLP block.
|
220
|
+
self.mlp = Olmo2MLP(config, quant_config)
|
221
|
+
|
222
|
+
# RMSNorm
|
223
|
+
self.post_attention_layernorm = RMSNorm(
|
224
|
+
config.hidden_size, eps=config.rms_norm_eps
|
225
|
+
)
|
226
|
+
|
227
|
+
self.post_feedforward_layernorm = RMSNorm(
|
228
|
+
config.hidden_size, eps=config.rms_norm_eps
|
229
|
+
)
|
230
|
+
|
231
|
+
def forward(
|
232
|
+
self,
|
233
|
+
positions: torch.Tensor,
|
234
|
+
hidden_states: torch.Tensor,
|
235
|
+
forward_batch: ForwardBatch,
|
236
|
+
) -> torch.Tensor:
|
237
|
+
# Attention block.
|
238
|
+
residual = hidden_states
|
239
|
+
hidden_states = self.self_attn(positions, hidden_states, forward_batch)
|
240
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
241
|
+
hidden_states = hidden_states + residual
|
242
|
+
|
243
|
+
# MLP block.
|
244
|
+
residual = hidden_states
|
245
|
+
hidden_states = self.mlp(hidden_states)
|
246
|
+
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
247
|
+
hidden_states = residual + hidden_states
|
248
|
+
return hidden_states
|
249
|
+
|
250
|
+
|
251
|
+
class Olmo2Model(nn.Module):
|
252
|
+
|
253
|
+
def __init__(
|
254
|
+
self,
|
255
|
+
config: PretrainedConfig,
|
256
|
+
quant_config: Optional[QuantizationConfig] = None,
|
257
|
+
):
|
258
|
+
super().__init__()
|
259
|
+
self.config = config
|
260
|
+
|
261
|
+
self.embed_tokens = VocabParallelEmbedding(
|
262
|
+
config.vocab_size, config.hidden_size
|
263
|
+
)
|
264
|
+
self.layers = make_layers(
|
265
|
+
config.num_hidden_layers,
|
266
|
+
lambda idx, prefix: Olmo2DecoderLayer(
|
267
|
+
layer_id=idx,
|
268
|
+
config=config,
|
269
|
+
quant_config=quant_config,
|
270
|
+
),
|
271
|
+
)
|
272
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
273
|
+
|
274
|
+
def forward(
|
275
|
+
self,
|
276
|
+
input_ids: torch.Tensor,
|
277
|
+
positions: torch.Tensor,
|
278
|
+
forward_batch: ForwardBatch,
|
279
|
+
input_embeds: torch.Tensor = None,
|
280
|
+
) -> torch.Tensor:
|
281
|
+
"""
|
282
|
+
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
283
|
+
"""
|
284
|
+
# Get embeddings of input.
|
285
|
+
# shape: (batch_size, seq_len, d_model)
|
286
|
+
|
287
|
+
if input_embeds is None:
|
288
|
+
hidden_states = self.embed_tokens(input_ids)
|
289
|
+
else:
|
290
|
+
hidden_states = input_embeds
|
291
|
+
|
292
|
+
# Apply blocks one-by-one.
|
293
|
+
for layer_id, decoder_layer in enumerate(self.layers):
|
294
|
+
# shape: (batch_size, seq_len, d_model)
|
295
|
+
hidden_states = decoder_layer(
|
296
|
+
positions,
|
297
|
+
hidden_states,
|
298
|
+
forward_batch,
|
299
|
+
)
|
300
|
+
|
301
|
+
# Apply final layer norm.
|
302
|
+
# shape: (batch_size, seq_len or 1, d_model)
|
303
|
+
hidden_states = self.norm(hidden_states)
|
304
|
+
return hidden_states
|
305
|
+
|
306
|
+
|
307
|
+
class Olmo2ForCausalLM(nn.Module):
|
308
|
+
"""
|
309
|
+
Extremely barebones HF model wrapper.
|
310
|
+
"""
|
311
|
+
|
312
|
+
def __init__(
|
313
|
+
self,
|
314
|
+
config: PretrainedConfig,
|
315
|
+
quant_config: Optional[QuantizationConfig] = None,
|
316
|
+
):
|
317
|
+
super().__init__()
|
318
|
+
self.config = config
|
319
|
+
self.model = Olmo2Model(config, quant_config)
|
320
|
+
if config.tie_word_embeddings:
|
321
|
+
self.lm_head = self.model.embed_tokens
|
322
|
+
else:
|
323
|
+
self.unpadded_vocab_size = config.vocab_size
|
324
|
+
self.lm_head = ParallelLMHead(
|
325
|
+
self.unpadded_vocab_size,
|
326
|
+
config.hidden_size,
|
327
|
+
org_num_embeddings=config.vocab_size,
|
328
|
+
quant_config=quant_config,
|
329
|
+
)
|
330
|
+
self.logits_processor = LogitsProcessor(config)
|
331
|
+
|
332
|
+
def forward(
|
333
|
+
self,
|
334
|
+
input_ids: torch.Tensor,
|
335
|
+
positions: torch.Tensor,
|
336
|
+
forward_batch: ForwardBatch,
|
337
|
+
input_embeds: torch.Tensor = None,
|
338
|
+
) -> torch.Tensor:
|
339
|
+
hidden_states = self.model(
|
340
|
+
input_ids=input_ids,
|
341
|
+
positions=positions,
|
342
|
+
forward_batch=forward_batch,
|
343
|
+
input_embeds=input_embeds,
|
344
|
+
)
|
345
|
+
return self.logits_processor(
|
346
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
347
|
+
)
|
348
|
+
|
349
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
350
|
+
stacked_params_mapping = [
|
351
|
+
# (param_name, shard_name, shard_id)
|
352
|
+
("qkv_proj", "q_proj", "q"),
|
353
|
+
("qkv_proj", "k_proj", "k"),
|
354
|
+
("qkv_proj", "v_proj", "v"),
|
355
|
+
("gate_up_proj", "gate_proj", 0),
|
356
|
+
("gate_up_proj", "up_proj", 1),
|
357
|
+
]
|
358
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
359
|
+
for name, loaded_weight in weights:
|
360
|
+
if "rotary_emb.inv_freq" in name:
|
361
|
+
continue
|
362
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
363
|
+
# Models trained using ColossalAI may include these tensors in
|
364
|
+
# the checkpoint. Skip them.
|
365
|
+
continue
|
366
|
+
# With tie_word_embeddings, we can skip lm_head.weight
|
367
|
+
# The weight might appear unnecessarily in the files if the model is
|
368
|
+
# processed with quantization, LoRA, fine-tuning, etc.
|
369
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
370
|
+
continue
|
371
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
372
|
+
if weight_name not in name:
|
373
|
+
continue
|
374
|
+
name = name.replace(weight_name, param_name)
|
375
|
+
# Skip loading extra bias for GPTQ models.
|
376
|
+
if name.endswith(".bias") and name not in params_dict:
|
377
|
+
continue
|
378
|
+
param = params_dict[name]
|
379
|
+
weight_loader = param.weight_loader
|
380
|
+
weight_loader(param, loaded_weight, shard_id)
|
381
|
+
break
|
382
|
+
else:
|
383
|
+
# Skip loading extra bias for GPTQ models.
|
384
|
+
if name.endswith(".bias") and name not in params_dict:
|
385
|
+
continue
|
386
|
+
param = params_dict[name]
|
387
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
388
|
+
weight_loader(param, loaded_weight)
|
389
|
+
|
390
|
+
|
391
|
+
EntryClass = Olmo2ForCausalLM
|
sglang/srt/models/olmoe.py
CHANGED
@@ -34,8 +34,6 @@ from vllm.model_executor.layers.linear import (
|
|
34
34
|
RowParallelLinear,
|
35
35
|
)
|
36
36
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
37
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
38
|
-
from vllm.utils import print_warning_once
|
39
37
|
|
40
38
|
from sglang.srt.layers.activation import SiluAndMul
|
41
39
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
@@ -48,7 +46,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
48
46
|
VocabParallelEmbedding,
|
49
47
|
)
|
50
48
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
51
|
-
from sglang.srt.
|
49
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
50
|
+
from sglang.srt.utils import make_layers, print_warning_once
|
52
51
|
|
53
52
|
|
54
53
|
class OlmoeMoE(nn.Module):
|
@@ -300,7 +299,6 @@ class OlmoeForCausalLM(nn.Module):
|
|
300
299
|
def __init__(
|
301
300
|
self,
|
302
301
|
config: PretrainedConfig,
|
303
|
-
cache_config=None,
|
304
302
|
quant_config: Optional[QuantizationConfig] = None,
|
305
303
|
) -> None:
|
306
304
|
super().__init__()
|
@@ -321,7 +319,7 @@ class OlmoeForCausalLM(nn.Module):
|
|
321
319
|
) -> torch.Tensor:
|
322
320
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
323
321
|
return self.logits_processor(
|
324
|
-
input_ids, hidden_states, self.lm_head
|
322
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
325
323
|
)
|
326
324
|
|
327
325
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/phi3_small.py
CHANGED
@@ -7,8 +7,6 @@ from transformers import Phi3Config
|
|
7
7
|
from transformers.configuration_utils import PretrainedConfig
|
8
8
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
9
9
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
10
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
11
|
-
from vllm.model_executor.models.utils import make_layers
|
12
10
|
|
13
11
|
from sglang.srt.layers.linear import (
|
14
12
|
MergedColumnParallelLinear,
|
@@ -27,6 +25,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
27
25
|
)
|
28
26
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
29
27
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
28
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
29
|
+
from sglang.srt.utils import make_layers
|
30
30
|
|
31
31
|
|
32
32
|
@torch.jit.script
|
@@ -235,7 +235,6 @@ class Phi3SmallDecoderLayer(nn.Module):
|
|
235
235
|
self,
|
236
236
|
config: PretrainedConfig,
|
237
237
|
layer_id: int,
|
238
|
-
cache_config=None,
|
239
238
|
quant_config: Optional[QuantizationConfig] = None,
|
240
239
|
):
|
241
240
|
super().__init__()
|
@@ -286,7 +285,6 @@ class Phi3SmallModel(nn.Module):
|
|
286
285
|
super().__init__()
|
287
286
|
|
288
287
|
self.config = config
|
289
|
-
cache_config = None
|
290
288
|
self.embed_tokens = VocabParallelEmbedding(
|
291
289
|
config.vocab_size, config.hidden_size
|
292
290
|
)
|
@@ -294,7 +292,7 @@ class Phi3SmallModel(nn.Module):
|
|
294
292
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
295
293
|
config.num_hidden_layers,
|
296
294
|
lambda prefix: Phi3SmallDecoderLayer(
|
297
|
-
config, int(prefix.split(".")[-1]),
|
295
|
+
config, int(prefix.split(".")[-1]), quant_config
|
298
296
|
),
|
299
297
|
prefix=f"{prefix}.layers",
|
300
298
|
)
|
@@ -339,7 +337,6 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
339
337
|
self,
|
340
338
|
config: Phi3Config,
|
341
339
|
quant_config: Optional[QuantizationConfig] = None,
|
342
|
-
cache_config=None,
|
343
340
|
):
|
344
341
|
|
345
342
|
super().__init__()
|
@@ -397,10 +394,13 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
397
394
|
|
398
395
|
def compute_logits(
|
399
396
|
self,
|
397
|
+
input_ids: torch.LongTensor,
|
400
398
|
hidden_states: torch.Tensor,
|
401
399
|
sampling_metadata,
|
402
400
|
) -> Optional[torch.Tensor]:
|
403
|
-
logits = self.logits_processor(
|
401
|
+
logits = self.logits_processor(
|
402
|
+
input_ids, self.lm_head, hidden_states, sampling_metadata
|
403
|
+
)
|
404
404
|
if self.dummy_token_indices is not None and logits is not None:
|
405
405
|
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
|
406
406
|
return logits
|
@@ -422,7 +422,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
|
422
422
|
|
423
423
|
if not get_embedding:
|
424
424
|
return self.logits_processor(
|
425
|
-
input_ids, hidden_states, self.lm_head
|
425
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
426
426
|
)
|
427
427
|
|
428
428
|
else:
|
sglang/srt/models/qwen.py
CHANGED
@@ -22,7 +22,6 @@ from torch import nn
|
|
22
22
|
from transformers import PretrainedConfig
|
23
23
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
24
24
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
25
|
|
27
26
|
from sglang.srt.layers.activation import SiluAndMul
|
28
27
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
39
38
|
VocabParallelEmbedding,
|
40
39
|
)
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
42
42
|
|
43
43
|
|
44
44
|
class QWenMLP(nn.Module):
|
@@ -242,7 +242,6 @@ class QWenLMHeadModel(nn.Module):
|
|
242
242
|
self,
|
243
243
|
config: PretrainedConfig,
|
244
244
|
quant_config: Optional[QuantizationConfig] = None,
|
245
|
-
cache_config=None,
|
246
245
|
):
|
247
246
|
super().__init__()
|
248
247
|
self.config = config
|
@@ -260,7 +259,7 @@ class QWenLMHeadModel(nn.Module):
|
|
260
259
|
):
|
261
260
|
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
262
261
|
return self.logits_processor(
|
263
|
-
input_ids, hidden_states, self.lm_head
|
262
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
264
263
|
)
|
265
264
|
|
266
265
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/qwen2.py
CHANGED
@@ -22,7 +22,6 @@ import torch
|
|
22
22
|
from torch import nn
|
23
23
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
24
24
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
25
|
|
27
26
|
from sglang.srt.layers.activation import SiluAndMul
|
28
27
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
40
39
|
VocabParallelEmbedding,
|
41
40
|
)
|
42
41
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
43
|
from sglang.srt.utils import make_layers
|
44
44
|
|
45
45
|
Qwen2Config = None
|
@@ -230,6 +230,7 @@ class Qwen2Model(nn.Module):
|
|
230
230
|
self.embed_tokens = VocabParallelEmbedding(
|
231
231
|
config.vocab_size,
|
232
232
|
config.hidden_size,
|
233
|
+
quant_config=quant_config,
|
233
234
|
)
|
234
235
|
self.layers = make_layers(
|
235
236
|
config.num_hidden_layers,
|
@@ -270,13 +271,17 @@ class Qwen2ForCausalLM(nn.Module):
|
|
270
271
|
self,
|
271
272
|
config: Qwen2Config,
|
272
273
|
quant_config: Optional[QuantizationConfig] = None,
|
273
|
-
cache_config=None,
|
274
274
|
) -> None:
|
275
275
|
super().__init__()
|
276
276
|
self.config = config
|
277
277
|
self.quant_config = quant_config
|
278
278
|
self.model = Qwen2Model(config, quant_config=quant_config)
|
279
|
-
|
279
|
+
if config.tie_word_embeddings:
|
280
|
+
self.lm_head = self.model.embed_tokens
|
281
|
+
else:
|
282
|
+
self.lm_head = ParallelLMHead(
|
283
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
284
|
+
)
|
280
285
|
self.logits_processor = LogitsProcessor(config)
|
281
286
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
282
287
|
|
@@ -292,7 +297,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
292
297
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
293
298
|
if not get_embedding:
|
294
299
|
return self.logits_processor(
|
295
|
-
input_ids, hidden_states, self.lm_head
|
300
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
296
301
|
)
|
297
302
|
else:
|
298
303
|
return self.pooler(hidden_states, forward_batch)
|
@@ -306,6 +311,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
306
311
|
("gate_up_proj", "gate_proj", 0),
|
307
312
|
("gate_up_proj", "up_proj", 1),
|
308
313
|
]
|
314
|
+
|
309
315
|
params_dict = dict(self.named_parameters())
|
310
316
|
for name, loaded_weight in weights:
|
311
317
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
@@ -335,11 +341,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|
335
341
|
param = params_dict[name]
|
336
342
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
337
343
|
weight_loader(param, loaded_weight)
|
338
|
-
if (
|
339
|
-
self.config.tie_word_embeddings
|
340
|
-
and name == "model.embed_tokens.weight"
|
341
|
-
):
|
342
|
-
weight_loader(params_dict["lm_head.weight"], loaded_weight)
|
343
344
|
|
344
345
|
|
345
346
|
EntryClass = Qwen2ForCausalLM
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -27,7 +27,6 @@ from vllm.distributed import (
|
|
27
27
|
tensor_model_parallel_all_reduce,
|
28
28
|
)
|
29
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
31
30
|
|
32
31
|
from sglang.srt.layers.activation import SiluAndMul
|
33
32
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
@@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
48
47
|
)
|
49
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
50
49
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
51
51
|
|
52
52
|
|
53
53
|
class Qwen2MoeMLP(nn.Module):
|
@@ -158,7 +158,6 @@ class Qwen2MoeAttention(nn.Module):
|
|
158
158
|
rope_theta: float = 10000,
|
159
159
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
160
160
|
max_position_embeddings: int = 8192,
|
161
|
-
cache_config=None,
|
162
161
|
quant_config: Optional[QuantizationConfig] = None,
|
163
162
|
) -> None:
|
164
163
|
super().__init__()
|
@@ -234,7 +233,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
234
233
|
self,
|
235
234
|
config: PretrainedConfig,
|
236
235
|
layer_id: int,
|
237
|
-
cache_config=None,
|
238
236
|
quant_config: Optional[QuantizationConfig] = None,
|
239
237
|
) -> None:
|
240
238
|
super().__init__()
|
@@ -250,7 +248,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
250
248
|
rope_theta=rope_theta,
|
251
249
|
rope_scaling=rope_scaling,
|
252
250
|
max_position_embeddings=max_position_embeddings,
|
253
|
-
cache_config=cache_config,
|
254
251
|
quant_config=quant_config,
|
255
252
|
)
|
256
253
|
|
@@ -304,7 +301,6 @@ class Qwen2MoeModel(nn.Module):
|
|
304
301
|
def __init__(
|
305
302
|
self,
|
306
303
|
config: PretrainedConfig,
|
307
|
-
cache_config=None,
|
308
304
|
quant_config: Optional[QuantizationConfig] = None,
|
309
305
|
) -> None:
|
310
306
|
super().__init__()
|
@@ -317,9 +313,7 @@ class Qwen2MoeModel(nn.Module):
|
|
317
313
|
)
|
318
314
|
self.layers = nn.ModuleList(
|
319
315
|
[
|
320
|
-
Qwen2MoeDecoderLayer(
|
321
|
-
config, layer_id, cache_config, quant_config=quant_config
|
322
|
-
)
|
316
|
+
Qwen2MoeDecoderLayer(config, layer_id, quant_config=quant_config)
|
323
317
|
for layer_id in range(config.num_hidden_layers)
|
324
318
|
]
|
325
319
|
)
|
@@ -353,14 +347,13 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
353
347
|
def __init__(
|
354
348
|
self,
|
355
349
|
config: PretrainedConfig,
|
356
|
-
cache_config=None,
|
357
350
|
quant_config: Optional[QuantizationConfig] = None,
|
358
351
|
) -> None:
|
359
352
|
super().__init__()
|
360
353
|
self.config = config
|
361
354
|
self.quant_config = quant_config
|
362
355
|
self.torchao_config = global_server_args_dict["torchao_config"]
|
363
|
-
self.model = Qwen2MoeModel(config,
|
356
|
+
self.model = Qwen2MoeModel(config, quant_config)
|
364
357
|
self.lm_head = ParallelLMHead(
|
365
358
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
366
359
|
)
|
@@ -376,7 +369,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
376
369
|
) -> torch.Tensor:
|
377
370
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
378
371
|
return self.logits_processor(
|
379
|
-
input_ids, hidden_states, self.lm_head
|
372
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
380
373
|
)
|
381
374
|
|
382
375
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|