sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +5 -1
- sglang/api.py +8 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +11 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +161 -81
- sglang/lang/ir.py +29 -11
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +83 -2
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +26 -10
- sglang/srt/managers/router/infer_batch.py +130 -74
- sglang/srt/managers/router/manager.py +7 -9
- sglang/srt/managers/router/model_rpc.py +224 -135
- sglang/srt/managers/router/model_runner.py +94 -107
- sglang/srt/managers/router/radix_cache.py +54 -18
- sglang/srt/managers/router/scheduler.py +23 -34
- sglang/srt/managers/tokenizer_manager.py +183 -88
- sglang/srt/model_config.py +5 -2
- sglang/srt/models/commandr.py +15 -22
- sglang/srt/models/dbrx.py +22 -29
- sglang/srt/models/gemma.py +14 -24
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +24 -23
- sglang/srt/models/llava.py +85 -25
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/mixtral.py +254 -130
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +28 -25
- sglang/srt/models/qwen2.py +17 -22
- sglang/srt/models/stablelm.py +21 -26
- sglang/srt/models/yivl.py +17 -25
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +101 -52
- sglang/srt/server_args.py +59 -11
- sglang/srt/utils.py +242 -75
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +95 -26
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -402
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,373 @@
|
|
1
|
+
# Adapted from
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1
|
3
|
+
"""Inference-only Mixtral model."""
|
4
|
+
from typing import Iterable, Optional, Tuple
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
from torch import nn
|
10
|
+
from transformers import MixtralConfig
|
11
|
+
from vllm.config import CacheConfig
|
12
|
+
from vllm.distributed import (
|
13
|
+
get_tensor_model_parallel_rank,
|
14
|
+
get_tensor_model_parallel_world_size,
|
15
|
+
tensor_model_parallel_all_reduce,
|
16
|
+
)
|
17
|
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
18
|
+
from vllm.model_executor.layers.linear import (
|
19
|
+
QKVParallelLinear,
|
20
|
+
ReplicatedLinear,
|
21
|
+
RowParallelLinear,
|
22
|
+
)
|
23
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
24
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
26
|
+
ParallelLMHead,
|
27
|
+
VocabParallelEmbedding,
|
28
|
+
)
|
29
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
30
|
+
|
31
|
+
|
32
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
33
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
34
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
35
|
+
|
36
|
+
|
37
|
+
class MixtralMLP(nn.Module):
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
num_experts: int,
|
41
|
+
hidden_size: int,
|
42
|
+
intermediate_size: int,
|
43
|
+
quant_config: Optional[QuantizationConfig] = None,
|
44
|
+
) -> None:
|
45
|
+
super().__init__()
|
46
|
+
self.num_experts = num_experts
|
47
|
+
self.ffn_dim = intermediate_size
|
48
|
+
self.hidden_dim = hidden_size
|
49
|
+
|
50
|
+
self.w1 = ReplicatedLinear(
|
51
|
+
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
52
|
+
)
|
53
|
+
self.w2 = ReplicatedLinear(
|
54
|
+
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
|
55
|
+
)
|
56
|
+
self.w3 = ReplicatedLinear(
|
57
|
+
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
58
|
+
)
|
59
|
+
|
60
|
+
# TODO: Use vllm's SiluAndMul
|
61
|
+
self.act_fn = nn.SiLU()
|
62
|
+
|
63
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
64
|
+
w1_out, _ = self.w1(hidden_states)
|
65
|
+
w1_out = self.act_fn(w1_out)
|
66
|
+
w3_out, _ = self.w3(hidden_states)
|
67
|
+
current_hidden_states = w1_out * w3_out
|
68
|
+
current_hidden_states, _ = self.w2(current_hidden_states)
|
69
|
+
return current_hidden_states
|
70
|
+
|
71
|
+
|
72
|
+
class MixtralMoE(nn.Module):
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
config: MixtralConfig,
|
76
|
+
quant_config: Optional[QuantizationConfig] = None,
|
77
|
+
):
|
78
|
+
super().__init__()
|
79
|
+
self.config = config
|
80
|
+
self.rank = get_tensor_model_parallel_rank()
|
81
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
82
|
+
self.num_total_experts = config.num_local_experts
|
83
|
+
self.top_k = config.num_experts_per_tok
|
84
|
+
if self.tp_size > self.num_total_experts:
|
85
|
+
raise ValueError(
|
86
|
+
f"Tensor parallel size {self.tp_size} is greater than "
|
87
|
+
f"the number of experts {self.num_total_experts}."
|
88
|
+
)
|
89
|
+
# Split experts equally between ranks
|
90
|
+
self.expert_indicies = np.array_split(
|
91
|
+
range(self.num_total_experts), self.tp_size
|
92
|
+
)[self.rank].tolist()
|
93
|
+
if not self.expert_indicies:
|
94
|
+
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
|
95
|
+
|
96
|
+
self.experts = nn.ModuleList(
|
97
|
+
[
|
98
|
+
(
|
99
|
+
MixtralMLP(
|
100
|
+
self.num_total_experts,
|
101
|
+
config.hidden_size,
|
102
|
+
config.intermediate_size,
|
103
|
+
quant_config=quant_config,
|
104
|
+
)
|
105
|
+
if idx in self.expert_indicies
|
106
|
+
else None
|
107
|
+
)
|
108
|
+
for idx in range(self.num_total_experts)
|
109
|
+
]
|
110
|
+
)
|
111
|
+
self.gate = ReplicatedLinear(
|
112
|
+
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
|
113
|
+
)
|
114
|
+
|
115
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
116
|
+
router_logits, _ = self.gate(hidden_states)
|
117
|
+
|
118
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
119
|
+
routing_weights, selected_experts = torch.topk(
|
120
|
+
routing_weights, self.top_k, dim=-1
|
121
|
+
)
|
122
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
123
|
+
|
124
|
+
final_hidden_states = None
|
125
|
+
for expert_idx in self.expert_indicies:
|
126
|
+
expert_layer = self.experts[expert_idx]
|
127
|
+
expert_mask = selected_experts == expert_idx
|
128
|
+
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
|
129
|
+
|
130
|
+
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
|
131
|
+
if final_hidden_states is None:
|
132
|
+
final_hidden_states = current_hidden_states
|
133
|
+
else:
|
134
|
+
final_hidden_states.add_(current_hidden_states)
|
135
|
+
|
136
|
+
return tensor_model_parallel_all_reduce(final_hidden_states)
|
137
|
+
|
138
|
+
|
139
|
+
class MixtralAttention(nn.Module):
|
140
|
+
def __init__(
|
141
|
+
self,
|
142
|
+
hidden_size: int,
|
143
|
+
num_heads: int,
|
144
|
+
num_kv_heads: int,
|
145
|
+
layer_id: int = 0,
|
146
|
+
max_position: int = 4096 * 32,
|
147
|
+
rope_theta: float = 10000,
|
148
|
+
quant_config: Optional[QuantizationConfig] = None,
|
149
|
+
sliding_window: Optional[int] = None,
|
150
|
+
) -> None:
|
151
|
+
super().__init__()
|
152
|
+
self.hidden_size = hidden_size
|
153
|
+
tp_size = get_tensor_model_parallel_world_size()
|
154
|
+
self.total_num_heads = num_heads
|
155
|
+
assert self.total_num_heads % tp_size == 0
|
156
|
+
self.num_heads = self.total_num_heads // tp_size
|
157
|
+
self.total_num_kv_heads = num_kv_heads
|
158
|
+
if self.total_num_kv_heads >= tp_size:
|
159
|
+
# Number of KV heads is greater than TP size, so we partition
|
160
|
+
# the KV heads across multiple tensor parallel GPUs.
|
161
|
+
assert self.total_num_kv_heads % tp_size == 0
|
162
|
+
else:
|
163
|
+
# Number of KV heads is less than TP size, so we replicate
|
164
|
+
# the KV heads across multiple tensor parallel GPUs.
|
165
|
+
assert tp_size % self.total_num_kv_heads == 0
|
166
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
167
|
+
self.head_dim = hidden_size // self.total_num_heads
|
168
|
+
self.q_size = self.num_heads * self.head_dim
|
169
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
170
|
+
self.scaling = self.head_dim**-0.5
|
171
|
+
self.rope_theta = rope_theta
|
172
|
+
self.sliding_window = sliding_window
|
173
|
+
|
174
|
+
self.qkv_proj = QKVParallelLinear(
|
175
|
+
hidden_size,
|
176
|
+
self.head_dim,
|
177
|
+
self.total_num_heads,
|
178
|
+
self.total_num_kv_heads,
|
179
|
+
bias=False,
|
180
|
+
quant_config=quant_config,
|
181
|
+
)
|
182
|
+
self.o_proj = RowParallelLinear(
|
183
|
+
self.total_num_heads * self.head_dim,
|
184
|
+
hidden_size,
|
185
|
+
bias=False,
|
186
|
+
quant_config=quant_config,
|
187
|
+
)
|
188
|
+
self.rotary_emb = get_rope(
|
189
|
+
self.head_dim,
|
190
|
+
rotary_dim=self.head_dim,
|
191
|
+
max_position=max_position,
|
192
|
+
base=int(self.rope_theta),
|
193
|
+
is_neox_style=True,
|
194
|
+
)
|
195
|
+
self.attn = RadixAttention(
|
196
|
+
self.num_heads,
|
197
|
+
self.head_dim,
|
198
|
+
self.scaling,
|
199
|
+
num_kv_heads=self.num_kv_heads,
|
200
|
+
layer_id=layer_id,
|
201
|
+
)
|
202
|
+
|
203
|
+
def forward(
|
204
|
+
self,
|
205
|
+
positions: torch.Tensor,
|
206
|
+
hidden_states: torch.Tensor,
|
207
|
+
input_metadata: InputMetadata,
|
208
|
+
) -> torch.Tensor:
|
209
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
210
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
211
|
+
q, k = self.rotary_emb(positions, q, k)
|
212
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
213
|
+
output, _ = self.o_proj(attn_output)
|
214
|
+
return output
|
215
|
+
|
216
|
+
|
217
|
+
class MixtralDecoderLayer(nn.Module):
|
218
|
+
def __init__(
|
219
|
+
self,
|
220
|
+
config: MixtralConfig,
|
221
|
+
layer_id: int = 0,
|
222
|
+
quant_config: Optional[QuantizationConfig] = None,
|
223
|
+
) -> None:
|
224
|
+
super().__init__()
|
225
|
+
self.hidden_size = config.hidden_size
|
226
|
+
# Requires transformers > 4.32.0
|
227
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
228
|
+
self.self_attn = MixtralAttention(
|
229
|
+
hidden_size=self.hidden_size,
|
230
|
+
num_heads=config.num_attention_heads,
|
231
|
+
max_position=config.max_position_embeddings,
|
232
|
+
num_kv_heads=config.num_key_value_heads,
|
233
|
+
layer_id=layer_id,
|
234
|
+
rope_theta=rope_theta,
|
235
|
+
sliding_window=config.sliding_window,
|
236
|
+
quant_config=quant_config,
|
237
|
+
)
|
238
|
+
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
|
239
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
240
|
+
self.post_attention_layernorm = RMSNorm(
|
241
|
+
config.hidden_size, eps=config.rms_norm_eps
|
242
|
+
)
|
243
|
+
|
244
|
+
def forward(
|
245
|
+
self,
|
246
|
+
positions: torch.Tensor,
|
247
|
+
hidden_states: torch.Tensor,
|
248
|
+
input_metadata: InputMetadata,
|
249
|
+
residual: Optional[torch.Tensor],
|
250
|
+
) -> torch.Tensor:
|
251
|
+
# Self Attention
|
252
|
+
if residual is None:
|
253
|
+
residual = hidden_states
|
254
|
+
hidden_states = self.input_layernorm(hidden_states)
|
255
|
+
else:
|
256
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
257
|
+
hidden_states = self.self_attn(
|
258
|
+
positions=positions,
|
259
|
+
hidden_states=hidden_states,
|
260
|
+
input_metadata=input_metadata,
|
261
|
+
)
|
262
|
+
|
263
|
+
# Fully Connected
|
264
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
265
|
+
hidden_states = self.block_sparse_moe(hidden_states)
|
266
|
+
return hidden_states, residual
|
267
|
+
|
268
|
+
|
269
|
+
class MixtralModel(nn.Module):
|
270
|
+
def __init__(
|
271
|
+
self,
|
272
|
+
config: MixtralConfig,
|
273
|
+
quant_config: Optional[QuantizationConfig] = None,
|
274
|
+
) -> None:
|
275
|
+
super().__init__()
|
276
|
+
self.padding_idx = config.pad_token_id
|
277
|
+
self.vocab_size = config.vocab_size
|
278
|
+
|
279
|
+
self.embed_tokens = VocabParallelEmbedding(
|
280
|
+
config.vocab_size,
|
281
|
+
config.hidden_size,
|
282
|
+
)
|
283
|
+
self.layers = nn.ModuleList(
|
284
|
+
[
|
285
|
+
MixtralDecoderLayer(config, i, quant_config=quant_config)
|
286
|
+
for i in range(config.num_hidden_layers)
|
287
|
+
]
|
288
|
+
)
|
289
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
290
|
+
|
291
|
+
def forward(
|
292
|
+
self,
|
293
|
+
input_ids: torch.Tensor,
|
294
|
+
positions: torch.Tensor,
|
295
|
+
input_metadata: InputMetadata,
|
296
|
+
input_embeds: torch.Tensor = None,
|
297
|
+
) -> torch.Tensor:
|
298
|
+
if input_embeds is None:
|
299
|
+
hidden_states = self.embed_tokens(input_ids)
|
300
|
+
else:
|
301
|
+
hidden_states = input_embeds
|
302
|
+
residual = None
|
303
|
+
for i in range(len(self.layers)):
|
304
|
+
layer = self.layers[i]
|
305
|
+
hidden_states, residual = layer(
|
306
|
+
positions, hidden_states, input_metadata, residual
|
307
|
+
)
|
308
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
309
|
+
return hidden_states
|
310
|
+
|
311
|
+
|
312
|
+
class QuantMixtralForCausalLM(nn.Module):
|
313
|
+
def __init__(
|
314
|
+
self,
|
315
|
+
config: MixtralConfig,
|
316
|
+
quant_config: Optional[QuantizationConfig] = None,
|
317
|
+
cache_config: Optional[CacheConfig] = None,
|
318
|
+
) -> None:
|
319
|
+
super().__init__()
|
320
|
+
self.config = config
|
321
|
+
self.quant_config = quant_config
|
322
|
+
self.model = MixtralModel(config, quant_config=quant_config)
|
323
|
+
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
324
|
+
self.logits_processor = LogitsProcessor(config)
|
325
|
+
|
326
|
+
def forward(
|
327
|
+
self,
|
328
|
+
input_ids: torch.Tensor,
|
329
|
+
positions: torch.Tensor,
|
330
|
+
input_metadata: InputMetadata,
|
331
|
+
input_embeds: torch.Tensor = None,
|
332
|
+
) -> torch.Tensor:
|
333
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
334
|
+
return self.logits_processor(
|
335
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
336
|
+
)
|
337
|
+
|
338
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
339
|
+
stacked_params_mapping = [
|
340
|
+
# (param_name, shard_name, shard_id)
|
341
|
+
("qkv_proj", "q_proj", "q"),
|
342
|
+
("qkv_proj", "k_proj", "k"),
|
343
|
+
("qkv_proj", "v_proj", "v"),
|
344
|
+
]
|
345
|
+
|
346
|
+
params_dict = dict(self.named_parameters())
|
347
|
+
for name, loaded_weight in weights:
|
348
|
+
if "rotary_emb.inv_freq" in name:
|
349
|
+
continue
|
350
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
351
|
+
if weight_name not in name:
|
352
|
+
continue
|
353
|
+
name = name.replace(weight_name, param_name)
|
354
|
+
# Skip loading extra bias for GPTQ models.
|
355
|
+
if name.endswith(".bias") and name not in params_dict:
|
356
|
+
continue
|
357
|
+
param = params_dict[name]
|
358
|
+
weight_loader = param.weight_loader
|
359
|
+
weight_loader(param, loaded_weight, shard_id)
|
360
|
+
break
|
361
|
+
else:
|
362
|
+
# Skip loading extra bias for GPTQ models.
|
363
|
+
if name.endswith(".bias") and name not in params_dict:
|
364
|
+
continue
|
365
|
+
# Skip experts that are not assigned to this worker.
|
366
|
+
if "block_sparse_moe.experts." in name and name not in params_dict:
|
367
|
+
continue
|
368
|
+
param = params_dict[name]
|
369
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
370
|
+
weight_loader(param, loaded_weight)
|
371
|
+
|
372
|
+
|
373
|
+
EntryClass = QuantMixtralForCausalLM
|
sglang/srt/models/qwen.py
CHANGED
@@ -1,8 +1,12 @@
|
|
1
|
-
|
1
|
+
# Adapted from
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
|
3
|
+
from typing import Any, Dict, Optional, Iterable, Tuple
|
2
4
|
|
3
5
|
import torch
|
4
6
|
from torch import nn
|
5
7
|
from transformers import PretrainedConfig
|
8
|
+
from vllm.config import CacheConfig
|
9
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
6
10
|
from vllm.model_executor.layers.activation import SiluAndMul
|
7
11
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
8
12
|
from vllm.model_executor.layers.linear import (
|
@@ -10,24 +14,17 @@ from vllm.model_executor.layers.linear import (
|
|
10
14
|
QKVParallelLinear,
|
11
15
|
RowParallelLinear,
|
12
16
|
)
|
13
|
-
from vllm.model_executor.layers.quantization.base_config import
|
14
|
-
QuantizationConfig)
|
17
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
15
18
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
16
19
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
17
20
|
ParallelLMHead,
|
18
21
|
VocabParallelEmbedding,
|
19
22
|
)
|
20
|
-
from vllm.
|
21
|
-
get_tensor_model_parallel_world_size,
|
22
|
-
)
|
23
|
-
from sglang.srt.weight_utils import (
|
24
|
-
default_weight_loader,
|
25
|
-
hf_model_weights_iterator,
|
26
|
-
)
|
23
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27
24
|
|
28
25
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
29
26
|
from sglang.srt.layers.radix_attention import RadixAttention
|
30
|
-
from sglang.srt.managers.
|
27
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
31
28
|
|
32
29
|
|
33
30
|
class QWenMLP(nn.Module):
|
@@ -132,7 +129,12 @@ class QWenAttention(nn.Module):
|
|
132
129
|
|
133
130
|
|
134
131
|
class QWenBlock(nn.Module):
|
135
|
-
def __init__(
|
132
|
+
def __init__(
|
133
|
+
self,
|
134
|
+
config: PretrainedConfig,
|
135
|
+
layer_id,
|
136
|
+
quant_config: Optional[QuantizationConfig] = None,
|
137
|
+
):
|
136
138
|
super().__init__()
|
137
139
|
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
138
140
|
|
@@ -181,7 +183,11 @@ class QWenBlock(nn.Module):
|
|
181
183
|
|
182
184
|
|
183
185
|
class QWenModel(nn.Module):
|
184
|
-
def __init__(
|
186
|
+
def __init__(
|
187
|
+
self,
|
188
|
+
config: PretrainedConfig,
|
189
|
+
quant_config: Optional[QuantizationConfig] = None,
|
190
|
+
):
|
185
191
|
super().__init__()
|
186
192
|
self.config = config
|
187
193
|
self.vocab_size = config.vocab_size
|
@@ -218,7 +224,12 @@ class QWenModel(nn.Module):
|
|
218
224
|
|
219
225
|
|
220
226
|
class QWenLMHeadModel(nn.Module):
|
221
|
-
def __init__(
|
227
|
+
def __init__(
|
228
|
+
self,
|
229
|
+
config: PretrainedConfig,
|
230
|
+
quant_config: Optional[QuantizationConfig] = None,
|
231
|
+
cache_config: Optional[CacheConfig] = None,
|
232
|
+
):
|
222
233
|
super().__init__()
|
223
234
|
self.config = config
|
224
235
|
self.transformer = QWenModel(config, quant_config=quant_config)
|
@@ -238,22 +249,14 @@ class QWenLMHeadModel(nn.Module):
|
|
238
249
|
)
|
239
250
|
return next_tokens
|
240
251
|
|
241
|
-
def load_weights(
|
242
|
-
self,
|
243
|
-
model_name_or_path: str,
|
244
|
-
cache_dir: Optional[str] = None,
|
245
|
-
load_format: str = "auto",
|
246
|
-
revision: Optional[str] = None,
|
247
|
-
):
|
252
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
248
253
|
stacked_params_mapping = [
|
249
254
|
# (param_name, shard_name, shard_id)
|
250
255
|
("gate_up_proj", "w2", 0),
|
251
256
|
("gate_up_proj", "w1", 1),
|
252
257
|
]
|
253
258
|
params_dict = dict(self.named_parameters())
|
254
|
-
for name, loaded_weight in
|
255
|
-
model_name_or_path, cache_dir, load_format, revision
|
256
|
-
):
|
259
|
+
for name, loaded_weight in weights:
|
257
260
|
if "rotary_emb.inv_freq" in name:
|
258
261
|
continue
|
259
262
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
@@ -276,4 +279,4 @@ class QWenLMHeadModel(nn.Module):
|
|
276
279
|
weight_loader(param, loaded_weight)
|
277
280
|
|
278
281
|
|
279
|
-
EntryClass = QWenLMHeadModel
|
282
|
+
EntryClass = QWenLMHeadModel
|
sglang/srt/models/qwen2.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
1
|
# Adapted from llama2.py
|
2
2
|
# Modify details for the adaptation of Qwen2 model.
|
3
3
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
4
|
-
from typing import Any, Dict, Optional, Tuple
|
4
|
+
from typing import Any, Dict, Optional, Tuple, Iterable
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
8
|
+
from vllm.config import CacheConfig
|
9
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
8
10
|
from vllm.model_executor.layers.activation import SiluAndMul
|
9
11
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
10
12
|
from vllm.model_executor.layers.linear import (
|
@@ -12,24 +14,17 @@ from vllm.model_executor.layers.linear import (
|
|
12
14
|
QKVParallelLinear,
|
13
15
|
RowParallelLinear,
|
14
16
|
)
|
15
|
-
from vllm.model_executor.layers.quantization.base_config import
|
16
|
-
QuantizationConfig)
|
17
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
17
18
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
18
19
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
19
20
|
ParallelLMHead,
|
20
21
|
VocabParallelEmbedding,
|
21
22
|
)
|
22
|
-
from vllm.
|
23
|
-
get_tensor_model_parallel_world_size,
|
24
|
-
)
|
25
|
-
from sglang.srt.weight_utils import (
|
26
|
-
default_weight_loader,
|
27
|
-
hf_model_weights_iterator,
|
28
|
-
)
|
23
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
29
24
|
|
30
25
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
31
26
|
from sglang.srt.layers.radix_attention import RadixAttention
|
32
|
-
from sglang.srt.managers.
|
27
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
33
28
|
|
34
29
|
Qwen2Config = None
|
35
30
|
|
@@ -50,7 +45,10 @@ class Qwen2MLP(nn.Module):
|
|
50
45
|
quant_config=quant_config,
|
51
46
|
)
|
52
47
|
self.down_proj = RowParallelLinear(
|
53
|
-
intermediate_size,
|
48
|
+
intermediate_size,
|
49
|
+
hidden_size,
|
50
|
+
bias=False,
|
51
|
+
quant_config=quant_config,
|
54
52
|
)
|
55
53
|
if hidden_act != "silu":
|
56
54
|
raise ValueError(
|
@@ -254,6 +252,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
254
252
|
self,
|
255
253
|
config: Qwen2Config,
|
256
254
|
quant_config: Optional[QuantizationConfig] = None,
|
255
|
+
cache_config: Optional[CacheConfig] = None,
|
257
256
|
) -> None:
|
258
257
|
super().__init__()
|
259
258
|
self.config = config
|
@@ -274,13 +273,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
274
273
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
275
274
|
)
|
276
275
|
|
277
|
-
def load_weights(
|
278
|
-
self,
|
279
|
-
model_name_or_path: str,
|
280
|
-
cache_dir: Optional[str] = None,
|
281
|
-
load_format: str = "auto",
|
282
|
-
revision: Optional[str] = None,
|
283
|
-
):
|
276
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
284
277
|
stacked_params_mapping = [
|
285
278
|
# (param_name, shard_name, shard_id)
|
286
279
|
("qkv_proj", "q_proj", "q"),
|
@@ -290,9 +283,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
290
283
|
("gate_up_proj", "up_proj", 1),
|
291
284
|
]
|
292
285
|
params_dict = dict(self.named_parameters())
|
293
|
-
for name, loaded_weight in
|
294
|
-
model_name_or_path, cache_dir, load_format, revision
|
295
|
-
):
|
286
|
+
for name, loaded_weight in weights:
|
296
287
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
297
288
|
continue
|
298
289
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
@@ -306,6 +297,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|
306
297
|
# Skip loading extra bias for GPTQ models.
|
307
298
|
if name.endswith(".bias") and name not in params_dict:
|
308
299
|
continue
|
300
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
301
|
+
continue
|
309
302
|
param = params_dict[name]
|
310
303
|
weight_loader = param.weight_loader
|
311
304
|
weight_loader(param, loaded_weight, shard_id)
|
@@ -314,6 +307,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|
314
307
|
# Skip loading extra bias for GPTQ models.
|
315
308
|
if name.endswith(".bias") and name not in params_dict:
|
316
309
|
continue
|
310
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
311
|
+
continue
|
317
312
|
param = params_dict[name]
|
318
313
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
319
314
|
weight_loader(param, loaded_weight)
|
sglang/srt/models/stablelm.py
CHANGED
@@ -1,41 +1,38 @@
|
|
1
|
-
#
|
2
|
-
# https://github.com/vllm-project/vllm/blob/
|
1
|
+
# Adapted from:
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
|
3
3
|
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
|
4
4
|
model compatible with HuggingFace weights."""
|
5
|
-
from typing import Optional, Tuple
|
5
|
+
from typing import Optional, Tuple, Iterable
|
6
6
|
|
7
7
|
import torch
|
8
8
|
from torch import nn
|
9
9
|
from transformers import PretrainedConfig
|
10
|
+
from vllm.config import CacheConfig
|
11
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
10
12
|
from vllm.model_executor.layers.activation import SiluAndMul
|
11
13
|
from vllm.model_executor.layers.linear import (
|
12
14
|
MergedColumnParallelLinear,
|
13
15
|
QKVParallelLinear,
|
14
16
|
RowParallelLinear,
|
15
17
|
)
|
16
|
-
from vllm.model_executor.layers.quantization.base_config import
|
17
|
-
QuantizationConfig)
|
18
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
18
19
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
19
20
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
20
21
|
ParallelLMHead,
|
21
22
|
VocabParallelEmbedding,
|
22
23
|
)
|
23
|
-
from vllm.
|
24
|
-
get_tensor_model_parallel_world_size,
|
25
|
-
)
|
26
|
-
from sglang.srt.weight_utils import (
|
27
|
-
default_weight_loader,
|
28
|
-
hf_model_weights_iterator,
|
29
|
-
)
|
24
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
30
25
|
|
31
26
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
32
27
|
from sglang.srt.layers.radix_attention import RadixAttention
|
33
|
-
from sglang.srt.managers.
|
28
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
34
29
|
|
35
30
|
|
36
31
|
class StablelmMLP(nn.Module):
|
37
32
|
def __init__(
|
38
|
-
self,
|
33
|
+
self,
|
34
|
+
config: PretrainedConfig,
|
35
|
+
quant_config: Optional[QuantizationConfig] = None,
|
39
36
|
) -> None:
|
40
37
|
super().__init__()
|
41
38
|
self.config = config
|
@@ -48,7 +45,10 @@ class StablelmMLP(nn.Module):
|
|
48
45
|
quant_config=quant_config,
|
49
46
|
)
|
50
47
|
self.down_proj = RowParallelLinear(
|
51
|
-
config.intermediate_size,
|
48
|
+
config.intermediate_size,
|
49
|
+
config.hidden_size,
|
50
|
+
bias=False,
|
51
|
+
quant_config=quant_config,
|
52
52
|
)
|
53
53
|
self.act_fn = SiluAndMul()
|
54
54
|
|
@@ -181,7 +181,9 @@ class StablelmDecoderLayer(nn.Module):
|
|
181
181
|
|
182
182
|
class StableLMEpochModel(nn.Module):
|
183
183
|
def __init__(
|
184
|
-
self,
|
184
|
+
self,
|
185
|
+
config: PretrainedConfig,
|
186
|
+
quant_config: Optional[QuantizationConfig] = None,
|
185
187
|
) -> None:
|
186
188
|
super().__init__()
|
187
189
|
self.embed_tokens = VocabParallelEmbedding(
|
@@ -224,6 +226,7 @@ class StableLmForCausalLM(nn.Module):
|
|
224
226
|
self,
|
225
227
|
config: PretrainedConfig,
|
226
228
|
quant_config: Optional[QuantizationConfig] = None,
|
229
|
+
cache_config: Optional[CacheConfig] = None,
|
227
230
|
) -> None:
|
228
231
|
super().__init__()
|
229
232
|
self.config = config
|
@@ -244,13 +247,7 @@ class StableLmForCausalLM(nn.Module):
|
|
244
247
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
245
248
|
)
|
246
249
|
|
247
|
-
def load_weights(
|
248
|
-
self,
|
249
|
-
model_name_or_path: str,
|
250
|
-
cache_dir: Optional[str] = None,
|
251
|
-
load_format: str = "auto",
|
252
|
-
revision: Optional[str] = None,
|
253
|
-
):
|
250
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
254
251
|
stacked_params_mapping = [
|
255
252
|
# (param_name, shard_name, shard_id)
|
256
253
|
("qkv_proj", "q_proj", "q"),
|
@@ -260,9 +257,7 @@ class StableLmForCausalLM(nn.Module):
|
|
260
257
|
("gate_up_proj", "up_proj", 1),
|
261
258
|
]
|
262
259
|
params_dict = dict(self.named_parameters())
|
263
|
-
for name, loaded_weight in
|
264
|
-
model_name_or_path, cache_dir, load_format, revision
|
265
|
-
):
|
260
|
+
for name, loaded_weight in weights:
|
266
261
|
if "rotary_emb.inv_freq" in name:
|
267
262
|
continue
|
268
263
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|