sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/srt/models/stablelm.py
CHANGED
@@ -1,40 +1,38 @@
|
|
1
|
-
#
|
2
|
-
# https://github.com/vllm-project/vllm/blob/
|
1
|
+
# Adapted from:
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
|
3
3
|
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
|
4
4
|
model compatible with HuggingFace weights."""
|
5
|
-
from typing import Optional, Tuple
|
5
|
+
from typing import Iterable, Optional, Tuple
|
6
6
|
|
7
7
|
import torch
|
8
8
|
from torch import nn
|
9
9
|
from transformers import PretrainedConfig
|
10
|
-
|
11
|
-
from
|
12
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
13
|
-
from sglang.srt.managers.router.model_runner import InputMetadata
|
10
|
+
from vllm.config import CacheConfig
|
11
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
14
12
|
from vllm.model_executor.layers.activation import SiluAndMul
|
15
13
|
from vllm.model_executor.layers.linear import (
|
16
|
-
LinearMethodBase,
|
17
14
|
MergedColumnParallelLinear,
|
18
15
|
QKVParallelLinear,
|
19
16
|
RowParallelLinear,
|
20
17
|
)
|
18
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
21
19
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
22
20
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
23
|
-
VocabParallelEmbedding,
|
24
21
|
ParallelLMHead,
|
22
|
+
VocabParallelEmbedding,
|
25
23
|
)
|
26
|
-
from vllm.model_executor.
|
27
|
-
|
28
|
-
|
29
|
-
from
|
30
|
-
|
31
|
-
hf_model_weights_iterator,
|
32
|
-
)
|
24
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
25
|
+
|
26
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
28
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
33
29
|
|
34
30
|
|
35
31
|
class StablelmMLP(nn.Module):
|
36
32
|
def __init__(
|
37
|
-
self,
|
33
|
+
self,
|
34
|
+
config: PretrainedConfig,
|
35
|
+
quant_config: Optional[QuantizationConfig] = None,
|
38
36
|
) -> None:
|
39
37
|
super().__init__()
|
40
38
|
self.config = config
|
@@ -44,10 +42,13 @@ class StablelmMLP(nn.Module):
|
|
44
42
|
config.hidden_size,
|
45
43
|
[config.intermediate_size] * 2,
|
46
44
|
bias=False,
|
47
|
-
|
45
|
+
quant_config=quant_config,
|
48
46
|
)
|
49
47
|
self.down_proj = RowParallelLinear(
|
50
|
-
config.intermediate_size,
|
48
|
+
config.intermediate_size,
|
49
|
+
config.hidden_size,
|
50
|
+
bias=False,
|
51
|
+
quant_config=quant_config,
|
51
52
|
)
|
52
53
|
self.act_fn = SiluAndMul()
|
53
54
|
|
@@ -63,7 +64,7 @@ class StablelmAttention(nn.Module):
|
|
63
64
|
self,
|
64
65
|
config: PretrainedConfig,
|
65
66
|
layer_id: int = 0,
|
66
|
-
|
67
|
+
quant_config: Optional[QuantizationConfig] = None,
|
67
68
|
) -> None:
|
68
69
|
super().__init__()
|
69
70
|
self.config = config
|
@@ -105,13 +106,11 @@ class StablelmAttention(nn.Module):
|
|
105
106
|
self.total_num_heads,
|
106
107
|
self.total_num_key_value_heads,
|
107
108
|
self.qkv_bias,
|
108
|
-
linear_method=linear_method,
|
109
109
|
)
|
110
110
|
self.o_proj = RowParallelLinear(
|
111
111
|
self.total_num_heads * self.head_dim,
|
112
112
|
self.hidden_size,
|
113
113
|
bias=False,
|
114
|
-
linear_method=linear_method,
|
115
114
|
)
|
116
115
|
self.rotary_emb = get_rope(
|
117
116
|
self.head_dim,
|
@@ -146,11 +145,11 @@ class StablelmDecoderLayer(nn.Module):
|
|
146
145
|
self,
|
147
146
|
config: PretrainedConfig,
|
148
147
|
layer_id: int = 0,
|
149
|
-
|
148
|
+
quant_config: Optional[QuantizationConfig] = None,
|
150
149
|
) -> None:
|
151
150
|
super().__init__()
|
152
151
|
self.self_attn = StablelmAttention(config, layer_id=layer_id)
|
153
|
-
self.mlp = StablelmMLP(config,
|
152
|
+
self.mlp = StablelmMLP(config, quant_config=quant_config)
|
154
153
|
norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
|
155
154
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
156
155
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
@@ -182,7 +181,9 @@ class StablelmDecoderLayer(nn.Module):
|
|
182
181
|
|
183
182
|
class StableLMEpochModel(nn.Module):
|
184
183
|
def __init__(
|
185
|
-
self,
|
184
|
+
self,
|
185
|
+
config: PretrainedConfig,
|
186
|
+
quant_config: Optional[QuantizationConfig] = None,
|
186
187
|
) -> None:
|
187
188
|
super().__init__()
|
188
189
|
self.embed_tokens = VocabParallelEmbedding(
|
@@ -191,7 +192,7 @@ class StableLMEpochModel(nn.Module):
|
|
191
192
|
)
|
192
193
|
self.layers = nn.ModuleList(
|
193
194
|
[
|
194
|
-
StablelmDecoderLayer(config, i,
|
195
|
+
StablelmDecoderLayer(config, i, quant_config=quant_config)
|
195
196
|
for i in range(config.num_hidden_layers)
|
196
197
|
]
|
197
198
|
)
|
@@ -224,12 +225,13 @@ class StableLmForCausalLM(nn.Module):
|
|
224
225
|
def __init__(
|
225
226
|
self,
|
226
227
|
config: PretrainedConfig,
|
227
|
-
|
228
|
+
quant_config: Optional[QuantizationConfig] = None,
|
229
|
+
cache_config: Optional[CacheConfig] = None,
|
228
230
|
) -> None:
|
229
231
|
super().__init__()
|
230
232
|
self.config = config
|
231
|
-
self.
|
232
|
-
self.model = StableLMEpochModel(config,
|
233
|
+
self.quant_config = quant_config
|
234
|
+
self.model = StableLMEpochModel(config, quant_config=quant_config)
|
233
235
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
234
236
|
self.logits_processor = LogitsProcessor(config)
|
235
237
|
|
@@ -245,13 +247,7 @@ class StableLmForCausalLM(nn.Module):
|
|
245
247
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
246
248
|
)
|
247
249
|
|
248
|
-
def load_weights(
|
249
|
-
self,
|
250
|
-
model_name_or_path: str,
|
251
|
-
cache_dir: Optional[str] = None,
|
252
|
-
load_format: str = "auto",
|
253
|
-
revision: Optional[str] = None,
|
254
|
-
):
|
250
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
255
251
|
stacked_params_mapping = [
|
256
252
|
# (param_name, shard_name, shard_id)
|
257
253
|
("qkv_proj", "q_proj", "q"),
|
@@ -261,9 +257,7 @@ class StableLmForCausalLM(nn.Module):
|
|
261
257
|
("gate_up_proj", "up_proj", 1),
|
262
258
|
]
|
263
259
|
params_dict = dict(self.named_parameters())
|
264
|
-
for name, loaded_weight in
|
265
|
-
model_name_or_path, cache_dir, load_format, revision
|
266
|
-
):
|
260
|
+
for name, loaded_weight in weights:
|
267
261
|
if "rotary_emb.inv_freq" in name:
|
268
262
|
continue
|
269
263
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
sglang/srt/models/yivl.py
CHANGED
@@ -1,42 +1,38 @@
|
|
1
1
|
"""Inference-only Yi-VL model."""
|
2
2
|
|
3
|
-
import
|
4
|
-
from typing import List, Optional
|
3
|
+
from typing import Iterable, Optional, Tuple
|
5
4
|
|
6
5
|
import torch
|
7
6
|
import torch.nn as nn
|
7
|
+
from transformers import CLIPVisionModel, LlavaConfig
|
8
|
+
from vllm.config import CacheConfig
|
9
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
10
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
11
|
+
|
8
12
|
from sglang.srt.models.llava import (
|
9
13
|
LlavaLlamaForCausalLM,
|
10
|
-
clip_vision_embed_forward,
|
11
14
|
monkey_path_clip_vision_embed_forward,
|
12
15
|
)
|
13
|
-
from transformers import CLIPVisionModel, LlavaConfig
|
14
|
-
from vllm.model_executor.weight_utils import (
|
15
|
-
default_weight_loader,
|
16
|
-
hf_model_weights_iterator,
|
17
|
-
)
|
18
16
|
|
19
17
|
|
20
18
|
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
21
|
-
def __init__(
|
22
|
-
self
|
23
|
-
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
config: LlavaConfig,
|
22
|
+
quant_config: Optional[QuantizationConfig] = None,
|
23
|
+
cache_config: Optional[CacheConfig] = None,
|
24
|
+
) -> None:
|
25
|
+
super().__init__(config, quant_config, cache_config)
|
24
26
|
|
25
27
|
self.multi_modal_projector = YiVLMultiModalProjector(self.config)
|
26
28
|
self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
|
27
29
|
"./", ""
|
28
30
|
) # Everything after "./"
|
29
31
|
|
30
|
-
def load_weights(
|
31
|
-
self,
|
32
|
-
model_name_or_path: str,
|
33
|
-
cache_dir: Optional[str] = None,
|
34
|
-
load_format: str = "auto",
|
35
|
-
revision: Optional[str] = None,
|
36
|
-
):
|
32
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
37
33
|
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
|
38
34
|
self.vision_tower = CLIPVisionModel.from_pretrained(
|
39
|
-
|
35
|
+
self.config._name_or_path,
|
40
36
|
torch_dtype=torch.float16,
|
41
37
|
subfolder=self.vision_tower_subfolder,
|
42
38
|
).cuda()
|
@@ -70,9 +66,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
|
70
66
|
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
71
67
|
}
|
72
68
|
params_dict = dict(self.named_parameters())
|
73
|
-
|
74
|
-
|
75
|
-
):
|
69
|
+
weights = list(weights)
|
70
|
+
for name, loaded_weight in weights:
|
76
71
|
if "projector" in name or "vision_tower" in name:
|
77
72
|
for weight_name, param_name in projector_weights.items():
|
78
73
|
if weight_name in name:
|
@@ -82,9 +77,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
|
82
77
|
weight_loader(param, loaded_weight)
|
83
78
|
|
84
79
|
# load language model
|
85
|
-
self.language_model.load_weights(
|
86
|
-
model_name_or_path, cache_dir, load_format, revision
|
87
|
-
)
|
80
|
+
self.language_model.load_weights(weights)
|
88
81
|
|
89
82
|
monkey_path_clip_vision_embed_forward()
|
90
83
|
|
@@ -105,7 +98,7 @@ class YiVLMultiModalProjector(nn.Module):
|
|
105
98
|
|
106
99
|
def forward(self, image_features):
|
107
100
|
hidden_states = self.linear_1(image_features)
|
108
|
-
|
101
|
+
hidden_states = self.ln_1(hidden_states)
|
109
102
|
hidden_states = self.act(hidden_states)
|
110
103
|
hidden_states = self.linear_2(hidden_states)
|
111
104
|
hidden_states = self.ln_2(hidden_states)
|
@@ -0,0 +1,411 @@
|
|
1
|
+
"""Conversion between OpenAI APIs and native SRT APIs"""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import json
|
5
|
+
import os
|
6
|
+
from http import HTTPStatus
|
7
|
+
|
8
|
+
from fastapi import Request
|
9
|
+
from fastapi.responses import JSONResponse, StreamingResponse
|
10
|
+
|
11
|
+
from sglang.srt.conversation import (
|
12
|
+
Conversation,
|
13
|
+
SeparatorStyle,
|
14
|
+
chat_template_exists,
|
15
|
+
generate_chat_conv,
|
16
|
+
register_conv_template,
|
17
|
+
)
|
18
|
+
from sglang.srt.managers.io_struct import GenerateReqInput
|
19
|
+
from sglang.srt.openai_protocol import (
|
20
|
+
ChatCompletionRequest,
|
21
|
+
ChatCompletionResponse,
|
22
|
+
ChatCompletionResponseChoice,
|
23
|
+
ChatCompletionResponseStreamChoice,
|
24
|
+
ChatCompletionStreamResponse,
|
25
|
+
ChatMessage,
|
26
|
+
CompletionRequest,
|
27
|
+
CompletionResponse,
|
28
|
+
CompletionResponseChoice,
|
29
|
+
CompletionResponseStreamChoice,
|
30
|
+
CompletionStreamResponse,
|
31
|
+
DeltaMessage,
|
32
|
+
ErrorResponse,
|
33
|
+
LogProbs,
|
34
|
+
UsageInfo,
|
35
|
+
)
|
36
|
+
|
37
|
+
chat_template_name = None
|
38
|
+
|
39
|
+
|
40
|
+
def create_error_response(
|
41
|
+
message: str,
|
42
|
+
err_type: str = "BadRequestError",
|
43
|
+
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
44
|
+
):
|
45
|
+
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
|
46
|
+
return JSONResponse(content=error.model_dump(), status_code=error.code)
|
47
|
+
|
48
|
+
|
49
|
+
def create_streaming_error_response(
|
50
|
+
message: str,
|
51
|
+
err_type: str = "BadRequestError",
|
52
|
+
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
|
53
|
+
) -> str:
|
54
|
+
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
|
55
|
+
json_str = json.dumps({"error": error.model_dump()})
|
56
|
+
return json_str
|
57
|
+
|
58
|
+
|
59
|
+
def load_chat_template_for_openai_api(chat_template_arg):
|
60
|
+
global chat_template_name
|
61
|
+
|
62
|
+
print(f"Use chat template: {chat_template_arg}")
|
63
|
+
if not chat_template_exists(chat_template_arg):
|
64
|
+
if not os.path.exists(chat_template_arg):
|
65
|
+
raise RuntimeError(
|
66
|
+
f"Chat template {chat_template_arg} is not a built-in template name "
|
67
|
+
"or a valid chat template file path."
|
68
|
+
)
|
69
|
+
with open(chat_template_arg, "r") as filep:
|
70
|
+
template = json.load(filep)
|
71
|
+
try:
|
72
|
+
sep_style = SeparatorStyle[template["sep_style"]]
|
73
|
+
except KeyError:
|
74
|
+
raise ValueError(
|
75
|
+
f"Unknown separator style: {template['sep_style']}"
|
76
|
+
) from None
|
77
|
+
register_conv_template(
|
78
|
+
Conversation(
|
79
|
+
name=template["name"],
|
80
|
+
system_template=template["system"] + "\n{system_message}",
|
81
|
+
system_message=template.get("system_message", ""),
|
82
|
+
roles=(template["user"], template["assistant"]),
|
83
|
+
sep_style=sep_style,
|
84
|
+
sep=template.get("sep", "\n"),
|
85
|
+
stop_str=template["stop_str"],
|
86
|
+
),
|
87
|
+
override=True,
|
88
|
+
)
|
89
|
+
chat_template_name = template["name"]
|
90
|
+
else:
|
91
|
+
chat_template_name = chat_template_arg
|
92
|
+
|
93
|
+
|
94
|
+
async def v1_completions(tokenizer_manager, raw_request: Request):
|
95
|
+
request_json = await raw_request.json()
|
96
|
+
request = CompletionRequest(**request_json)
|
97
|
+
|
98
|
+
if request.n != 1:
|
99
|
+
return create_error_response("n != 1 is not supported")
|
100
|
+
|
101
|
+
adapted_request = GenerateReqInput(
|
102
|
+
text=request.prompt,
|
103
|
+
sampling_params={
|
104
|
+
"temperature": request.temperature,
|
105
|
+
"max_new_tokens": request.max_tokens,
|
106
|
+
"stop": request.stop,
|
107
|
+
"top_p": request.top_p,
|
108
|
+
"presence_penalty": request.presence_penalty,
|
109
|
+
"frequency_penalty": request.frequency_penalty,
|
110
|
+
"regex": request.regex,
|
111
|
+
},
|
112
|
+
return_logprob=request.logprobs is not None and request.logprobs > 0,
|
113
|
+
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
|
114
|
+
return_text_in_logprobs=True,
|
115
|
+
stream=request.stream,
|
116
|
+
)
|
117
|
+
|
118
|
+
if adapted_request.stream:
|
119
|
+
|
120
|
+
async def generate_stream_resp():
|
121
|
+
stream_buffer = ""
|
122
|
+
n_prev_token = 0
|
123
|
+
try:
|
124
|
+
async for content in tokenizer_manager.generate_request(
|
125
|
+
adapted_request, raw_request
|
126
|
+
):
|
127
|
+
text = content["text"]
|
128
|
+
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
129
|
+
completion_tokens = content["meta_info"]["completion_tokens"]
|
130
|
+
|
131
|
+
if not stream_buffer: # The first chunk
|
132
|
+
if request.echo:
|
133
|
+
# Prepend prompt in response text.
|
134
|
+
text = request.prompt + text
|
135
|
+
|
136
|
+
if request.logprobs:
|
137
|
+
# The first chunk and echo is enabled.
|
138
|
+
if not stream_buffer and request.echo:
|
139
|
+
prefill_token_logprobs = content["meta_info"][
|
140
|
+
"prefill_token_logprobs"
|
141
|
+
]
|
142
|
+
prefill_top_logprobs = content["meta_info"][
|
143
|
+
"prefill_top_logprobs"
|
144
|
+
]
|
145
|
+
else:
|
146
|
+
prefill_token_logprobs = None
|
147
|
+
prefill_top_logprobs = None
|
148
|
+
|
149
|
+
logprobs = to_openai_style_logprobs(
|
150
|
+
prefill_token_logprobs=prefill_token_logprobs,
|
151
|
+
prefill_top_logprobs=prefill_top_logprobs,
|
152
|
+
decode_token_logprobs=content["meta_info"][
|
153
|
+
"decode_token_logprobs"
|
154
|
+
][n_prev_token:],
|
155
|
+
decode_top_logprobs=content["meta_info"][
|
156
|
+
"decode_top_logprobs"
|
157
|
+
][n_prev_token:],
|
158
|
+
)
|
159
|
+
|
160
|
+
n_prev_token = len(
|
161
|
+
content["meta_info"]["decode_token_logprobs"]
|
162
|
+
)
|
163
|
+
else:
|
164
|
+
logprobs = None
|
165
|
+
|
166
|
+
delta = text[len(stream_buffer) :]
|
167
|
+
stream_buffer = stream_buffer + delta
|
168
|
+
choice_data = CompletionResponseStreamChoice(
|
169
|
+
index=0,
|
170
|
+
text=delta,
|
171
|
+
logprobs=logprobs,
|
172
|
+
finish_reason=content["meta_info"]["finish_reason"],
|
173
|
+
)
|
174
|
+
chunk = CompletionStreamResponse(
|
175
|
+
id=content["meta_info"]["id"],
|
176
|
+
object="text_completion",
|
177
|
+
choices=[choice_data],
|
178
|
+
model=request.model,
|
179
|
+
usage=UsageInfo(
|
180
|
+
prompt_tokens=prompt_tokens,
|
181
|
+
completion_tokens=completion_tokens,
|
182
|
+
total_tokens=prompt_tokens + completion_tokens,
|
183
|
+
),
|
184
|
+
)
|
185
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
186
|
+
except ValueError as e:
|
187
|
+
error = create_streaming_error_response(str(e))
|
188
|
+
yield f"data: {error}\n\n"
|
189
|
+
yield "data: [DONE]\n\n"
|
190
|
+
|
191
|
+
return StreamingResponse(
|
192
|
+
generate_stream_resp(),
|
193
|
+
media_type="text/event-stream",
|
194
|
+
background=tokenizer_manager.create_abort_task(adapted_request),
|
195
|
+
)
|
196
|
+
|
197
|
+
# Non-streaming response.
|
198
|
+
try:
|
199
|
+
ret = await tokenizer_manager.generate_request(
|
200
|
+
adapted_request, raw_request
|
201
|
+
).__anext__()
|
202
|
+
except ValueError as e:
|
203
|
+
return create_error_response(str(e))
|
204
|
+
|
205
|
+
ret = ret[0] if isinstance(ret, list) else ret
|
206
|
+
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
207
|
+
completion_tokens = ret["meta_info"]["completion_tokens"]
|
208
|
+
text = ret["text"]
|
209
|
+
if request.echo:
|
210
|
+
text = request.prompt + text
|
211
|
+
|
212
|
+
if request.logprobs:
|
213
|
+
if request.echo:
|
214
|
+
prefill_token_logprobs = ret["meta_info"]["prefill_token_logprobs"]
|
215
|
+
prefill_top_logprobs = ret["meta_info"]["prefill_top_logprobs"]
|
216
|
+
else:
|
217
|
+
prefill_token_logprobs = None
|
218
|
+
prefill_top_logprobs = None
|
219
|
+
|
220
|
+
logprobs = to_openai_style_logprobs(
|
221
|
+
prefill_token_logprobs=prefill_token_logprobs,
|
222
|
+
prefill_top_logprobs=prefill_top_logprobs,
|
223
|
+
decode_token_logprobs=ret["meta_info"]["decode_token_logprobs"],
|
224
|
+
decode_top_logprobs=ret["meta_info"]["decode_top_logprobs"],
|
225
|
+
)
|
226
|
+
else:
|
227
|
+
logprobs = None
|
228
|
+
|
229
|
+
choice_data = CompletionResponseChoice(
|
230
|
+
index=0,
|
231
|
+
text=text,
|
232
|
+
logprobs=logprobs,
|
233
|
+
finish_reason=ret["meta_info"]["finish_reason"],
|
234
|
+
)
|
235
|
+
response = CompletionResponse(
|
236
|
+
id=ret["meta_info"]["id"],
|
237
|
+
model=request.model,
|
238
|
+
choices=[choice_data],
|
239
|
+
usage=UsageInfo(
|
240
|
+
prompt_tokens=prompt_tokens,
|
241
|
+
completion_tokens=completion_tokens,
|
242
|
+
total_tokens=prompt_tokens + completion_tokens,
|
243
|
+
),
|
244
|
+
)
|
245
|
+
return response
|
246
|
+
|
247
|
+
|
248
|
+
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
249
|
+
request_json = await raw_request.json()
|
250
|
+
request = ChatCompletionRequest(**request_json)
|
251
|
+
|
252
|
+
if request.n != 1:
|
253
|
+
return create_error_response("n != 1 is not supported")
|
254
|
+
|
255
|
+
# Prep the data needed for the underlying GenerateReqInput:
|
256
|
+
# - prompt: The full prompt string.
|
257
|
+
# - stop: Custom stop tokens.
|
258
|
+
# - image_data: None or a list of image strings (URLs or base64 strings).
|
259
|
+
# None skips any image processing in GenerateReqInput.
|
260
|
+
if not isinstance(request.messages, str):
|
261
|
+
# Apply chat template and its stop strings.
|
262
|
+
if chat_template_name is None:
|
263
|
+
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
264
|
+
request.messages, tokenize=False, add_generation_prompt=True
|
265
|
+
)
|
266
|
+
stop = request.stop
|
267
|
+
image_data = None
|
268
|
+
else:
|
269
|
+
conv = generate_chat_conv(request, chat_template_name)
|
270
|
+
prompt = conv.get_prompt()
|
271
|
+
image_data = conv.image_data
|
272
|
+
stop = conv.stop_str or []
|
273
|
+
if request.stop:
|
274
|
+
if isinstance(request.stop, str):
|
275
|
+
stop.append(request.stop)
|
276
|
+
else:
|
277
|
+
stop.extend(request.stop)
|
278
|
+
else:
|
279
|
+
# Use the raw prompt and stop strings if the messages is already a string.
|
280
|
+
prompt = request.messages
|
281
|
+
stop = request.stop
|
282
|
+
image_data = None
|
283
|
+
|
284
|
+
adapted_request = GenerateReqInput(
|
285
|
+
text=prompt,
|
286
|
+
image_data=image_data,
|
287
|
+
sampling_params={
|
288
|
+
"temperature": request.temperature,
|
289
|
+
"max_new_tokens": request.max_tokens,
|
290
|
+
"stop": stop,
|
291
|
+
"top_p": request.top_p,
|
292
|
+
"presence_penalty": request.presence_penalty,
|
293
|
+
"frequency_penalty": request.frequency_penalty,
|
294
|
+
"regex": request.regex,
|
295
|
+
},
|
296
|
+
stream=request.stream,
|
297
|
+
)
|
298
|
+
|
299
|
+
if adapted_request.stream:
|
300
|
+
|
301
|
+
async def generate_stream_resp():
|
302
|
+
is_first = True
|
303
|
+
|
304
|
+
stream_buffer = ""
|
305
|
+
try:
|
306
|
+
async for content in tokenizer_manager.generate_request(
|
307
|
+
adapted_request, raw_request
|
308
|
+
):
|
309
|
+
if is_first:
|
310
|
+
# First chunk with role
|
311
|
+
is_first = False
|
312
|
+
choice_data = ChatCompletionResponseStreamChoice(
|
313
|
+
index=0,
|
314
|
+
delta=DeltaMessage(role="assistant"),
|
315
|
+
finish_reason=content["meta_info"]["finish_reason"],
|
316
|
+
)
|
317
|
+
chunk = ChatCompletionStreamResponse(
|
318
|
+
id=content["meta_info"]["id"],
|
319
|
+
choices=[choice_data],
|
320
|
+
model=request.model,
|
321
|
+
)
|
322
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
323
|
+
|
324
|
+
text = content["text"]
|
325
|
+
delta = text[len(stream_buffer) :]
|
326
|
+
stream_buffer = stream_buffer + delta
|
327
|
+
choice_data = ChatCompletionResponseStreamChoice(
|
328
|
+
index=0,
|
329
|
+
delta=DeltaMessage(content=delta),
|
330
|
+
finish_reason=content["meta_info"]["finish_reason"],
|
331
|
+
)
|
332
|
+
chunk = ChatCompletionStreamResponse(
|
333
|
+
id=content["meta_info"]["id"],
|
334
|
+
choices=[choice_data],
|
335
|
+
model=request.model,
|
336
|
+
)
|
337
|
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
338
|
+
except ValueError as e:
|
339
|
+
error = create_streaming_error_response(str(e))
|
340
|
+
yield f"data: {error}\n\n"
|
341
|
+
yield "data: [DONE]\n\n"
|
342
|
+
|
343
|
+
return StreamingResponse(
|
344
|
+
generate_stream_resp(),
|
345
|
+
media_type="text/event-stream",
|
346
|
+
background=tokenizer_manager.create_abort_task(adapted_request),
|
347
|
+
)
|
348
|
+
|
349
|
+
# Non-streaming response.
|
350
|
+
try:
|
351
|
+
ret = await tokenizer_manager.generate_request(
|
352
|
+
adapted_request, raw_request
|
353
|
+
).__anext__()
|
354
|
+
except ValueError as e:
|
355
|
+
return create_error_response(str(e))
|
356
|
+
|
357
|
+
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
358
|
+
completion_tokens = ret["meta_info"]["completion_tokens"]
|
359
|
+
choice_data = ChatCompletionResponseChoice(
|
360
|
+
index=0,
|
361
|
+
message=ChatMessage(role="assistant", content=ret["text"]),
|
362
|
+
finish_reason=ret["meta_info"]["finish_reason"],
|
363
|
+
)
|
364
|
+
response = ChatCompletionResponse(
|
365
|
+
id=ret["meta_info"]["id"],
|
366
|
+
model=request.model,
|
367
|
+
choices=[choice_data],
|
368
|
+
usage=UsageInfo(
|
369
|
+
prompt_tokens=prompt_tokens,
|
370
|
+
completion_tokens=completion_tokens,
|
371
|
+
total_tokens=prompt_tokens + completion_tokens,
|
372
|
+
),
|
373
|
+
)
|
374
|
+
return response
|
375
|
+
|
376
|
+
|
377
|
+
def to_openai_style_logprobs(
|
378
|
+
prefill_token_logprobs=None,
|
379
|
+
decode_token_logprobs=None,
|
380
|
+
prefill_top_logprobs=None,
|
381
|
+
decode_top_logprobs=None,
|
382
|
+
):
|
383
|
+
ret_logprobs = LogProbs()
|
384
|
+
|
385
|
+
def append_token_logprobs(token_logprobs):
|
386
|
+
for logprob, _, token_text in token_logprobs:
|
387
|
+
ret_logprobs.tokens.append(token_text)
|
388
|
+
ret_logprobs.token_logprobs.append(logprob)
|
389
|
+
|
390
|
+
# Not supported yet
|
391
|
+
ret_logprobs.text_offset.append(-1)
|
392
|
+
|
393
|
+
def append_top_logprobs(top_logprobs):
|
394
|
+
for tokens in top_logprobs:
|
395
|
+
if tokens is not None:
|
396
|
+
ret_logprobs.top_logprobs.append(
|
397
|
+
{token[2]: token[0] for token in tokens}
|
398
|
+
)
|
399
|
+
else:
|
400
|
+
ret_logprobs.top_logprobs.append(None)
|
401
|
+
|
402
|
+
if prefill_token_logprobs is not None:
|
403
|
+
append_token_logprobs(prefill_token_logprobs)
|
404
|
+
if decode_token_logprobs is not None:
|
405
|
+
append_token_logprobs(decode_token_logprobs)
|
406
|
+
if prefill_top_logprobs is not None:
|
407
|
+
append_top_logprobs(prefill_top_logprobs)
|
408
|
+
if decode_top_logprobs is not None:
|
409
|
+
append_top_logprobs(decode_top_logprobs)
|
410
|
+
|
411
|
+
return ret_logprobs
|