sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +4 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +4 -1
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +1 -1
- sglang/lang/ir.py +15 -5
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +64 -9
- sglang/srt/layers/fused_moe.py +186 -89
- sglang/srt/layers/logits_processor.py +53 -25
- sglang/srt/layers/radix_attention.py +34 -7
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +142 -67
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +8 -3
- sglang/srt/managers/controller/model_runner.py +154 -54
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +140 -135
- sglang/srt/managers/detokenizer_manager.py +15 -19
- sglang/srt/managers/io_struct.py +10 -4
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/model_config.py +83 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +11 -4
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +33 -23
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +60 -19
- sglang/srt/server_args.py +79 -44
- sglang/srt/utils.py +146 -37
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/srt/models/llava.py
CHANGED
@@ -1,11 +1,17 @@
|
|
1
1
|
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
2
2
|
|
3
|
-
from typing import
|
3
|
+
from typing import Iterable, List, Optional, Tuple
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
8
|
-
from transformers import
|
8
|
+
from transformers import (
|
9
|
+
CLIPVisionConfig,
|
10
|
+
CLIPVisionModel,
|
11
|
+
LlavaConfig,
|
12
|
+
MistralConfig,
|
13
|
+
Qwen2Config,
|
14
|
+
)
|
9
15
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
10
16
|
from vllm.config import CacheConfig
|
11
17
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
@@ -19,8 +25,8 @@ from sglang.srt.mm_utils import (
|
|
19
25
|
unpad_image_shape,
|
20
26
|
)
|
21
27
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
22
|
-
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
23
28
|
from sglang.srt.models.mistral import MistralForCausalLM
|
29
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
24
30
|
|
25
31
|
|
26
32
|
class LlavaLlamaForCausalLM(nn.Module):
|
@@ -359,6 +365,7 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
|
359
365
|
|
360
366
|
first_call = True
|
361
367
|
|
368
|
+
|
362
369
|
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
363
370
|
batch_size = pixel_values.shape[0]
|
364
371
|
|
@@ -388,8 +395,4 @@ def monkey_path_clip_vision_embed_forward():
|
|
388
395
|
)
|
389
396
|
|
390
397
|
|
391
|
-
EntryClass = [
|
392
|
-
LlavaLlamaForCausalLM,
|
393
|
-
LlavaQwenForCausalLM,
|
394
|
-
LlavaMistralForCausalLM
|
395
|
-
]
|
398
|
+
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|
sglang/srt/models/llavavid.py
CHANGED
sglang/srt/models/mixtral.py
CHANGED
@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
33
33
|
from vllm.model_executor.utils import set_weight_attrs
|
34
34
|
from vllm.utils import print_warning_once
|
35
35
|
|
36
|
-
|
37
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
38
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
38
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
40
39
|
|
41
40
|
|
42
|
-
|
43
41
|
class MixtralMoE(nn.Module):
|
44
42
|
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
45
43
|
across all ranks.
|
@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module):
|
|
76
74
|
self.params_dtype = params_dtype
|
77
75
|
|
78
76
|
# Gate always runs at half / full precision for now.
|
79
|
-
self.gate = ReplicatedLinear(
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
77
|
+
self.gate = ReplicatedLinear(
|
78
|
+
self.hidden_size,
|
79
|
+
self.num_total_experts,
|
80
|
+
bias=False,
|
81
|
+
params_dtype=self.params_dtype,
|
82
|
+
quant_config=None,
|
83
|
+
)
|
84
84
|
|
85
85
|
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
86
86
|
params_dtype = torch.float8_e4m3fn
|
87
87
|
|
88
88
|
self.w13_weight = nn.Parameter(
|
89
|
-
torch.empty(
|
90
|
-
|
91
|
-
|
92
|
-
|
89
|
+
torch.empty(
|
90
|
+
self.num_total_experts,
|
91
|
+
2 * self.intermediate_size,
|
92
|
+
self.hidden_size,
|
93
|
+
dtype=params_dtype,
|
94
|
+
)
|
95
|
+
)
|
93
96
|
self.w2_weight = nn.Parameter(
|
94
|
-
torch.empty(
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
set_weight_attrs(
|
103
|
-
|
104
|
-
|
97
|
+
torch.empty(
|
98
|
+
self.num_total_experts,
|
99
|
+
self.hidden_size,
|
100
|
+
self.intermediate_size,
|
101
|
+
dtype=params_dtype,
|
102
|
+
)
|
103
|
+
)
|
104
|
+
|
105
|
+
set_weight_attrs(
|
106
|
+
self.w13_weight,
|
107
|
+
{
|
108
|
+
"weight_loader": self.weight_loader,
|
109
|
+
},
|
110
|
+
)
|
111
|
+
set_weight_attrs(
|
112
|
+
self.w2_weight,
|
113
|
+
{
|
114
|
+
"weight_loader": self.weight_loader,
|
115
|
+
},
|
116
|
+
)
|
105
117
|
|
106
118
|
# Used for fp8.
|
107
119
|
self.w13_scale = None
|
@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module):
|
|
111
123
|
|
112
124
|
if self.use_fp8:
|
113
125
|
# WEIGHT_SCALE (for fp8)
|
114
|
-
self.w13_scale = nn.Parameter(
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
126
|
+
self.w13_scale = nn.Parameter(
|
127
|
+
torch.ones(self.num_total_experts, dtype=torch.float32),
|
128
|
+
requires_grad=False,
|
129
|
+
)
|
130
|
+
self.w2_scale = nn.Parameter(
|
131
|
+
torch.ones(self.num_total_experts, dtype=torch.float32),
|
132
|
+
requires_grad=False,
|
133
|
+
)
|
120
134
|
|
121
135
|
# If loading fp8 checkpoint, pass the weight loaders.
|
122
136
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
123
137
|
# process_weights_after_loading()
|
124
138
|
if quant_config.is_checkpoint_fp8_serialized:
|
125
|
-
set_weight_attrs(
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
139
|
+
set_weight_attrs(
|
140
|
+
self.w13_scale,
|
141
|
+
{
|
142
|
+
"weight_loader": self.weight_loader,
|
143
|
+
},
|
144
|
+
)
|
145
|
+
set_weight_attrs(
|
146
|
+
self.w2_scale,
|
147
|
+
{
|
148
|
+
"weight_loader": self.weight_loader,
|
149
|
+
},
|
150
|
+
)
|
131
151
|
|
132
152
|
# ACT_SCALE (for fp8)
|
133
153
|
if quant_config.activation_scheme == "static":
|
134
154
|
if not quant_config.is_checkpoint_fp8_serialized:
|
135
155
|
raise ValueError(
|
136
156
|
"Found static activation scheme for checkpoint that "
|
137
|
-
"was not serialized fp8."
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
set_weight_attrs(
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
157
|
+
"was not serialized fp8."
|
158
|
+
)
|
159
|
+
self.a13_scale = nn.Parameter(
|
160
|
+
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
161
|
+
requires_grad=False,
|
162
|
+
)
|
163
|
+
self.a2_scale = nn.Parameter(
|
164
|
+
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
165
|
+
requires_grad=False,
|
166
|
+
)
|
167
|
+
|
168
|
+
set_weight_attrs(
|
169
|
+
self.a13_scale,
|
170
|
+
{
|
171
|
+
"weight_loader": self.weight_loader,
|
172
|
+
},
|
173
|
+
)
|
174
|
+
set_weight_attrs(
|
175
|
+
self.a2_scale,
|
176
|
+
{
|
177
|
+
"weight_loader": self.weight_loader,
|
178
|
+
},
|
179
|
+
)
|
180
|
+
|
181
|
+
def weight_loader(
|
182
|
+
self,
|
183
|
+
param: nn.Parameter,
|
184
|
+
loaded_weight: torch.Tensor,
|
185
|
+
weight_name: str,
|
186
|
+
expert_id: int,
|
187
|
+
):
|
154
188
|
tp_rank = get_tensor_model_parallel_rank()
|
155
189
|
param_data = param.data
|
156
190
|
shard_size = self.intermediate_size
|
@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module):
|
|
158
192
|
if weight_name.endswith("w1.weight"):
|
159
193
|
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
160
194
|
if weight_name.endswith("w3.weight"):
|
161
|
-
param_data[expert_id,
|
162
|
-
|
195
|
+
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
196
|
+
shard, :
|
197
|
+
]
|
163
198
|
if weight_name.endswith("w2.weight"):
|
164
199
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
165
200
|
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module):
|
|
172
207
|
|
173
208
|
# If checkpoint is fp16, quantize here.
|
174
209
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
175
|
-
w13_weight = torch.empty_like(
|
176
|
-
|
177
|
-
|
178
|
-
|
210
|
+
w13_weight = torch.empty_like(
|
211
|
+
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
212
|
+
)
|
213
|
+
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
179
214
|
for expert in range(self.num_total_experts):
|
180
|
-
w13_weight[expert, :, :], self.w13_scale[
|
181
|
-
expert
|
182
|
-
|
183
|
-
w2_weight[expert, :, :], self.w2_scale[
|
184
|
-
expert
|
185
|
-
|
215
|
+
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
216
|
+
self.w13_weight.data[expert, :, :]
|
217
|
+
)
|
218
|
+
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
219
|
+
self.w2_weight.data[expert, :, :]
|
220
|
+
)
|
186
221
|
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
187
222
|
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
188
223
|
|
@@ -193,40 +228,40 @@ class MixtralMoE(nn.Module):
|
|
193
228
|
if self.a13_scale is None or self.a2_scale is None:
|
194
229
|
raise ValueError(
|
195
230
|
"QuantConfig has static quantization, but found "
|
196
|
-
"activation scales are None."
|
231
|
+
"activation scales are None."
|
232
|
+
)
|
197
233
|
|
198
|
-
if
|
199
|
-
or not all_close_1d(self.a2_scale)):
|
234
|
+
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
200
235
|
print_warning_once(
|
201
236
|
"Found act_scales that are not equal for fp8 MoE layer. "
|
202
|
-
"Using the maximum across experts for each layer. "
|
237
|
+
"Using the maximum across experts for each layer. "
|
238
|
+
)
|
203
239
|
|
204
|
-
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
205
|
-
|
206
|
-
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
207
|
-
requires_grad=False)
|
240
|
+
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
241
|
+
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
208
242
|
|
209
243
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
210
244
|
num_tokens, hidden_size = hidden_states.shape
|
211
245
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
212
246
|
# router_logits: (num_tokens, n_experts)
|
213
247
|
router_logits, _ = self.gate(hidden_states)
|
214
|
-
final_hidden_states = fused_moe(
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
248
|
+
final_hidden_states = fused_moe(
|
249
|
+
hidden_states,
|
250
|
+
self.w13_weight,
|
251
|
+
self.w2_weight,
|
252
|
+
router_logits,
|
253
|
+
self.top_k,
|
254
|
+
renormalize=True,
|
255
|
+
inplace=True,
|
256
|
+
use_fp8=self.use_fp8,
|
257
|
+
w1_scale=self.w13_scale,
|
258
|
+
w2_scale=self.w2_scale,
|
259
|
+
a1_scale=self.a13_scale,
|
260
|
+
a2_scale=self.a2_scale,
|
261
|
+
)
|
226
262
|
|
227
263
|
if self.tp_size > 1:
|
228
|
-
final_hidden_states = tensor_model_parallel_all_reduce(
|
229
|
-
final_hidden_states)
|
264
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
230
265
|
|
231
266
|
return final_hidden_states.view(num_tokens, hidden_size)
|
232
267
|
|
@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module):
|
|
335
370
|
top_k=config.num_experts_per_tok,
|
336
371
|
hidden_size=config.hidden_size,
|
337
372
|
intermediate_size=config.intermediate_size,
|
338
|
-
quant_config=quant_config
|
373
|
+
quant_config=quant_config,
|
374
|
+
)
|
339
375
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
340
376
|
self.post_attention_layernorm = RMSNorm(
|
341
377
|
config.hidden_size, eps=config.rms_norm_eps
|
@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module):
|
|
444
480
|
("qkv_proj", "v_proj", "v"),
|
445
481
|
]
|
446
482
|
|
447
|
-
expert_params_mapping =
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
483
|
+
expert_params_mapping = (
|
484
|
+
[
|
485
|
+
# These are the weight scales for the experts
|
486
|
+
# (param_name, weight_name, expert_id)
|
487
|
+
(
|
488
|
+
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
489
|
+
f"experts.{expert_id}.{weight_name}.weight_scale",
|
490
|
+
expert_id,
|
491
|
+
)
|
492
|
+
for expert_id in range(self.config.num_local_experts)
|
493
|
+
for weight_name in ["w1", "w2", "w3"]
|
494
|
+
]
|
495
|
+
+ [
|
496
|
+
# These are the weights for the experts
|
497
|
+
# (param_name, weight_name, expert_id)
|
498
|
+
(
|
499
|
+
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
500
|
+
f"experts.{expert_id}.{weight_name}.weight",
|
501
|
+
expert_id,
|
502
|
+
)
|
503
|
+
for expert_id in range(self.config.num_local_experts)
|
504
|
+
for weight_name in ["w1", "w2", "w3"]
|
505
|
+
]
|
506
|
+
+ [
|
507
|
+
# These are the activation scales for the experts
|
508
|
+
# (param_name, weight_name, expert_id)
|
509
|
+
(
|
510
|
+
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
511
|
+
f"experts.{expert_id}.{weight_name}.act_scale",
|
512
|
+
expert_id,
|
513
|
+
)
|
514
|
+
for expert_id in range(self.config.num_local_experts)
|
515
|
+
for weight_name in ["w1", "w2", "w3"]
|
516
|
+
]
|
517
|
+
)
|
469
518
|
|
470
519
|
params_dict = dict(self.named_parameters())
|
471
520
|
for name, loaded_weight in weights:
|
472
521
|
if "rotary_emb.inv_freq" in name:
|
473
522
|
continue
|
474
523
|
|
475
|
-
for
|
524
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
476
525
|
if weight_name not in name:
|
477
526
|
continue
|
478
527
|
name = name.replace(weight_name, param_name)
|
@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module):
|
|
490
539
|
name = name.replace(weight_name, param_name)
|
491
540
|
param = params_dict[name]
|
492
541
|
weight_loader = param.weight_loader
|
493
|
-
weight_loader(
|
494
|
-
|
495
|
-
|
496
|
-
expert_id=expert_id)
|
542
|
+
weight_loader(
|
543
|
+
param, loaded_weight, weight_name, expert_id=expert_id
|
544
|
+
)
|
497
545
|
break
|
498
546
|
else:
|
499
547
|
# Skip loading extra bias for GPTQ models.
|
500
548
|
if name.endswith(".bias") and name not in params_dict:
|
501
549
|
continue
|
502
550
|
param = params_dict[name]
|
503
|
-
weight_loader = getattr(
|
504
|
-
|
551
|
+
weight_loader = getattr(
|
552
|
+
param, "weight_loader", default_weight_loader
|
553
|
+
)
|
505
554
|
weight_loader(param, loaded_weight)
|
506
555
|
|
507
556
|
|
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
28
28
|
)
|
29
29
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
30
30
|
|
31
|
-
|
32
31
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
33
32
|
from sglang.srt.layers.radix_attention import RadixAttention
|
34
33
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
sglang/srt/models/qwen.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
# Adapted from
|
2
2
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
|
3
|
-
from typing import Any, Dict,
|
3
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
4
4
|
|
5
5
|
import torch
|
6
6
|
from torch import nn
|
sglang/srt/models/qwen2.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# Adapted from llama2.py
|
2
2
|
# Modify details for the adaptation of Qwen2 model.
|
3
3
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
4
|
-
from typing import Any, Dict, Optional, Tuple
|
4
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
sglang/srt/models/stablelm.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
|
3
3
|
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
|
4
4
|
model compatible with HuggingFace weights."""
|
5
|
-
from typing import Optional, Tuple
|
5
|
+
from typing import Iterable, Optional, Tuple
|
6
6
|
|
7
7
|
import torch
|
8
8
|
from torch import nn
|
sglang/srt/models/yivl.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
"""Inference-only Yi-VL model."""
|
2
2
|
|
3
|
-
from typing import
|
3
|
+
from typing import Iterable, Optional, Tuple
|
4
4
|
|
5
5
|
import torch
|
6
6
|
import torch.nn as nn
|
7
7
|
from transformers import CLIPVisionModel, LlavaConfig
|
8
8
|
from vllm.config import CacheConfig
|
9
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
9
10
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
10
11
|
|
11
|
-
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
12
12
|
from sglang.srt.models.llava import (
|
13
13
|
LlavaLlamaForCausalLM,
|
14
14
|
monkey_path_clip_vision_embed_forward,
|
sglang/srt/openai_api_adapter.py
CHANGED
@@ -6,7 +6,7 @@ import os
|
|
6
6
|
from http import HTTPStatus
|
7
7
|
|
8
8
|
from fastapi import Request
|
9
|
-
from fastapi.responses import
|
9
|
+
from fastapi.responses import JSONResponse, StreamingResponse
|
10
10
|
|
11
11
|
from sglang.srt.conversation import (
|
12
12
|
Conversation,
|
@@ -40,21 +40,18 @@ chat_template_name = None
|
|
40
40
|
def create_error_response(
|
41
41
|
message: str,
|
42
42
|
err_type: str = "BadRequestError",
|
43
|
-
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
return JSONResponse(content=error.model_dump(),
|
48
|
-
status_code=error.code)
|
43
|
+
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
44
|
+
):
|
45
|
+
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
|
46
|
+
return JSONResponse(content=error.model_dump(), status_code=error.code)
|
49
47
|
|
50
48
|
|
51
49
|
def create_streaming_error_response(
|
52
50
|
message: str,
|
53
51
|
err_type: str = "BadRequestError",
|
54
|
-
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST
|
55
|
-
|
56
|
-
|
57
|
-
code=status_code.value)
|
52
|
+
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
53
|
+
) -> str:
|
54
|
+
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
|
58
55
|
json_str = json.dumps({"error": error.model_dump()})
|
59
56
|
return json_str
|
60
57
|
|
@@ -125,7 +122,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
125
122
|
n_prev_token = 0
|
126
123
|
try:
|
127
124
|
async for content in tokenizer_manager.generate_request(
|
128
|
-
adapted_request, raw_request
|
125
|
+
adapted_request, raw_request
|
126
|
+
):
|
129
127
|
text = content["text"]
|
130
128
|
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
131
129
|
completion_tokens = content["meta_info"]["completion_tokens"]
|
@@ -154,12 +152,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
154
152
|
decode_token_logprobs=content["meta_info"][
|
155
153
|
"decode_token_logprobs"
|
156
154
|
][n_prev_token:],
|
157
|
-
decode_top_logprobs=content["meta_info"][
|
158
|
-
|
159
|
-
],
|
155
|
+
decode_top_logprobs=content["meta_info"][
|
156
|
+
"decode_top_logprobs"
|
157
|
+
][n_prev_token:],
|
160
158
|
)
|
161
159
|
|
162
|
-
n_prev_token = len(
|
160
|
+
n_prev_token = len(
|
161
|
+
content["meta_info"]["decode_token_logprobs"]
|
162
|
+
)
|
163
163
|
else:
|
164
164
|
logprobs = None
|
165
165
|
|
@@ -188,13 +188,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
188
188
|
yield f"data: {error}\n\n"
|
189
189
|
yield "data: [DONE]\n\n"
|
190
190
|
|
191
|
-
return StreamingResponse(
|
192
|
-
|
191
|
+
return StreamingResponse(
|
192
|
+
generate_stream_resp(),
|
193
|
+
media_type="text/event-stream",
|
194
|
+
background=tokenizer_manager.create_abort_task(adapted_request),
|
195
|
+
)
|
193
196
|
|
194
197
|
# Non-streaming response.
|
195
198
|
try:
|
196
199
|
ret = await tokenizer_manager.generate_request(
|
197
|
-
adapted_request, raw_request
|
200
|
+
adapted_request, raw_request
|
201
|
+
).__anext__()
|
198
202
|
except ValueError as e:
|
199
203
|
return create_error_response(str(e))
|
200
204
|
|
@@ -299,7 +303,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
299
303
|
|
300
304
|
stream_buffer = ""
|
301
305
|
try:
|
302
|
-
async for content in tokenizer_manager.generate_request(
|
306
|
+
async for content in tokenizer_manager.generate_request(
|
307
|
+
adapted_request, raw_request
|
308
|
+
):
|
303
309
|
if is_first:
|
304
310
|
# First chunk with role
|
305
311
|
is_first = False
|
@@ -334,13 +340,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
334
340
|
yield f"data: {error}\n\n"
|
335
341
|
yield "data: [DONE]\n\n"
|
336
342
|
|
337
|
-
return StreamingResponse(
|
338
|
-
|
343
|
+
return StreamingResponse(
|
344
|
+
generate_stream_resp(),
|
345
|
+
media_type="text/event-stream",
|
346
|
+
background=tokenizer_manager.create_abort_task(adapted_request),
|
347
|
+
)
|
339
348
|
|
340
349
|
# Non-streaming response.
|
341
350
|
try:
|
342
351
|
ret = await tokenizer_manager.generate_request(
|
343
|
-
adapted_request, raw_request
|
352
|
+
adapted_request, raw_request
|
353
|
+
).__anext__()
|
344
354
|
except ValueError as e:
|
345
355
|
return create_error_response(str(e))
|
346
356
|
|
sglang/srt/openai_protocol.py
CHANGED