sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,372 @@
|
|
1
|
+
# Adapted from
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1
|
3
|
+
"""Inference-only Mixtral model."""
|
4
|
+
from typing import Iterable, Optional, Tuple
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
from torch import nn
|
10
|
+
from transformers import MixtralConfig
|
11
|
+
from vllm.config import CacheConfig
|
12
|
+
from vllm.distributed import (
|
13
|
+
get_tensor_model_parallel_rank,
|
14
|
+
get_tensor_model_parallel_world_size,
|
15
|
+
tensor_model_parallel_all_reduce,
|
16
|
+
)
|
17
|
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
18
|
+
from vllm.model_executor.layers.linear import (
|
19
|
+
QKVParallelLinear,
|
20
|
+
ReplicatedLinear,
|
21
|
+
RowParallelLinear,
|
22
|
+
)
|
23
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
24
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
26
|
+
ParallelLMHead,
|
27
|
+
VocabParallelEmbedding,
|
28
|
+
)
|
29
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
30
|
+
|
31
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
32
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
33
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
34
|
+
|
35
|
+
|
36
|
+
class MixtralMLP(nn.Module):
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
num_experts: int,
|
40
|
+
hidden_size: int,
|
41
|
+
intermediate_size: int,
|
42
|
+
quant_config: Optional[QuantizationConfig] = None,
|
43
|
+
) -> None:
|
44
|
+
super().__init__()
|
45
|
+
self.num_experts = num_experts
|
46
|
+
self.ffn_dim = intermediate_size
|
47
|
+
self.hidden_dim = hidden_size
|
48
|
+
|
49
|
+
self.w1 = ReplicatedLinear(
|
50
|
+
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
51
|
+
)
|
52
|
+
self.w2 = ReplicatedLinear(
|
53
|
+
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
|
54
|
+
)
|
55
|
+
self.w3 = ReplicatedLinear(
|
56
|
+
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
57
|
+
)
|
58
|
+
|
59
|
+
# TODO: Use vllm's SiluAndMul
|
60
|
+
self.act_fn = nn.SiLU()
|
61
|
+
|
62
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
63
|
+
w1_out, _ = self.w1(hidden_states)
|
64
|
+
w1_out = self.act_fn(w1_out)
|
65
|
+
w3_out, _ = self.w3(hidden_states)
|
66
|
+
current_hidden_states = w1_out * w3_out
|
67
|
+
current_hidden_states, _ = self.w2(current_hidden_states)
|
68
|
+
return current_hidden_states
|
69
|
+
|
70
|
+
|
71
|
+
class MixtralMoE(nn.Module):
|
72
|
+
def __init__(
|
73
|
+
self,
|
74
|
+
config: MixtralConfig,
|
75
|
+
quant_config: Optional[QuantizationConfig] = None,
|
76
|
+
):
|
77
|
+
super().__init__()
|
78
|
+
self.config = config
|
79
|
+
self.rank = get_tensor_model_parallel_rank()
|
80
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
81
|
+
self.num_total_experts = config.num_local_experts
|
82
|
+
self.top_k = config.num_experts_per_tok
|
83
|
+
if self.tp_size > self.num_total_experts:
|
84
|
+
raise ValueError(
|
85
|
+
f"Tensor parallel size {self.tp_size} is greater than "
|
86
|
+
f"the number of experts {self.num_total_experts}."
|
87
|
+
)
|
88
|
+
# Split experts equally between ranks
|
89
|
+
self.expert_indicies = np.array_split(
|
90
|
+
range(self.num_total_experts), self.tp_size
|
91
|
+
)[self.rank].tolist()
|
92
|
+
if not self.expert_indicies:
|
93
|
+
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
|
94
|
+
|
95
|
+
self.experts = nn.ModuleList(
|
96
|
+
[
|
97
|
+
(
|
98
|
+
MixtralMLP(
|
99
|
+
self.num_total_experts,
|
100
|
+
config.hidden_size,
|
101
|
+
config.intermediate_size,
|
102
|
+
quant_config=quant_config,
|
103
|
+
)
|
104
|
+
if idx in self.expert_indicies
|
105
|
+
else None
|
106
|
+
)
|
107
|
+
for idx in range(self.num_total_experts)
|
108
|
+
]
|
109
|
+
)
|
110
|
+
self.gate = ReplicatedLinear(
|
111
|
+
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
|
112
|
+
)
|
113
|
+
|
114
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
115
|
+
router_logits, _ = self.gate(hidden_states)
|
116
|
+
|
117
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
118
|
+
routing_weights, selected_experts = torch.topk(
|
119
|
+
routing_weights, self.top_k, dim=-1
|
120
|
+
)
|
121
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
122
|
+
|
123
|
+
final_hidden_states = None
|
124
|
+
for expert_idx in self.expert_indicies:
|
125
|
+
expert_layer = self.experts[expert_idx]
|
126
|
+
expert_mask = selected_experts == expert_idx
|
127
|
+
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
|
128
|
+
|
129
|
+
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
|
130
|
+
if final_hidden_states is None:
|
131
|
+
final_hidden_states = current_hidden_states
|
132
|
+
else:
|
133
|
+
final_hidden_states.add_(current_hidden_states)
|
134
|
+
|
135
|
+
return tensor_model_parallel_all_reduce(final_hidden_states)
|
136
|
+
|
137
|
+
|
138
|
+
class MixtralAttention(nn.Module):
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
hidden_size: int,
|
142
|
+
num_heads: int,
|
143
|
+
num_kv_heads: int,
|
144
|
+
layer_id: int = 0,
|
145
|
+
max_position: int = 4096 * 32,
|
146
|
+
rope_theta: float = 10000,
|
147
|
+
quant_config: Optional[QuantizationConfig] = None,
|
148
|
+
sliding_window: Optional[int] = None,
|
149
|
+
) -> None:
|
150
|
+
super().__init__()
|
151
|
+
self.hidden_size = hidden_size
|
152
|
+
tp_size = get_tensor_model_parallel_world_size()
|
153
|
+
self.total_num_heads = num_heads
|
154
|
+
assert self.total_num_heads % tp_size == 0
|
155
|
+
self.num_heads = self.total_num_heads // tp_size
|
156
|
+
self.total_num_kv_heads = num_kv_heads
|
157
|
+
if self.total_num_kv_heads >= tp_size:
|
158
|
+
# Number of KV heads is greater than TP size, so we partition
|
159
|
+
# the KV heads across multiple tensor parallel GPUs.
|
160
|
+
assert self.total_num_kv_heads % tp_size == 0
|
161
|
+
else:
|
162
|
+
# Number of KV heads is less than TP size, so we replicate
|
163
|
+
# the KV heads across multiple tensor parallel GPUs.
|
164
|
+
assert tp_size % self.total_num_kv_heads == 0
|
165
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
166
|
+
self.head_dim = hidden_size // self.total_num_heads
|
167
|
+
self.q_size = self.num_heads * self.head_dim
|
168
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
169
|
+
self.scaling = self.head_dim**-0.5
|
170
|
+
self.rope_theta = rope_theta
|
171
|
+
self.sliding_window = sliding_window
|
172
|
+
|
173
|
+
self.qkv_proj = QKVParallelLinear(
|
174
|
+
hidden_size,
|
175
|
+
self.head_dim,
|
176
|
+
self.total_num_heads,
|
177
|
+
self.total_num_kv_heads,
|
178
|
+
bias=False,
|
179
|
+
quant_config=quant_config,
|
180
|
+
)
|
181
|
+
self.o_proj = RowParallelLinear(
|
182
|
+
self.total_num_heads * self.head_dim,
|
183
|
+
hidden_size,
|
184
|
+
bias=False,
|
185
|
+
quant_config=quant_config,
|
186
|
+
)
|
187
|
+
self.rotary_emb = get_rope(
|
188
|
+
self.head_dim,
|
189
|
+
rotary_dim=self.head_dim,
|
190
|
+
max_position=max_position,
|
191
|
+
base=int(self.rope_theta),
|
192
|
+
is_neox_style=True,
|
193
|
+
)
|
194
|
+
self.attn = RadixAttention(
|
195
|
+
self.num_heads,
|
196
|
+
self.head_dim,
|
197
|
+
self.scaling,
|
198
|
+
num_kv_heads=self.num_kv_heads,
|
199
|
+
layer_id=layer_id,
|
200
|
+
)
|
201
|
+
|
202
|
+
def forward(
|
203
|
+
self,
|
204
|
+
positions: torch.Tensor,
|
205
|
+
hidden_states: torch.Tensor,
|
206
|
+
input_metadata: InputMetadata,
|
207
|
+
) -> torch.Tensor:
|
208
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
209
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
210
|
+
q, k = self.rotary_emb(positions, q, k)
|
211
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
212
|
+
output, _ = self.o_proj(attn_output)
|
213
|
+
return output
|
214
|
+
|
215
|
+
|
216
|
+
class MixtralDecoderLayer(nn.Module):
|
217
|
+
def __init__(
|
218
|
+
self,
|
219
|
+
config: MixtralConfig,
|
220
|
+
layer_id: int = 0,
|
221
|
+
quant_config: Optional[QuantizationConfig] = None,
|
222
|
+
) -> None:
|
223
|
+
super().__init__()
|
224
|
+
self.hidden_size = config.hidden_size
|
225
|
+
# Requires transformers > 4.32.0
|
226
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
227
|
+
self.self_attn = MixtralAttention(
|
228
|
+
hidden_size=self.hidden_size,
|
229
|
+
num_heads=config.num_attention_heads,
|
230
|
+
max_position=config.max_position_embeddings,
|
231
|
+
num_kv_heads=config.num_key_value_heads,
|
232
|
+
layer_id=layer_id,
|
233
|
+
rope_theta=rope_theta,
|
234
|
+
sliding_window=config.sliding_window,
|
235
|
+
quant_config=quant_config,
|
236
|
+
)
|
237
|
+
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
|
238
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
239
|
+
self.post_attention_layernorm = RMSNorm(
|
240
|
+
config.hidden_size, eps=config.rms_norm_eps
|
241
|
+
)
|
242
|
+
|
243
|
+
def forward(
|
244
|
+
self,
|
245
|
+
positions: torch.Tensor,
|
246
|
+
hidden_states: torch.Tensor,
|
247
|
+
input_metadata: InputMetadata,
|
248
|
+
residual: Optional[torch.Tensor],
|
249
|
+
) -> torch.Tensor:
|
250
|
+
# Self Attention
|
251
|
+
if residual is None:
|
252
|
+
residual = hidden_states
|
253
|
+
hidden_states = self.input_layernorm(hidden_states)
|
254
|
+
else:
|
255
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
256
|
+
hidden_states = self.self_attn(
|
257
|
+
positions=positions,
|
258
|
+
hidden_states=hidden_states,
|
259
|
+
input_metadata=input_metadata,
|
260
|
+
)
|
261
|
+
|
262
|
+
# Fully Connected
|
263
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
264
|
+
hidden_states = self.block_sparse_moe(hidden_states)
|
265
|
+
return hidden_states, residual
|
266
|
+
|
267
|
+
|
268
|
+
class MixtralModel(nn.Module):
|
269
|
+
def __init__(
|
270
|
+
self,
|
271
|
+
config: MixtralConfig,
|
272
|
+
quant_config: Optional[QuantizationConfig] = None,
|
273
|
+
) -> None:
|
274
|
+
super().__init__()
|
275
|
+
self.padding_idx = config.pad_token_id
|
276
|
+
self.vocab_size = config.vocab_size
|
277
|
+
|
278
|
+
self.embed_tokens = VocabParallelEmbedding(
|
279
|
+
config.vocab_size,
|
280
|
+
config.hidden_size,
|
281
|
+
)
|
282
|
+
self.layers = nn.ModuleList(
|
283
|
+
[
|
284
|
+
MixtralDecoderLayer(config, i, quant_config=quant_config)
|
285
|
+
for i in range(config.num_hidden_layers)
|
286
|
+
]
|
287
|
+
)
|
288
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
289
|
+
|
290
|
+
def forward(
|
291
|
+
self,
|
292
|
+
input_ids: torch.Tensor,
|
293
|
+
positions: torch.Tensor,
|
294
|
+
input_metadata: InputMetadata,
|
295
|
+
input_embeds: torch.Tensor = None,
|
296
|
+
) -> torch.Tensor:
|
297
|
+
if input_embeds is None:
|
298
|
+
hidden_states = self.embed_tokens(input_ids)
|
299
|
+
else:
|
300
|
+
hidden_states = input_embeds
|
301
|
+
residual = None
|
302
|
+
for i in range(len(self.layers)):
|
303
|
+
layer = self.layers[i]
|
304
|
+
hidden_states, residual = layer(
|
305
|
+
positions, hidden_states, input_metadata, residual
|
306
|
+
)
|
307
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
308
|
+
return hidden_states
|
309
|
+
|
310
|
+
|
311
|
+
class QuantMixtralForCausalLM(nn.Module):
|
312
|
+
def __init__(
|
313
|
+
self,
|
314
|
+
config: MixtralConfig,
|
315
|
+
quant_config: Optional[QuantizationConfig] = None,
|
316
|
+
cache_config: Optional[CacheConfig] = None,
|
317
|
+
) -> None:
|
318
|
+
super().__init__()
|
319
|
+
self.config = config
|
320
|
+
self.quant_config = quant_config
|
321
|
+
self.model = MixtralModel(config, quant_config=quant_config)
|
322
|
+
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
323
|
+
self.logits_processor = LogitsProcessor(config)
|
324
|
+
|
325
|
+
def forward(
|
326
|
+
self,
|
327
|
+
input_ids: torch.Tensor,
|
328
|
+
positions: torch.Tensor,
|
329
|
+
input_metadata: InputMetadata,
|
330
|
+
input_embeds: torch.Tensor = None,
|
331
|
+
) -> torch.Tensor:
|
332
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
333
|
+
return self.logits_processor(
|
334
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
335
|
+
)
|
336
|
+
|
337
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
338
|
+
stacked_params_mapping = [
|
339
|
+
# (param_name, shard_name, shard_id)
|
340
|
+
("qkv_proj", "q_proj", "q"),
|
341
|
+
("qkv_proj", "k_proj", "k"),
|
342
|
+
("qkv_proj", "v_proj", "v"),
|
343
|
+
]
|
344
|
+
|
345
|
+
params_dict = dict(self.named_parameters())
|
346
|
+
for name, loaded_weight in weights:
|
347
|
+
if "rotary_emb.inv_freq" in name:
|
348
|
+
continue
|
349
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
350
|
+
if weight_name not in name:
|
351
|
+
continue
|
352
|
+
name = name.replace(weight_name, param_name)
|
353
|
+
# Skip loading extra bias for GPTQ models.
|
354
|
+
if name.endswith(".bias") and name not in params_dict:
|
355
|
+
continue
|
356
|
+
param = params_dict[name]
|
357
|
+
weight_loader = param.weight_loader
|
358
|
+
weight_loader(param, loaded_weight, shard_id)
|
359
|
+
break
|
360
|
+
else:
|
361
|
+
# Skip loading extra bias for GPTQ models.
|
362
|
+
if name.endswith(".bias") and name not in params_dict:
|
363
|
+
continue
|
364
|
+
# Skip experts that are not assigned to this worker.
|
365
|
+
if "block_sparse_moe.experts." in name and name not in params_dict:
|
366
|
+
continue
|
367
|
+
param = params_dict[name]
|
368
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
369
|
+
weight_loader(param, loaded_weight)
|
370
|
+
|
371
|
+
|
372
|
+
EntryClass = QuantMixtralForCausalLM
|
sglang/srt/models/qwen.py
CHANGED
@@ -1,31 +1,30 @@
|
|
1
|
-
|
1
|
+
# Adapted from
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
|
3
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
2
4
|
|
3
5
|
import torch
|
4
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
5
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
6
|
-
from sglang.srt.managers.router.model_runner import InputMetadata
|
7
6
|
from torch import nn
|
8
7
|
from transformers import PretrainedConfig
|
8
|
+
from vllm.config import CacheConfig
|
9
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
9
10
|
from vllm.model_executor.layers.activation import SiluAndMul
|
10
11
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
11
12
|
from vllm.model_executor.layers.linear import (
|
12
|
-
LinearMethodBase,
|
13
13
|
MergedColumnParallelLinear,
|
14
14
|
QKVParallelLinear,
|
15
15
|
RowParallelLinear,
|
16
16
|
)
|
17
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
17
18
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
18
19
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
19
20
|
ParallelLMHead,
|
20
21
|
VocabParallelEmbedding,
|
21
22
|
)
|
22
|
-
from vllm.model_executor.
|
23
|
-
|
24
|
-
|
25
|
-
from
|
26
|
-
|
27
|
-
hf_model_weights_iterator,
|
28
|
-
)
|
23
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
24
|
+
|
25
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
26
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
27
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
29
28
|
|
30
29
|
|
31
30
|
class QWenMLP(nn.Module):
|
@@ -34,7 +33,7 @@ class QWenMLP(nn.Module):
|
|
34
33
|
hidden_size: int,
|
35
34
|
intermediate_size: int,
|
36
35
|
hidden_act: str = "silu",
|
37
|
-
|
36
|
+
quant_config: Optional[QuantizationConfig] = None,
|
38
37
|
):
|
39
38
|
super().__init__()
|
40
39
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -42,14 +41,14 @@ class QWenMLP(nn.Module):
|
|
42
41
|
2 * [intermediate_size],
|
43
42
|
bias=False,
|
44
43
|
gather_output=False,
|
45
|
-
|
44
|
+
quant_config=quant_config,
|
46
45
|
)
|
47
46
|
self.c_proj = RowParallelLinear(
|
48
47
|
intermediate_size,
|
49
48
|
hidden_size,
|
50
49
|
bias=False,
|
51
50
|
input_is_parallel=True,
|
52
|
-
|
51
|
+
quant_config=quant_config,
|
53
52
|
)
|
54
53
|
if hidden_act != "silu":
|
55
54
|
raise ValueError(
|
@@ -74,7 +73,7 @@ class QWenAttention(nn.Module):
|
|
74
73
|
layer_id: int = 0,
|
75
74
|
rope_theta: float = 10000,
|
76
75
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
77
|
-
|
76
|
+
quant_config: Optional[QuantizationConfig] = None,
|
78
77
|
):
|
79
78
|
super().__init__()
|
80
79
|
self.hidden_size = hidden_size
|
@@ -90,14 +89,14 @@ class QWenAttention(nn.Module):
|
|
90
89
|
self.head_dim,
|
91
90
|
self.total_num_heads,
|
92
91
|
bias=True,
|
93
|
-
|
92
|
+
quant_config=quant_config,
|
94
93
|
)
|
95
94
|
self.c_proj = RowParallelLinear(
|
96
95
|
self.total_num_heads * self.head_dim,
|
97
96
|
hidden_size,
|
98
97
|
bias=False,
|
99
98
|
input_is_parallel=True,
|
100
|
-
|
99
|
+
quant_config=quant_config,
|
101
100
|
)
|
102
101
|
self.rotary_emb = get_rope(
|
103
102
|
self.head_dim,
|
@@ -130,7 +129,12 @@ class QWenAttention(nn.Module):
|
|
130
129
|
|
131
130
|
|
132
131
|
class QWenBlock(nn.Module):
|
133
|
-
def __init__(
|
132
|
+
def __init__(
|
133
|
+
self,
|
134
|
+
config: PretrainedConfig,
|
135
|
+
layer_id,
|
136
|
+
quant_config: Optional[QuantizationConfig] = None,
|
137
|
+
):
|
134
138
|
super().__init__()
|
135
139
|
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
136
140
|
|
@@ -143,7 +147,7 @@ class QWenBlock(nn.Module):
|
|
143
147
|
rope_theta=rope_theta,
|
144
148
|
rope_scaling=rope_scaling,
|
145
149
|
layer_id=layer_id,
|
146
|
-
|
150
|
+
quant_config=quant_config,
|
147
151
|
)
|
148
152
|
|
149
153
|
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
@@ -151,7 +155,7 @@ class QWenBlock(nn.Module):
|
|
151
155
|
self.mlp = QWenMLP(
|
152
156
|
config.hidden_size,
|
153
157
|
config.intermediate_size // 2,
|
154
|
-
|
158
|
+
quant_config=quant_config,
|
155
159
|
)
|
156
160
|
|
157
161
|
def forward(
|
@@ -179,7 +183,11 @@ class QWenBlock(nn.Module):
|
|
179
183
|
|
180
184
|
|
181
185
|
class QWenModel(nn.Module):
|
182
|
-
def __init__(
|
186
|
+
def __init__(
|
187
|
+
self,
|
188
|
+
config: PretrainedConfig,
|
189
|
+
quant_config: Optional[QuantizationConfig] = None,
|
190
|
+
):
|
183
191
|
super().__init__()
|
184
192
|
self.config = config
|
185
193
|
self.vocab_size = config.vocab_size
|
@@ -191,7 +199,7 @@ class QWenModel(nn.Module):
|
|
191
199
|
)
|
192
200
|
self.h = nn.ModuleList(
|
193
201
|
[
|
194
|
-
QWenBlock(config, i,
|
202
|
+
QWenBlock(config, i, quant_config=quant_config)
|
195
203
|
for i in range(config.num_hidden_layers)
|
196
204
|
]
|
197
205
|
)
|
@@ -216,10 +224,15 @@ class QWenModel(nn.Module):
|
|
216
224
|
|
217
225
|
|
218
226
|
class QWenLMHeadModel(nn.Module):
|
219
|
-
def __init__(
|
227
|
+
def __init__(
|
228
|
+
self,
|
229
|
+
config: PretrainedConfig,
|
230
|
+
quant_config: Optional[QuantizationConfig] = None,
|
231
|
+
cache_config: Optional[CacheConfig] = None,
|
232
|
+
):
|
220
233
|
super().__init__()
|
221
234
|
self.config = config
|
222
|
-
self.transformer = QWenModel(config,
|
235
|
+
self.transformer = QWenModel(config, quant_config=quant_config)
|
223
236
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
224
237
|
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
225
238
|
self.logits_processor = LogitsProcessor(config)
|
@@ -236,22 +249,14 @@ class QWenLMHeadModel(nn.Module):
|
|
236
249
|
)
|
237
250
|
return next_tokens
|
238
251
|
|
239
|
-
def load_weights(
|
240
|
-
self,
|
241
|
-
model_name_or_path: str,
|
242
|
-
cache_dir: Optional[str] = None,
|
243
|
-
load_format: str = "auto",
|
244
|
-
revision: Optional[str] = None,
|
245
|
-
):
|
252
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
246
253
|
stacked_params_mapping = [
|
247
254
|
# (param_name, shard_name, shard_id)
|
248
255
|
("gate_up_proj", "w2", 0),
|
249
256
|
("gate_up_proj", "w1", 1),
|
250
257
|
]
|
251
258
|
params_dict = dict(self.named_parameters())
|
252
|
-
for name, loaded_weight in
|
253
|
-
model_name_or_path, cache_dir, load_format, revision
|
254
|
-
):
|
259
|
+
for name, loaded_weight in weights:
|
255
260
|
if "rotary_emb.inv_freq" in name:
|
256
261
|
continue
|
257
262
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|