sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +7 -7
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +158 -11
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +12 -2
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +28 -3
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +8 -2
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +3 -1
- sglang/srt/hf_transformers_utils.py +130 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +582 -0
- sglang/srt/layers/logits_processor.py +65 -32
- sglang/srt/layers/radix_attention.py +41 -7
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
- sglang/srt/managers/{router → controller}/model_runner.py +262 -158
- sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
- sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
- sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
- sglang/srt/managers/detokenizer_manager.py +42 -46
- sglang/srt/managers/io_struct.py +22 -12
- sglang/srt/managers/tokenizer_manager.py +151 -87
- sglang/srt/model_config.py +83 -5
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +12 -15
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +26 -15
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +86 -19
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +282 -103
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +150 -95
- sglang/srt/openai_protocol.py +11 -2
- sglang/srt/server.py +124 -48
- sglang/srt/server_args.py +128 -48
- sglang/srt/utils.py +234 -67
- sglang/test/test_programs.py +65 -3
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +23 -4
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,372 @@
|
|
1
|
+
# Adapted from
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1
|
3
|
+
"""Inference-only Mixtral model."""
|
4
|
+
from typing import Iterable, Optional, Tuple
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
from torch import nn
|
10
|
+
from transformers import MixtralConfig
|
11
|
+
from vllm.config import CacheConfig
|
12
|
+
from vllm.distributed import (
|
13
|
+
get_tensor_model_parallel_rank,
|
14
|
+
get_tensor_model_parallel_world_size,
|
15
|
+
tensor_model_parallel_all_reduce,
|
16
|
+
)
|
17
|
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
18
|
+
from vllm.model_executor.layers.linear import (
|
19
|
+
QKVParallelLinear,
|
20
|
+
ReplicatedLinear,
|
21
|
+
RowParallelLinear,
|
22
|
+
)
|
23
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
24
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
26
|
+
ParallelLMHead,
|
27
|
+
VocabParallelEmbedding,
|
28
|
+
)
|
29
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
30
|
+
|
31
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
32
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
33
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
34
|
+
|
35
|
+
|
36
|
+
class MixtralMLP(nn.Module):
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
num_experts: int,
|
40
|
+
hidden_size: int,
|
41
|
+
intermediate_size: int,
|
42
|
+
quant_config: Optional[QuantizationConfig] = None,
|
43
|
+
) -> None:
|
44
|
+
super().__init__()
|
45
|
+
self.num_experts = num_experts
|
46
|
+
self.ffn_dim = intermediate_size
|
47
|
+
self.hidden_dim = hidden_size
|
48
|
+
|
49
|
+
self.w1 = ReplicatedLinear(
|
50
|
+
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
51
|
+
)
|
52
|
+
self.w2 = ReplicatedLinear(
|
53
|
+
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
|
54
|
+
)
|
55
|
+
self.w3 = ReplicatedLinear(
|
56
|
+
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
57
|
+
)
|
58
|
+
|
59
|
+
# TODO: Use vllm's SiluAndMul
|
60
|
+
self.act_fn = nn.SiLU()
|
61
|
+
|
62
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
63
|
+
w1_out, _ = self.w1(hidden_states)
|
64
|
+
w1_out = self.act_fn(w1_out)
|
65
|
+
w3_out, _ = self.w3(hidden_states)
|
66
|
+
current_hidden_states = w1_out * w3_out
|
67
|
+
current_hidden_states, _ = self.w2(current_hidden_states)
|
68
|
+
return current_hidden_states
|
69
|
+
|
70
|
+
|
71
|
+
class MixtralMoE(nn.Module):
|
72
|
+
def __init__(
|
73
|
+
self,
|
74
|
+
config: MixtralConfig,
|
75
|
+
quant_config: Optional[QuantizationConfig] = None,
|
76
|
+
):
|
77
|
+
super().__init__()
|
78
|
+
self.config = config
|
79
|
+
self.rank = get_tensor_model_parallel_rank()
|
80
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
81
|
+
self.num_total_experts = config.num_local_experts
|
82
|
+
self.top_k = config.num_experts_per_tok
|
83
|
+
if self.tp_size > self.num_total_experts:
|
84
|
+
raise ValueError(
|
85
|
+
f"Tensor parallel size {self.tp_size} is greater than "
|
86
|
+
f"the number of experts {self.num_total_experts}."
|
87
|
+
)
|
88
|
+
# Split experts equally between ranks
|
89
|
+
self.expert_indicies = np.array_split(
|
90
|
+
range(self.num_total_experts), self.tp_size
|
91
|
+
)[self.rank].tolist()
|
92
|
+
if not self.expert_indicies:
|
93
|
+
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
|
94
|
+
|
95
|
+
self.experts = nn.ModuleList(
|
96
|
+
[
|
97
|
+
(
|
98
|
+
MixtralMLP(
|
99
|
+
self.num_total_experts,
|
100
|
+
config.hidden_size,
|
101
|
+
config.intermediate_size,
|
102
|
+
quant_config=quant_config,
|
103
|
+
)
|
104
|
+
if idx in self.expert_indicies
|
105
|
+
else None
|
106
|
+
)
|
107
|
+
for idx in range(self.num_total_experts)
|
108
|
+
]
|
109
|
+
)
|
110
|
+
self.gate = ReplicatedLinear(
|
111
|
+
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
|
112
|
+
)
|
113
|
+
|
114
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
115
|
+
router_logits, _ = self.gate(hidden_states)
|
116
|
+
|
117
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
118
|
+
routing_weights, selected_experts = torch.topk(
|
119
|
+
routing_weights, self.top_k, dim=-1
|
120
|
+
)
|
121
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
122
|
+
|
123
|
+
final_hidden_states = None
|
124
|
+
for expert_idx in self.expert_indicies:
|
125
|
+
expert_layer = self.experts[expert_idx]
|
126
|
+
expert_mask = selected_experts == expert_idx
|
127
|
+
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
|
128
|
+
|
129
|
+
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
|
130
|
+
if final_hidden_states is None:
|
131
|
+
final_hidden_states = current_hidden_states
|
132
|
+
else:
|
133
|
+
final_hidden_states.add_(current_hidden_states)
|
134
|
+
|
135
|
+
return tensor_model_parallel_all_reduce(final_hidden_states)
|
136
|
+
|
137
|
+
|
138
|
+
class MixtralAttention(nn.Module):
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
hidden_size: int,
|
142
|
+
num_heads: int,
|
143
|
+
num_kv_heads: int,
|
144
|
+
layer_id: int = 0,
|
145
|
+
max_position: int = 4096 * 32,
|
146
|
+
rope_theta: float = 10000,
|
147
|
+
quant_config: Optional[QuantizationConfig] = None,
|
148
|
+
sliding_window: Optional[int] = None,
|
149
|
+
) -> None:
|
150
|
+
super().__init__()
|
151
|
+
self.hidden_size = hidden_size
|
152
|
+
tp_size = get_tensor_model_parallel_world_size()
|
153
|
+
self.total_num_heads = num_heads
|
154
|
+
assert self.total_num_heads % tp_size == 0
|
155
|
+
self.num_heads = self.total_num_heads // tp_size
|
156
|
+
self.total_num_kv_heads = num_kv_heads
|
157
|
+
if self.total_num_kv_heads >= tp_size:
|
158
|
+
# Number of KV heads is greater than TP size, so we partition
|
159
|
+
# the KV heads across multiple tensor parallel GPUs.
|
160
|
+
assert self.total_num_kv_heads % tp_size == 0
|
161
|
+
else:
|
162
|
+
# Number of KV heads is less than TP size, so we replicate
|
163
|
+
# the KV heads across multiple tensor parallel GPUs.
|
164
|
+
assert tp_size % self.total_num_kv_heads == 0
|
165
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
166
|
+
self.head_dim = hidden_size // self.total_num_heads
|
167
|
+
self.q_size = self.num_heads * self.head_dim
|
168
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
169
|
+
self.scaling = self.head_dim**-0.5
|
170
|
+
self.rope_theta = rope_theta
|
171
|
+
self.sliding_window = sliding_window
|
172
|
+
|
173
|
+
self.qkv_proj = QKVParallelLinear(
|
174
|
+
hidden_size,
|
175
|
+
self.head_dim,
|
176
|
+
self.total_num_heads,
|
177
|
+
self.total_num_kv_heads,
|
178
|
+
bias=False,
|
179
|
+
quant_config=quant_config,
|
180
|
+
)
|
181
|
+
self.o_proj = RowParallelLinear(
|
182
|
+
self.total_num_heads * self.head_dim,
|
183
|
+
hidden_size,
|
184
|
+
bias=False,
|
185
|
+
quant_config=quant_config,
|
186
|
+
)
|
187
|
+
self.rotary_emb = get_rope(
|
188
|
+
self.head_dim,
|
189
|
+
rotary_dim=self.head_dim,
|
190
|
+
max_position=max_position,
|
191
|
+
base=int(self.rope_theta),
|
192
|
+
is_neox_style=True,
|
193
|
+
)
|
194
|
+
self.attn = RadixAttention(
|
195
|
+
self.num_heads,
|
196
|
+
self.head_dim,
|
197
|
+
self.scaling,
|
198
|
+
num_kv_heads=self.num_kv_heads,
|
199
|
+
layer_id=layer_id,
|
200
|
+
)
|
201
|
+
|
202
|
+
def forward(
|
203
|
+
self,
|
204
|
+
positions: torch.Tensor,
|
205
|
+
hidden_states: torch.Tensor,
|
206
|
+
input_metadata: InputMetadata,
|
207
|
+
) -> torch.Tensor:
|
208
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
209
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
210
|
+
q, k = self.rotary_emb(positions, q, k)
|
211
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
212
|
+
output, _ = self.o_proj(attn_output)
|
213
|
+
return output
|
214
|
+
|
215
|
+
|
216
|
+
class MixtralDecoderLayer(nn.Module):
|
217
|
+
def __init__(
|
218
|
+
self,
|
219
|
+
config: MixtralConfig,
|
220
|
+
layer_id: int = 0,
|
221
|
+
quant_config: Optional[QuantizationConfig] = None,
|
222
|
+
) -> None:
|
223
|
+
super().__init__()
|
224
|
+
self.hidden_size = config.hidden_size
|
225
|
+
# Requires transformers > 4.32.0
|
226
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
227
|
+
self.self_attn = MixtralAttention(
|
228
|
+
hidden_size=self.hidden_size,
|
229
|
+
num_heads=config.num_attention_heads,
|
230
|
+
max_position=config.max_position_embeddings,
|
231
|
+
num_kv_heads=config.num_key_value_heads,
|
232
|
+
layer_id=layer_id,
|
233
|
+
rope_theta=rope_theta,
|
234
|
+
sliding_window=config.sliding_window,
|
235
|
+
quant_config=quant_config,
|
236
|
+
)
|
237
|
+
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
|
238
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
239
|
+
self.post_attention_layernorm = RMSNorm(
|
240
|
+
config.hidden_size, eps=config.rms_norm_eps
|
241
|
+
)
|
242
|
+
|
243
|
+
def forward(
|
244
|
+
self,
|
245
|
+
positions: torch.Tensor,
|
246
|
+
hidden_states: torch.Tensor,
|
247
|
+
input_metadata: InputMetadata,
|
248
|
+
residual: Optional[torch.Tensor],
|
249
|
+
) -> torch.Tensor:
|
250
|
+
# Self Attention
|
251
|
+
if residual is None:
|
252
|
+
residual = hidden_states
|
253
|
+
hidden_states = self.input_layernorm(hidden_states)
|
254
|
+
else:
|
255
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
256
|
+
hidden_states = self.self_attn(
|
257
|
+
positions=positions,
|
258
|
+
hidden_states=hidden_states,
|
259
|
+
input_metadata=input_metadata,
|
260
|
+
)
|
261
|
+
|
262
|
+
# Fully Connected
|
263
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
264
|
+
hidden_states = self.block_sparse_moe(hidden_states)
|
265
|
+
return hidden_states, residual
|
266
|
+
|
267
|
+
|
268
|
+
class MixtralModel(nn.Module):
|
269
|
+
def __init__(
|
270
|
+
self,
|
271
|
+
config: MixtralConfig,
|
272
|
+
quant_config: Optional[QuantizationConfig] = None,
|
273
|
+
) -> None:
|
274
|
+
super().__init__()
|
275
|
+
self.padding_idx = config.pad_token_id
|
276
|
+
self.vocab_size = config.vocab_size
|
277
|
+
|
278
|
+
self.embed_tokens = VocabParallelEmbedding(
|
279
|
+
config.vocab_size,
|
280
|
+
config.hidden_size,
|
281
|
+
)
|
282
|
+
self.layers = nn.ModuleList(
|
283
|
+
[
|
284
|
+
MixtralDecoderLayer(config, i, quant_config=quant_config)
|
285
|
+
for i in range(config.num_hidden_layers)
|
286
|
+
]
|
287
|
+
)
|
288
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
289
|
+
|
290
|
+
def forward(
|
291
|
+
self,
|
292
|
+
input_ids: torch.Tensor,
|
293
|
+
positions: torch.Tensor,
|
294
|
+
input_metadata: InputMetadata,
|
295
|
+
input_embeds: torch.Tensor = None,
|
296
|
+
) -> torch.Tensor:
|
297
|
+
if input_embeds is None:
|
298
|
+
hidden_states = self.embed_tokens(input_ids)
|
299
|
+
else:
|
300
|
+
hidden_states = input_embeds
|
301
|
+
residual = None
|
302
|
+
for i in range(len(self.layers)):
|
303
|
+
layer = self.layers[i]
|
304
|
+
hidden_states, residual = layer(
|
305
|
+
positions, hidden_states, input_metadata, residual
|
306
|
+
)
|
307
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
308
|
+
return hidden_states
|
309
|
+
|
310
|
+
|
311
|
+
class QuantMixtralForCausalLM(nn.Module):
|
312
|
+
def __init__(
|
313
|
+
self,
|
314
|
+
config: MixtralConfig,
|
315
|
+
quant_config: Optional[QuantizationConfig] = None,
|
316
|
+
cache_config: Optional[CacheConfig] = None,
|
317
|
+
) -> None:
|
318
|
+
super().__init__()
|
319
|
+
self.config = config
|
320
|
+
self.quant_config = quant_config
|
321
|
+
self.model = MixtralModel(config, quant_config=quant_config)
|
322
|
+
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
323
|
+
self.logits_processor = LogitsProcessor(config)
|
324
|
+
|
325
|
+
def forward(
|
326
|
+
self,
|
327
|
+
input_ids: torch.Tensor,
|
328
|
+
positions: torch.Tensor,
|
329
|
+
input_metadata: InputMetadata,
|
330
|
+
input_embeds: torch.Tensor = None,
|
331
|
+
) -> torch.Tensor:
|
332
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
333
|
+
return self.logits_processor(
|
334
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
335
|
+
)
|
336
|
+
|
337
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
338
|
+
stacked_params_mapping = [
|
339
|
+
# (param_name, shard_name, shard_id)
|
340
|
+
("qkv_proj", "q_proj", "q"),
|
341
|
+
("qkv_proj", "k_proj", "k"),
|
342
|
+
("qkv_proj", "v_proj", "v"),
|
343
|
+
]
|
344
|
+
|
345
|
+
params_dict = dict(self.named_parameters())
|
346
|
+
for name, loaded_weight in weights:
|
347
|
+
if "rotary_emb.inv_freq" in name:
|
348
|
+
continue
|
349
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
350
|
+
if weight_name not in name:
|
351
|
+
continue
|
352
|
+
name = name.replace(weight_name, param_name)
|
353
|
+
# Skip loading extra bias for GPTQ models.
|
354
|
+
if name.endswith(".bias") and name not in params_dict:
|
355
|
+
continue
|
356
|
+
param = params_dict[name]
|
357
|
+
weight_loader = param.weight_loader
|
358
|
+
weight_loader(param, loaded_weight, shard_id)
|
359
|
+
break
|
360
|
+
else:
|
361
|
+
# Skip loading extra bias for GPTQ models.
|
362
|
+
if name.endswith(".bias") and name not in params_dict:
|
363
|
+
continue
|
364
|
+
# Skip experts that are not assigned to this worker.
|
365
|
+
if "block_sparse_moe.experts." in name and name not in params_dict:
|
366
|
+
continue
|
367
|
+
param = params_dict[name]
|
368
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
369
|
+
weight_loader(param, loaded_weight)
|
370
|
+
|
371
|
+
|
372
|
+
EntryClass = QuantMixtralForCausalLM
|
sglang/srt/models/qwen.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1
|
-
|
1
|
+
# Adapted from
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
|
3
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
2
4
|
|
3
5
|
import torch
|
4
6
|
from torch import nn
|
5
7
|
from transformers import PretrainedConfig
|
8
|
+
from vllm.config import CacheConfig
|
6
9
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
7
10
|
from vllm.model_executor.layers.activation import SiluAndMul
|
8
11
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
@@ -17,11 +20,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
17
20
|
ParallelLMHead,
|
18
21
|
VocabParallelEmbedding,
|
19
22
|
)
|
23
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
20
24
|
|
21
25
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
22
26
|
from sglang.srt.layers.radix_attention import RadixAttention
|
23
|
-
from sglang.srt.managers.
|
24
|
-
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
27
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
25
28
|
|
26
29
|
|
27
30
|
class QWenMLP(nn.Module):
|
@@ -225,6 +228,7 @@ class QWenLMHeadModel(nn.Module):
|
|
225
228
|
self,
|
226
229
|
config: PretrainedConfig,
|
227
230
|
quant_config: Optional[QuantizationConfig] = None,
|
231
|
+
cache_config: Optional[CacheConfig] = None,
|
228
232
|
):
|
229
233
|
super().__init__()
|
230
234
|
self.config = config
|
@@ -245,22 +249,14 @@ class QWenLMHeadModel(nn.Module):
|
|
245
249
|
)
|
246
250
|
return next_tokens
|
247
251
|
|
248
|
-
def load_weights(
|
249
|
-
self,
|
250
|
-
model_name_or_path: str,
|
251
|
-
cache_dir: Optional[str] = None,
|
252
|
-
load_format: str = "auto",
|
253
|
-
revision: Optional[str] = None,
|
254
|
-
):
|
252
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
255
253
|
stacked_params_mapping = [
|
256
254
|
# (param_name, shard_name, shard_id)
|
257
255
|
("gate_up_proj", "w2", 0),
|
258
256
|
("gate_up_proj", "w1", 1),
|
259
257
|
]
|
260
258
|
params_dict = dict(self.named_parameters())
|
261
|
-
for name, loaded_weight in
|
262
|
-
model_name_or_path, cache_dir, load_format, revision
|
263
|
-
):
|
259
|
+
for name, loaded_weight in weights:
|
264
260
|
if "rotary_emb.inv_freq" in name:
|
265
261
|
continue
|
266
262
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
sglang/srt/models/qwen2.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
1
|
# Adapted from llama2.py
|
2
2
|
# Modify details for the adaptation of Qwen2 model.
|
3
3
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
4
|
-
from typing import Any, Dict, Optional, Tuple
|
4
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
8
|
+
from vllm.config import CacheConfig
|
8
9
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
9
10
|
from vllm.model_executor.layers.activation import SiluAndMul
|
10
11
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
@@ -19,11 +20,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
19
20
|
ParallelLMHead,
|
20
21
|
VocabParallelEmbedding,
|
21
22
|
)
|
23
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
22
24
|
|
23
25
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
24
26
|
from sglang.srt.layers.radix_attention import RadixAttention
|
25
|
-
from sglang.srt.managers.
|
26
|
-
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
27
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
27
28
|
|
28
29
|
Qwen2Config = None
|
29
30
|
|
@@ -251,6 +252,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
251
252
|
self,
|
252
253
|
config: Qwen2Config,
|
253
254
|
quant_config: Optional[QuantizationConfig] = None,
|
255
|
+
cache_config: Optional[CacheConfig] = None,
|
254
256
|
) -> None:
|
255
257
|
super().__init__()
|
256
258
|
self.config = config
|
@@ -271,13 +273,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
271
273
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
272
274
|
)
|
273
275
|
|
274
|
-
def load_weights(
|
275
|
-
self,
|
276
|
-
model_name_or_path: str,
|
277
|
-
cache_dir: Optional[str] = None,
|
278
|
-
load_format: str = "auto",
|
279
|
-
revision: Optional[str] = None,
|
280
|
-
):
|
276
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
281
277
|
stacked_params_mapping = [
|
282
278
|
# (param_name, shard_name, shard_id)
|
283
279
|
("qkv_proj", "q_proj", "q"),
|
@@ -287,9 +283,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
287
283
|
("gate_up_proj", "up_proj", 1),
|
288
284
|
]
|
289
285
|
params_dict = dict(self.named_parameters())
|
290
|
-
for name, loaded_weight in
|
291
|
-
model_name_or_path, cache_dir, load_format, revision
|
292
|
-
):
|
286
|
+
for name, loaded_weight in weights:
|
293
287
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
294
288
|
continue
|
295
289
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
@@ -303,6 +297,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|
303
297
|
# Skip loading extra bias for GPTQ models.
|
304
298
|
if name.endswith(".bias") and name not in params_dict:
|
305
299
|
continue
|
300
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
301
|
+
continue
|
306
302
|
param = params_dict[name]
|
307
303
|
weight_loader = param.weight_loader
|
308
304
|
weight_loader(param, loaded_weight, shard_id)
|
@@ -311,6 +307,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|
311
307
|
# Skip loading extra bias for GPTQ models.
|
312
308
|
if name.endswith(".bias") and name not in params_dict:
|
313
309
|
continue
|
310
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
311
|
+
continue
|
314
312
|
param = params_dict[name]
|
315
313
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
316
314
|
weight_loader(param, loaded_weight)
|
sglang/srt/models/stablelm.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1
|
-
#
|
2
|
-
# https://github.com/vllm-project/vllm/blob/
|
1
|
+
# Adapted from:
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
|
3
3
|
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
|
4
4
|
model compatible with HuggingFace weights."""
|
5
|
-
from typing import Optional, Tuple
|
5
|
+
from typing import Iterable, Optional, Tuple
|
6
6
|
|
7
7
|
import torch
|
8
8
|
from torch import nn
|
9
9
|
from transformers import PretrainedConfig
|
10
|
+
from vllm.config import CacheConfig
|
10
11
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
11
12
|
from vllm.model_executor.layers.activation import SiluAndMul
|
12
13
|
from vllm.model_executor.layers.linear import (
|
@@ -20,11 +21,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
20
21
|
ParallelLMHead,
|
21
22
|
VocabParallelEmbedding,
|
22
23
|
)
|
24
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
23
25
|
|
24
26
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
25
27
|
from sglang.srt.layers.radix_attention import RadixAttention
|
26
|
-
from sglang.srt.managers.
|
27
|
-
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
28
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
28
29
|
|
29
30
|
|
30
31
|
class StablelmMLP(nn.Module):
|
@@ -225,6 +226,7 @@ class StableLmForCausalLM(nn.Module):
|
|
225
226
|
self,
|
226
227
|
config: PretrainedConfig,
|
227
228
|
quant_config: Optional[QuantizationConfig] = None,
|
229
|
+
cache_config: Optional[CacheConfig] = None,
|
228
230
|
) -> None:
|
229
231
|
super().__init__()
|
230
232
|
self.config = config
|
@@ -245,13 +247,7 @@ class StableLmForCausalLM(nn.Module):
|
|
245
247
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
246
248
|
)
|
247
249
|
|
248
|
-
def load_weights(
|
249
|
-
self,
|
250
|
-
model_name_or_path: str,
|
251
|
-
cache_dir: Optional[str] = None,
|
252
|
-
load_format: str = "auto",
|
253
|
-
revision: Optional[str] = None,
|
254
|
-
):
|
250
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
255
251
|
stacked_params_mapping = [
|
256
252
|
# (param_name, shard_name, shard_id)
|
257
253
|
("qkv_proj", "q_proj", "q"),
|
@@ -261,9 +257,7 @@ class StableLmForCausalLM(nn.Module):
|
|
261
257
|
("gate_up_proj", "up_proj", 1),
|
262
258
|
]
|
263
259
|
params_dict = dict(self.named_parameters())
|
264
|
-
for name, loaded_weight in
|
265
|
-
model_name_or_path, cache_dir, load_format, revision
|
266
|
-
):
|
260
|
+
for name, loaded_weight in weights:
|
267
261
|
if "rotary_emb.inv_freq" in name:
|
268
262
|
continue
|
269
263
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
sglang/srt/models/yivl.py
CHANGED
@@ -1,40 +1,38 @@
|
|
1
1
|
"""Inference-only Yi-VL model."""
|
2
2
|
|
3
|
-
import
|
4
|
-
from typing import List, Optional
|
3
|
+
from typing import Iterable, Optional, Tuple
|
5
4
|
|
6
5
|
import torch
|
7
6
|
import torch.nn as nn
|
8
7
|
from transformers import CLIPVisionModel, LlavaConfig
|
8
|
+
from vllm.config import CacheConfig
|
9
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
10
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
9
11
|
|
10
12
|
from sglang.srt.models.llava import (
|
11
13
|
LlavaLlamaForCausalLM,
|
12
|
-
clip_vision_embed_forward,
|
13
14
|
monkey_path_clip_vision_embed_forward,
|
14
15
|
)
|
15
|
-
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
16
16
|
|
17
17
|
|
18
18
|
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
19
|
-
def __init__(
|
20
|
-
self
|
21
|
-
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
config: LlavaConfig,
|
22
|
+
quant_config: Optional[QuantizationConfig] = None,
|
23
|
+
cache_config: Optional[CacheConfig] = None,
|
24
|
+
) -> None:
|
25
|
+
super().__init__(config, quant_config, cache_config)
|
22
26
|
|
23
27
|
self.multi_modal_projector = YiVLMultiModalProjector(self.config)
|
24
28
|
self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
|
25
29
|
"./", ""
|
26
30
|
) # Everything after "./"
|
27
31
|
|
28
|
-
def load_weights(
|
29
|
-
self,
|
30
|
-
model_name_or_path: str,
|
31
|
-
cache_dir: Optional[str] = None,
|
32
|
-
load_format: str = "auto",
|
33
|
-
revision: Optional[str] = None,
|
34
|
-
):
|
32
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
35
33
|
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
|
36
34
|
self.vision_tower = CLIPVisionModel.from_pretrained(
|
37
|
-
|
35
|
+
self.config._name_or_path,
|
38
36
|
torch_dtype=torch.float16,
|
39
37
|
subfolder=self.vision_tower_subfolder,
|
40
38
|
).cuda()
|
@@ -68,9 +66,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
|
68
66
|
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
69
67
|
}
|
70
68
|
params_dict = dict(self.named_parameters())
|
71
|
-
|
72
|
-
|
73
|
-
):
|
69
|
+
weights = list(weights)
|
70
|
+
for name, loaded_weight in weights:
|
74
71
|
if "projector" in name or "vision_tower" in name:
|
75
72
|
for weight_name, param_name in projector_weights.items():
|
76
73
|
if weight_name in name:
|
@@ -80,9 +77,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
|
80
77
|
weight_loader(param, loaded_weight)
|
81
78
|
|
82
79
|
# load language model
|
83
|
-
self.language_model.load_weights(
|
84
|
-
model_name_or_path, cache_dir, load_format, revision
|
85
|
-
)
|
80
|
+
self.language_model.load_weights(weights)
|
86
81
|
|
87
82
|
monkey_path_clip_vision_embed_forward()
|
88
83
|
|
@@ -103,7 +98,7 @@ class YiVLMultiModalProjector(nn.Module):
|
|
103
98
|
|
104
99
|
def forward(self, image_features):
|
105
100
|
hidden_states = self.linear_1(image_features)
|
106
|
-
|
101
|
+
hidden_states = self.ln_1(hidden_states)
|
107
102
|
hidden_states = self.act(hidden_states)
|
108
103
|
hidden_states = self.linear_2(hidden_states)
|
109
104
|
hidden_states = self.ln_2(hidden_states)
|