sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +6 -25
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +104 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +58 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +117 -131
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +57 -44
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -1
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,445 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""Inference-only XVERSE MoE model."""
|
17
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from torch import nn
|
21
|
+
from transformers import PretrainedConfig
|
22
|
+
from vllm.config import CacheConfig
|
23
|
+
from vllm.distributed import (
|
24
|
+
get_tensor_model_parallel_rank,
|
25
|
+
get_tensor_model_parallel_world_size,
|
26
|
+
tensor_model_parallel_all_reduce,
|
27
|
+
)
|
28
|
+
from vllm.model_executor.layers.activation import SiluAndMul
|
29
|
+
from vllm.model_executor.layers.fused_moe import fused_moe
|
30
|
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
31
|
+
from vllm.model_executor.layers.linear import (
|
32
|
+
MergedColumnParallelLinear,
|
33
|
+
QKVParallelLinear,
|
34
|
+
ReplicatedLinear,
|
35
|
+
RowParallelLinear,
|
36
|
+
)
|
37
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
38
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
39
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
40
|
+
ParallelLMHead,
|
41
|
+
VocabParallelEmbedding,
|
42
|
+
)
|
43
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
44
|
+
|
45
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
46
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
47
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
48
|
+
|
49
|
+
|
50
|
+
class XverseMLP(nn.Module):
|
51
|
+
|
52
|
+
def __init__(
|
53
|
+
self,
|
54
|
+
hidden_size: int,
|
55
|
+
intermediate_size: int,
|
56
|
+
hidden_act: str,
|
57
|
+
quant_config: Optional[QuantizationConfig] = None,
|
58
|
+
reduce_results: bool = True,
|
59
|
+
) -> None:
|
60
|
+
super().__init__()
|
61
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
62
|
+
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
63
|
+
)
|
64
|
+
self.down_proj = RowParallelLinear(
|
65
|
+
intermediate_size,
|
66
|
+
hidden_size,
|
67
|
+
bias=False,
|
68
|
+
quant_config=quant_config,
|
69
|
+
reduce_results=reduce_results,
|
70
|
+
)
|
71
|
+
if hidden_act != "silu":
|
72
|
+
raise ValueError(
|
73
|
+
f"Unsupported activation: {hidden_act}. "
|
74
|
+
"Only silu is supported for now."
|
75
|
+
)
|
76
|
+
self.act_fn = SiluAndMul()
|
77
|
+
|
78
|
+
def forward(self, x):
|
79
|
+
gate_up, _ = self.gate_up_proj(x)
|
80
|
+
x = self.act_fn(gate_up)
|
81
|
+
x, _ = self.down_proj(x)
|
82
|
+
return x
|
83
|
+
|
84
|
+
|
85
|
+
class XverseMoE(nn.Module):
|
86
|
+
|
87
|
+
def __init__(
|
88
|
+
self,
|
89
|
+
config: PretrainedConfig,
|
90
|
+
quant_config: Optional[QuantizationConfig] = None,
|
91
|
+
):
|
92
|
+
super().__init__()
|
93
|
+
self.config = config
|
94
|
+
self.rank = get_tensor_model_parallel_rank()
|
95
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
96
|
+
self.n_routed_experts = config.num_experts
|
97
|
+
self.top_k = config.moe_top_k
|
98
|
+
if self.tp_size > self.n_routed_experts:
|
99
|
+
raise ValueError(
|
100
|
+
f"Tensor parallel size {self.tp_size} is greater than "
|
101
|
+
f"the number of experts {self.n_routed_experts}."
|
102
|
+
)
|
103
|
+
|
104
|
+
self.experts = nn.ModuleList(
|
105
|
+
[
|
106
|
+
XverseMLP(
|
107
|
+
hidden_size=config.hidden_size,
|
108
|
+
intermediate_size=config.intermediate_size,
|
109
|
+
hidden_act=config.hidden_act,
|
110
|
+
quant_config=quant_config,
|
111
|
+
reduce_results=False,
|
112
|
+
)
|
113
|
+
for _ in range(self.n_routed_experts)
|
114
|
+
]
|
115
|
+
)
|
116
|
+
self.pack_params()
|
117
|
+
|
118
|
+
self.router = ReplicatedLinear(
|
119
|
+
config.hidden_size, self.n_routed_experts, bias=False, quant_config=None
|
120
|
+
)
|
121
|
+
|
122
|
+
if config.num_shared_experts is not None:
|
123
|
+
intermediate_size = config.intermediate_size * config.num_shared_experts
|
124
|
+
self.shared_experts = XverseMLP(
|
125
|
+
hidden_size=config.hidden_size,
|
126
|
+
intermediate_size=intermediate_size,
|
127
|
+
hidden_act=config.hidden_act,
|
128
|
+
quant_config=quant_config,
|
129
|
+
reduce_results=False,
|
130
|
+
)
|
131
|
+
|
132
|
+
def pack_params(self):
|
133
|
+
w1 = []
|
134
|
+
w2 = []
|
135
|
+
for expert in self.experts:
|
136
|
+
w1.append(expert.gate_up_proj.weight)
|
137
|
+
w2.append(expert.down_proj.weight)
|
138
|
+
self.w1 = torch._utils._flatten_dense_tensors(w1)
|
139
|
+
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
|
140
|
+
for data, param in zip(w1s, w1):
|
141
|
+
param.data = data
|
142
|
+
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
|
143
|
+
|
144
|
+
self.w2 = torch._utils._flatten_dense_tensors(w2)
|
145
|
+
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
|
146
|
+
for data, param in zip(w2s, w2):
|
147
|
+
param.data = data
|
148
|
+
|
149
|
+
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
|
150
|
+
|
151
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
152
|
+
num_tokens, hidden_dim = hidden_states.shape
|
153
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
154
|
+
if self.config.num_shared_experts is not None:
|
155
|
+
shared_output = self.shared_experts(hidden_states)
|
156
|
+
# router_logits: (num_tokens, n_experts)
|
157
|
+
router_logits, _ = self.router(hidden_states)
|
158
|
+
final_hidden_states = fused_moe(
|
159
|
+
hidden_states,
|
160
|
+
self.w1,
|
161
|
+
self.w2,
|
162
|
+
router_logits,
|
163
|
+
self.top_k,
|
164
|
+
renormalize=getattr(self.config, "norm_topk_prob", False),
|
165
|
+
inplace=True,
|
166
|
+
)
|
167
|
+
|
168
|
+
if self.config.num_shared_experts is not None:
|
169
|
+
final_hidden_states = final_hidden_states + shared_output
|
170
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
171
|
+
|
172
|
+
return final_hidden_states.view(num_tokens, hidden_dim)
|
173
|
+
|
174
|
+
|
175
|
+
class XverseAttention(nn.Module):
|
176
|
+
|
177
|
+
def __init__(
|
178
|
+
self,
|
179
|
+
hidden_size: int,
|
180
|
+
num_heads: int,
|
181
|
+
num_kv_heads: int,
|
182
|
+
layer_id: int = 0,
|
183
|
+
rope_theta: float = 10000,
|
184
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
185
|
+
max_position_embeddings: int = 8192,
|
186
|
+
cache_config: Optional[CacheConfig] = None,
|
187
|
+
quant_config: Optional[QuantizationConfig] = None,
|
188
|
+
) -> None:
|
189
|
+
super().__init__()
|
190
|
+
self.hidden_size = hidden_size
|
191
|
+
tp_size = get_tensor_model_parallel_world_size()
|
192
|
+
self.total_num_heads = num_heads
|
193
|
+
assert self.total_num_heads % tp_size == 0
|
194
|
+
self.num_heads = self.total_num_heads // tp_size
|
195
|
+
self.total_num_kv_heads = num_kv_heads
|
196
|
+
if self.total_num_kv_heads >= tp_size:
|
197
|
+
# Number of KV heads is greater than TP size, so we partition
|
198
|
+
# the KV heads across multiple tensor parallel GPUs.
|
199
|
+
assert self.total_num_kv_heads % tp_size == 0
|
200
|
+
else:
|
201
|
+
# Number of KV heads is less than TP size, so we replicate
|
202
|
+
# the KV heads across multiple tensor parallel GPUs.
|
203
|
+
assert tp_size % self.total_num_kv_heads == 0
|
204
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
205
|
+
self.head_dim = hidden_size // self.total_num_heads
|
206
|
+
self.q_size = self.num_heads * self.head_dim
|
207
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
208
|
+
self.scaling = self.head_dim**-0.5
|
209
|
+
self.rope_theta = rope_theta
|
210
|
+
self.max_position_embeddings = max_position_embeddings
|
211
|
+
|
212
|
+
self.qkv_proj = QKVParallelLinear(
|
213
|
+
hidden_size,
|
214
|
+
self.head_dim,
|
215
|
+
self.total_num_heads,
|
216
|
+
self.total_num_kv_heads,
|
217
|
+
bias=False,
|
218
|
+
quant_config=quant_config,
|
219
|
+
)
|
220
|
+
|
221
|
+
self.o_proj = RowParallelLinear(
|
222
|
+
self.total_num_heads * self.head_dim,
|
223
|
+
hidden_size,
|
224
|
+
bias=False,
|
225
|
+
quant_config=quant_config,
|
226
|
+
)
|
227
|
+
|
228
|
+
self.rotary_emb = get_rope(
|
229
|
+
self.head_dim,
|
230
|
+
rotary_dim=self.head_dim,
|
231
|
+
max_position=max_position_embeddings,
|
232
|
+
base=rope_theta,
|
233
|
+
rope_scaling=rope_scaling,
|
234
|
+
)
|
235
|
+
self.attn = RadixAttention(
|
236
|
+
self.num_heads,
|
237
|
+
self.head_dim,
|
238
|
+
self.scaling,
|
239
|
+
num_kv_heads=self.num_kv_heads,
|
240
|
+
layer_id=layer_id,
|
241
|
+
)
|
242
|
+
|
243
|
+
def forward(
|
244
|
+
self,
|
245
|
+
positions: torch.Tensor,
|
246
|
+
hidden_states: torch.Tensor,
|
247
|
+
input_metadata: InputMetadata,
|
248
|
+
) -> torch.Tensor:
|
249
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
250
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
251
|
+
q, k = self.rotary_emb(positions, q, k)
|
252
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
253
|
+
output, _ = self.o_proj(attn_output)
|
254
|
+
return output
|
255
|
+
|
256
|
+
|
257
|
+
class XverseDecoderLayer(nn.Module):
|
258
|
+
|
259
|
+
def __init__(
|
260
|
+
self,
|
261
|
+
config: PretrainedConfig,
|
262
|
+
layer_id: int,
|
263
|
+
cache_config: Optional[CacheConfig] = None,
|
264
|
+
quant_config: Optional[QuantizationConfig] = None,
|
265
|
+
) -> None:
|
266
|
+
super().__init__()
|
267
|
+
self.hidden_size = config.hidden_size
|
268
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
269
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
270
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
271
|
+
num_key_value_heads = getattr(
|
272
|
+
config, "num_key_value_heads", config.num_attention_heads
|
273
|
+
)
|
274
|
+
self.self_attn = XverseAttention(
|
275
|
+
hidden_size=self.hidden_size,
|
276
|
+
num_heads=config.num_attention_heads,
|
277
|
+
num_kv_heads=num_key_value_heads,
|
278
|
+
layer_id=layer_id,
|
279
|
+
rope_theta=rope_theta,
|
280
|
+
rope_scaling=rope_scaling,
|
281
|
+
max_position_embeddings=max_position_embeddings,
|
282
|
+
cache_config=cache_config,
|
283
|
+
quant_config=quant_config,
|
284
|
+
)
|
285
|
+
if config.num_experts is not None:
|
286
|
+
self.mlp = XverseMoE(config=config, quant_config=quant_config)
|
287
|
+
else:
|
288
|
+
self.mlp = XverseMLP(
|
289
|
+
hidden_size=config.hidden_size,
|
290
|
+
intermediate_size=config.intermediate_size,
|
291
|
+
hidden_act=config.hidden_act,
|
292
|
+
quant_config=quant_config,
|
293
|
+
)
|
294
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
295
|
+
self.post_attention_layernorm = RMSNorm(
|
296
|
+
config.hidden_size, eps=config.rms_norm_eps
|
297
|
+
)
|
298
|
+
|
299
|
+
def forward(
|
300
|
+
self,
|
301
|
+
positions: torch.Tensor,
|
302
|
+
hidden_states: torch.Tensor,
|
303
|
+
input_metadata: InputMetadata,
|
304
|
+
residual: Optional[torch.Tensor],
|
305
|
+
) -> torch.Tensor:
|
306
|
+
# Self Attention
|
307
|
+
if residual is None:
|
308
|
+
residual = hidden_states
|
309
|
+
hidden_states = self.input_layernorm(hidden_states)
|
310
|
+
else:
|
311
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
312
|
+
hidden_states = self.self_attn(
|
313
|
+
positions=positions,
|
314
|
+
hidden_states=hidden_states,
|
315
|
+
input_metadata=input_metadata,
|
316
|
+
)
|
317
|
+
|
318
|
+
# Fully Connected
|
319
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
320
|
+
hidden_states = self.mlp(hidden_states)
|
321
|
+
return hidden_states, residual
|
322
|
+
|
323
|
+
|
324
|
+
class XverseModel(nn.Module):
|
325
|
+
|
326
|
+
fall_back_to_pt_during_load = False
|
327
|
+
|
328
|
+
def __init__(
|
329
|
+
self,
|
330
|
+
config: PretrainedConfig,
|
331
|
+
cache_config: Optional[CacheConfig] = None,
|
332
|
+
quant_config: Optional[QuantizationConfig] = None,
|
333
|
+
) -> None:
|
334
|
+
super().__init__()
|
335
|
+
self.padding_idx = config.pad_token_id
|
336
|
+
self.vocab_size = config.vocab_size
|
337
|
+
|
338
|
+
self.embed_tokens = VocabParallelEmbedding(
|
339
|
+
config.vocab_size,
|
340
|
+
config.hidden_size,
|
341
|
+
)
|
342
|
+
self.layers = nn.ModuleList(
|
343
|
+
[
|
344
|
+
XverseDecoderLayer(
|
345
|
+
config, layer_id, cache_config, quant_config=quant_config
|
346
|
+
)
|
347
|
+
for layer_id in range(config.num_hidden_layers)
|
348
|
+
]
|
349
|
+
)
|
350
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
351
|
+
|
352
|
+
def forward(
|
353
|
+
self,
|
354
|
+
input_ids: torch.Tensor,
|
355
|
+
positions: torch.Tensor,
|
356
|
+
input_metadata: InputMetadata,
|
357
|
+
) -> torch.Tensor:
|
358
|
+
hidden_states = self.embed_tokens(input_ids)
|
359
|
+
residual = None
|
360
|
+
for i in range(len(self.layers)):
|
361
|
+
layer = self.layers[i]
|
362
|
+
hidden_states, residual = layer(
|
363
|
+
positions, hidden_states, input_metadata, residual
|
364
|
+
)
|
365
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
366
|
+
return hidden_states
|
367
|
+
|
368
|
+
|
369
|
+
class XverseMoeForCausalLM(nn.Module):
|
370
|
+
|
371
|
+
def __init__(
|
372
|
+
self,
|
373
|
+
config: PretrainedConfig,
|
374
|
+
cache_config: Optional[CacheConfig] = None,
|
375
|
+
quant_config: Optional[QuantizationConfig] = None,
|
376
|
+
) -> None:
|
377
|
+
super().__init__()
|
378
|
+
self.config = config
|
379
|
+
self.quant_config = quant_config
|
380
|
+
self.model = XverseModel(config, cache_config, quant_config)
|
381
|
+
self.lm_head = ParallelLMHead(
|
382
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
383
|
+
)
|
384
|
+
self.logits_processor = LogitsProcessor(config)
|
385
|
+
|
386
|
+
self.param_dict = dict(self.named_parameters())
|
387
|
+
|
388
|
+
@torch.no_grad()
|
389
|
+
def forward(
|
390
|
+
self,
|
391
|
+
input_ids: torch.Tensor,
|
392
|
+
positions: torch.Tensor,
|
393
|
+
input_metadata: InputMetadata,
|
394
|
+
) -> torch.Tensor:
|
395
|
+
hidden_states = self.model(input_ids, positions, input_metadata)
|
396
|
+
return self.logits_processor(
|
397
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
398
|
+
)
|
399
|
+
|
400
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
401
|
+
stacked_params_mapping = [
|
402
|
+
# (param_name, shard_name, shard_id)
|
403
|
+
("qkv_proj", "q_proj", "q"),
|
404
|
+
("qkv_proj", "k_proj", "k"),
|
405
|
+
("qkv_proj", "v_proj", "v"),
|
406
|
+
("gate_up_proj", "gate_proj", 0),
|
407
|
+
("gate_up_proj", "up_proj", 1),
|
408
|
+
]
|
409
|
+
|
410
|
+
params_dict = self.param_dict
|
411
|
+
|
412
|
+
for name, loaded_weight in weights:
|
413
|
+
if "rotary_emb.inv_freq" in name:
|
414
|
+
continue
|
415
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
416
|
+
if weight_name not in name:
|
417
|
+
continue
|
418
|
+
name = name.replace(weight_name, param_name)
|
419
|
+
# Skip loading extra bias for GPTQ models.
|
420
|
+
if name.endswith(".bias") and name not in params_dict:
|
421
|
+
continue
|
422
|
+
# Skip experts that are not assigned to this worker.
|
423
|
+
if (
|
424
|
+
"mlp.experts." in name or "mlp.shared_experts." in name
|
425
|
+
) and name not in params_dict:
|
426
|
+
continue
|
427
|
+
param = params_dict[name]
|
428
|
+
weight_loader = param.weight_loader
|
429
|
+
weight_loader(param, loaded_weight, shard_id)
|
430
|
+
break
|
431
|
+
else:
|
432
|
+
# Skip loading extra bias for GPTQ models.
|
433
|
+
if name.endswith(".bias") and name not in params_dict:
|
434
|
+
continue
|
435
|
+
# Skip experts that are not assigned to this worker.
|
436
|
+
if (
|
437
|
+
"mlp.experts." in name or "mlp.shared_experts." in name
|
438
|
+
) and name not in params_dict:
|
439
|
+
continue
|
440
|
+
param = params_dict[name]
|
441
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
442
|
+
weight_loader(param, loaded_weight)
|
443
|
+
|
444
|
+
|
445
|
+
EntryClass = XverseMoeForCausalLM
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -22,12 +22,19 @@ import os
|
|
22
22
|
import time
|
23
23
|
import uuid
|
24
24
|
from http import HTTPStatus
|
25
|
-
from typing import Dict, List
|
25
|
+
from typing import Dict, List
|
26
26
|
|
27
27
|
from fastapi import HTTPException, Request, UploadFile
|
28
28
|
from fastapi.responses import JSONResponse, StreamingResponse
|
29
29
|
from pydantic import ValidationError
|
30
30
|
|
31
|
+
try:
|
32
|
+
from outlines.fsm.json_schema import convert_json_schema_to_str
|
33
|
+
except ImportError:
|
34
|
+
# Before outlines 0.0.47, convert_json_schema_to_str is under
|
35
|
+
# outlines.integrations.utils
|
36
|
+
from outlines.integrations.utils import convert_json_schema_to_str
|
37
|
+
|
31
38
|
from sglang.srt.conversation import (
|
32
39
|
Conversation,
|
33
40
|
SeparatorStyle,
|
@@ -88,19 +95,6 @@ file_id_storage: Dict[str, str] = {}
|
|
88
95
|
storage_dir = None
|
89
96
|
|
90
97
|
|
91
|
-
def format_finish_reason(finish_reason) -> Optional[str]:
|
92
|
-
if finish_reason.startswith("None"):
|
93
|
-
return None
|
94
|
-
elif finish_reason.startswith("FINISH_MATCHED"):
|
95
|
-
return "stop"
|
96
|
-
elif finish_reason.startswith("FINISH_LENGTH"):
|
97
|
-
return "length"
|
98
|
-
elif finish_reason.startswith("FINISH_ABORT"):
|
99
|
-
return "abort"
|
100
|
-
else:
|
101
|
-
return "unknown"
|
102
|
-
|
103
|
-
|
104
98
|
def create_error_response(
|
105
99
|
message: str,
|
106
100
|
err_type: str = "BadRequestError",
|
@@ -478,7 +472,7 @@ def v1_generate_request(
|
|
478
472
|
first_prompt_type = type(all_requests[0].prompt)
|
479
473
|
for request in all_requests:
|
480
474
|
assert (
|
481
|
-
type(request.prompt)
|
475
|
+
type(request.prompt) is first_prompt_type
|
482
476
|
), "All prompts must be of the same type in file input settings"
|
483
477
|
if len(all_requests) > 1 and request.n > 1:
|
484
478
|
raise ValueError(
|
@@ -611,8 +605,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|
611
605
|
"index": 0,
|
612
606
|
"text": text,
|
613
607
|
"logprobs": logprobs,
|
614
|
-
"finish_reason":
|
615
|
-
ret_item["meta_info"]["finish_reason"]
|
608
|
+
"finish_reason": (
|
609
|
+
ret_item["meta_info"]["finish_reason"]["type"]
|
610
|
+
if ret_item["meta_info"]["finish_reason"]
|
611
|
+
else ""
|
616
612
|
),
|
617
613
|
}
|
618
614
|
else:
|
@@ -620,8 +616,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|
620
616
|
index=idx,
|
621
617
|
text=text,
|
622
618
|
logprobs=logprobs,
|
623
|
-
finish_reason=
|
624
|
-
ret_item["meta_info"]["finish_reason"]
|
619
|
+
finish_reason=(
|
620
|
+
ret_item["meta_info"]["finish_reason"]["type"]
|
621
|
+
if ret_item["meta_info"]["finish_reason"]
|
622
|
+
else ""
|
625
623
|
),
|
626
624
|
)
|
627
625
|
|
@@ -755,8 +753,10 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
755
753
|
index=index,
|
756
754
|
text=delta,
|
757
755
|
logprobs=logprobs,
|
758
|
-
finish_reason=
|
759
|
-
content["meta_info"]["finish_reason"]
|
756
|
+
finish_reason=(
|
757
|
+
content["meta_info"]["finish_reason"]["type"]
|
758
|
+
if content["meta_info"]["finish_reason"]
|
759
|
+
else ""
|
760
760
|
),
|
761
761
|
)
|
762
762
|
chunk = CompletionStreamResponse(
|
@@ -832,6 +832,7 @@ def v1_chat_generate_request(
|
|
832
832
|
return_logprobs = []
|
833
833
|
logprob_start_lens = []
|
834
834
|
top_logprobs_nums = []
|
835
|
+
modalities_list = []
|
835
836
|
|
836
837
|
# NOTE: with openai API, the prompt's logprobs are always not computed
|
837
838
|
|
@@ -864,10 +865,12 @@ def v1_chat_generate_request(
|
|
864
865
|
)
|
865
866
|
stop = request.stop
|
866
867
|
image_data = None
|
868
|
+
modalities = []
|
867
869
|
else:
|
868
870
|
conv = generate_chat_conv(request, chat_template_name)
|
869
871
|
prompt = conv.get_prompt()
|
870
872
|
image_data = conv.image_data
|
873
|
+
modalities = conv.modalities
|
871
874
|
stop = conv.stop_str or []
|
872
875
|
if request.stop:
|
873
876
|
if isinstance(request.stop, str):
|
@@ -880,27 +883,33 @@ def v1_chat_generate_request(
|
|
880
883
|
prompt_ids = request.messages
|
881
884
|
stop = request.stop
|
882
885
|
image_data = None
|
886
|
+
modalities = []
|
883
887
|
input_ids.append(prompt_ids)
|
884
888
|
return_logprobs.append(request.logprobs)
|
885
889
|
logprob_start_lens.append(-1)
|
886
|
-
top_logprobs_nums.append(request.top_logprobs)
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
890
|
+
top_logprobs_nums.append(request.top_logprobs or 0)
|
891
|
+
|
892
|
+
sampling_params = {
|
893
|
+
"temperature": request.temperature,
|
894
|
+
"max_new_tokens": request.max_tokens,
|
895
|
+
"min_new_tokens": request.min_tokens,
|
896
|
+
"stop": stop,
|
897
|
+
"stop_token_ids": request.stop_token_ids,
|
898
|
+
"top_p": request.top_p,
|
899
|
+
"presence_penalty": request.presence_penalty,
|
900
|
+
"frequency_penalty": request.frequency_penalty,
|
901
|
+
"repetition_penalty": request.repetition_penalty,
|
902
|
+
"regex": request.regex,
|
903
|
+
"n": request.n,
|
904
|
+
}
|
905
|
+
if request.response_format and request.response_format.type == "json_schema":
|
906
|
+
sampling_params["json_schema"] = convert_json_schema_to_str(
|
907
|
+
request.response_format.json_schema.schema_
|
908
|
+
)
|
909
|
+
sampling_params_list.append(sampling_params)
|
910
|
+
|
903
911
|
image_data_list.append(image_data)
|
912
|
+
modalities_list.extend(modalities)
|
904
913
|
if len(all_requests) == 1:
|
905
914
|
input_ids = input_ids[0]
|
906
915
|
if isinstance(input_ids, str):
|
@@ -912,6 +921,7 @@ def v1_chat_generate_request(
|
|
912
921
|
return_logprobs = return_logprobs[0]
|
913
922
|
logprob_start_lens = logprob_start_lens[0]
|
914
923
|
top_logprobs_nums = top_logprobs_nums[0]
|
924
|
+
modalities_list = modalities_list[:1]
|
915
925
|
else:
|
916
926
|
if isinstance(input_ids[0], str):
|
917
927
|
prompt_kwargs = {"text": input_ids}
|
@@ -928,6 +938,7 @@ def v1_chat_generate_request(
|
|
928
938
|
stream=all_requests[0].stream,
|
929
939
|
return_text_in_logprobs=True,
|
930
940
|
rid=request_ids,
|
941
|
+
modalities=modalities_list,
|
931
942
|
)
|
932
943
|
if len(all_requests) == 1:
|
933
944
|
return adapted_request, all_requests[0]
|
@@ -981,8 +992,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
981
992
|
"index": 0,
|
982
993
|
"message": {"role": "assistant", "content": ret_item["text"]},
|
983
994
|
"logprobs": choice_logprobs,
|
984
|
-
"finish_reason":
|
985
|
-
ret_item["meta_info"]["finish_reason"]
|
995
|
+
"finish_reason": (
|
996
|
+
ret_item["meta_info"]["finish_reason"]["type"]
|
997
|
+
if ret_item["meta_info"]["finish_reason"]
|
998
|
+
else ""
|
986
999
|
),
|
987
1000
|
}
|
988
1001
|
else:
|
@@ -990,8 +1003,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
990
1003
|
index=idx,
|
991
1004
|
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
992
1005
|
logprobs=choice_logprobs,
|
993
|
-
finish_reason=
|
994
|
-
ret_item["meta_info"]["finish_reason"]
|
1006
|
+
finish_reason=(
|
1007
|
+
ret_item["meta_info"]["finish_reason"]["type"]
|
1008
|
+
if ret_item["meta_info"]["finish_reason"]
|
1009
|
+
else ""
|
995
1010
|
),
|
996
1011
|
)
|
997
1012
|
|
@@ -1116,8 +1131,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1116
1131
|
choice_data = ChatCompletionResponseStreamChoice(
|
1117
1132
|
index=index,
|
1118
1133
|
delta=DeltaMessage(role="assistant"),
|
1119
|
-
finish_reason=
|
1120
|
-
content["meta_info"]["finish_reason"]
|
1134
|
+
finish_reason=(
|
1135
|
+
content["meta_info"]["finish_reason"]["type"]
|
1136
|
+
if content["meta_info"]["finish_reason"]
|
1137
|
+
else ""
|
1121
1138
|
),
|
1122
1139
|
logprobs=choice_logprobs,
|
1123
1140
|
)
|
@@ -1134,8 +1151,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1134
1151
|
choice_data = ChatCompletionResponseStreamChoice(
|
1135
1152
|
index=index,
|
1136
1153
|
delta=DeltaMessage(content=delta),
|
1137
|
-
finish_reason=
|
1138
|
-
content["meta_info"]["finish_reason"]
|
1154
|
+
finish_reason=(
|
1155
|
+
content["meta_info"]["finish_reason"]["type"]
|
1156
|
+
if content["meta_info"]["finish_reason"]
|
1157
|
+
else ""
|
1139
1158
|
),
|
1140
1159
|
logprobs=choice_logprobs,
|
1141
1160
|
)
|