sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +151 -40
- sglang/bench_serving.py +46 -22
- sglang/check_env.py +24 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -29
- sglang/lang/choices.py +164 -0
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +14 -5
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +6 -1
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +4 -7
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +174 -380
- sglang/srt/managers/tokenizer_manager.py +197 -112
- sglang/srt/managers/tp_worker.py +299 -364
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +10 -15
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +27 -12
- sglang/srt/model_executor/forward_batch_info.py +319 -0
- sglang/srt/model_executor/model_runner.py +30 -47
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +1 -1
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -2
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +3 -8
- sglang/srt/models/llama2.py +5 -5
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -12
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +189 -39
- sglang/srt/openai_api/protocol.py +43 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +93 -21
- sglang/srt/server_args.py +30 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +21 -3
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.10.dist-info/RECORD +0 -100
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,88 @@
|
|
1
|
+
from typing import Iterable, Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import nn
|
5
|
+
from transformers import LlamaConfig
|
6
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
7
|
+
|
8
|
+
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
9
|
+
from sglang.srt.model_executor.model_runner import InputMetadata
|
10
|
+
from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel
|
11
|
+
|
12
|
+
|
13
|
+
class LlamaEmbeddingModel(nn.Module):
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
config: LlamaConfig,
|
17
|
+
quant_config=None,
|
18
|
+
cache_config=None,
|
19
|
+
efficient_weight_load=False,
|
20
|
+
) -> None:
|
21
|
+
super().__init__()
|
22
|
+
self.model = LlamaModel(config, quant_config=quant_config)
|
23
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
24
|
+
|
25
|
+
@torch.no_grad()
|
26
|
+
def forward(
|
27
|
+
self,
|
28
|
+
input_ids: torch.Tensor,
|
29
|
+
positions: torch.Tensor,
|
30
|
+
input_metadata: InputMetadata,
|
31
|
+
input_embeds: torch.Tensor = None,
|
32
|
+
) -> EmbeddingPoolerOutput:
|
33
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
34
|
+
return self.pooler(hidden_states, input_metadata)
|
35
|
+
|
36
|
+
def load_weights(
|
37
|
+
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
38
|
+
):
|
39
|
+
stacked_params_mapping = [
|
40
|
+
# (param_name, shard_name, shard_id)
|
41
|
+
("qkv_proj", "q_proj", "q"),
|
42
|
+
("qkv_proj", "k_proj", "k"),
|
43
|
+
("qkv_proj", "v_proj", "v"),
|
44
|
+
("gate_up_proj", "gate_proj", 0),
|
45
|
+
("gate_up_proj", "up_proj", 1),
|
46
|
+
]
|
47
|
+
params_dict = dict(self.model.named_parameters())
|
48
|
+
|
49
|
+
def load_weights_per_param(name, loaded_weight):
|
50
|
+
if "rotary_emb.inv_freq" in name or "projector" in name:
|
51
|
+
return
|
52
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
53
|
+
# Models trained using ColossalAI may include these tensors in
|
54
|
+
# the checkpoint. Skip them.
|
55
|
+
return
|
56
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
57
|
+
if weight_name not in name:
|
58
|
+
continue
|
59
|
+
name = name.replace(weight_name, param_name)
|
60
|
+
# Skip loading extra bias for GPTQ models.
|
61
|
+
if name.endswith(".bias") and name not in params_dict:
|
62
|
+
continue
|
63
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
64
|
+
continue
|
65
|
+
param = params_dict[name]
|
66
|
+
weight_loader = param.weight_loader
|
67
|
+
weight_loader(param, loaded_weight, shard_id)
|
68
|
+
break
|
69
|
+
else:
|
70
|
+
# Skip loading extra bias for GPTQ models.
|
71
|
+
if name.endswith(".bias") and name not in params_dict:
|
72
|
+
return
|
73
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
74
|
+
return
|
75
|
+
param = params_dict[name]
|
76
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
77
|
+
weight_loader(param, loaded_weight)
|
78
|
+
|
79
|
+
if name is None or loaded_weight is None:
|
80
|
+
for name, loaded_weight in weights:
|
81
|
+
load_weights_per_param(name, loaded_weight)
|
82
|
+
else:
|
83
|
+
load_weights_per_param(name, loaded_weight)
|
84
|
+
|
85
|
+
|
86
|
+
EntryClass = LlamaEmbeddingModel
|
87
|
+
# compat: e5-mistral model.config class == MistralModel
|
88
|
+
EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
|
sglang/srt/models/llava.py
CHANGED
@@ -32,13 +32,12 @@ from vllm.config import CacheConfig
|
|
32
32
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
33
33
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
34
34
|
|
35
|
-
from sglang.srt.managers.schedule_batch import ForwardMode
|
36
35
|
from sglang.srt.mm_utils import (
|
37
36
|
get_anyres_image_grid_shape,
|
38
37
|
unpad_image,
|
39
38
|
unpad_image_shape,
|
40
39
|
)
|
41
|
-
from sglang.srt.model_executor.
|
40
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
42
41
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
43
42
|
from sglang.srt.models.mistral import MistralForCausalLM
|
44
43
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
sglang/srt/models/llavavid.py
CHANGED
@@ -26,13 +26,12 @@ from vllm.config import CacheConfig
|
|
26
26
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
27
27
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
28
28
|
|
29
|
-
from sglang.srt.managers.schedule_batch import ForwardMode
|
30
29
|
from sglang.srt.mm_utils import (
|
31
30
|
get_anyres_image_grid_shape,
|
32
31
|
unpad_image,
|
33
32
|
unpad_image_shape,
|
34
33
|
)
|
35
|
-
from sglang.srt.model_executor.
|
34
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
36
35
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
37
36
|
|
38
37
|
|
sglang/srt/models/minicpm.py
CHANGED
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
39
39
|
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
-
from sglang.srt.model_executor.
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
43
43
|
|
44
44
|
|
45
45
|
class MiniCPMMLP(nn.Module):
|
sglang/srt/models/mixtral.py
CHANGED
@@ -50,7 +50,7 @@ from vllm.utils import print_warning_once
|
|
50
50
|
|
51
51
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
52
52
|
from sglang.srt.layers.radix_attention import RadixAttention
|
53
|
-
from sglang.srt.model_executor.
|
53
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
54
54
|
|
55
55
|
|
56
56
|
class MixtralMoE(nn.Module):
|
@@ -45,7 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.model_executor.
|
48
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
49
|
|
50
50
|
|
51
51
|
class MixtralMLP(nn.Module):
|
sglang/srt/models/qwen.py
CHANGED
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
39
39
|
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
-
from sglang.srt.model_executor.
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
43
43
|
|
44
44
|
|
45
45
|
class QWenMLP(nn.Module):
|
sglang/srt/models/qwen2.py
CHANGED
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
39
39
|
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
-
from sglang.srt.model_executor.
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
43
43
|
|
44
44
|
Qwen2Config = None
|
45
45
|
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -46,12 +46,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
46
46
|
VocabParallelEmbedding,
|
47
47
|
)
|
48
48
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
49
|
-
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
50
|
-
from vllm.sequence import IntermediateTensors, SamplerOutput
|
51
49
|
|
52
50
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
53
51
|
from sglang.srt.layers.radix_attention import RadixAttention
|
54
|
-
from sglang.srt.model_executor.
|
52
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
55
53
|
|
56
54
|
|
57
55
|
class Qwen2MoeMLP(nn.Module):
|
@@ -368,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
368
366
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
369
367
|
)
|
370
368
|
self.logits_processor = LogitsProcessor(config)
|
371
|
-
self.sampler = Sampler()
|
372
369
|
|
373
370
|
@torch.no_grad()
|
374
371
|
def forward(
|
@@ -394,14 +391,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
394
391
|
)
|
395
392
|
return logits
|
396
393
|
|
397
|
-
def sample(
|
398
|
-
self,
|
399
|
-
logits: Optional[torch.Tensor],
|
400
|
-
sampling_metadata: SamplingMetadata,
|
401
|
-
) -> Optional[SamplerOutput]:
|
402
|
-
next_tokens = self.sampler(logits, sampling_metadata)
|
403
|
-
return next_tokens
|
404
|
-
|
405
394
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
406
395
|
stacked_params_mapping = [
|
407
396
|
# (param_name, shard_name, shard_id)
|
sglang/srt/models/stablelm.py
CHANGED
@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
40
40
|
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.model_executor.
|
43
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
44
|
|
45
45
|
|
46
46
|
class StablelmMLP(nn.Module):
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -34,7 +34,7 @@ from sglang.srt.conversation import (
|
|
34
34
|
generate_chat_conv,
|
35
35
|
register_conv_template,
|
36
36
|
)
|
37
|
-
from sglang.srt.managers.io_struct import GenerateReqInput
|
37
|
+
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
38
38
|
from sglang.srt.openai_api.protocol import (
|
39
39
|
BatchRequest,
|
40
40
|
BatchResponse,
|
@@ -52,7 +52,11 @@ from sglang.srt.openai_api.protocol import (
|
|
52
52
|
CompletionResponseStreamChoice,
|
53
53
|
CompletionStreamResponse,
|
54
54
|
DeltaMessage,
|
55
|
+
EmbeddingObject,
|
56
|
+
EmbeddingRequest,
|
57
|
+
EmbeddingResponse,
|
55
58
|
ErrorResponse,
|
59
|
+
FileDeleteResponse,
|
56
60
|
FileRequest,
|
57
61
|
FileResponse,
|
58
62
|
LogProbs,
|
@@ -73,7 +77,7 @@ class FileMetadata:
|
|
73
77
|
batch_storage: Dict[str, BatchResponse] = {}
|
74
78
|
file_id_request: Dict[str, FileMetadata] = {}
|
75
79
|
file_id_response: Dict[str, FileResponse] = {}
|
76
|
-
# map file id to file path in
|
80
|
+
# map file id to file path in SGLang backend
|
77
81
|
file_id_storage: Dict[str, str] = {}
|
78
82
|
|
79
83
|
|
@@ -81,6 +85,19 @@ file_id_storage: Dict[str, str] = {}
|
|
81
85
|
storage_dir = None
|
82
86
|
|
83
87
|
|
88
|
+
def format_finish_reason(finish_reason) -> Optional[str]:
|
89
|
+
if finish_reason.startswith("None"):
|
90
|
+
return None
|
91
|
+
elif finish_reason.startswith("FINISH_MATCHED"):
|
92
|
+
return "stop"
|
93
|
+
elif finish_reason.startswith("FINISH_LENGTH"):
|
94
|
+
return "length"
|
95
|
+
elif finish_reason.startswith("FINISH_ABORT"):
|
96
|
+
return "abort"
|
97
|
+
else:
|
98
|
+
return "unknown"
|
99
|
+
|
100
|
+
|
84
101
|
def create_error_response(
|
85
102
|
message: str,
|
86
103
|
err_type: str = "BadRequestError",
|
@@ -174,6 +191,20 @@ async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str
|
|
174
191
|
return {"error": "Invalid input", "details": e.errors()}
|
175
192
|
|
176
193
|
|
194
|
+
async def v1_delete_file(file_id: str):
|
195
|
+
# Retrieve the file job from the in-memory storage
|
196
|
+
file_response = file_id_response.get(file_id)
|
197
|
+
if file_response is None:
|
198
|
+
raise HTTPException(status_code=404, detail="File not found")
|
199
|
+
file_path = file_id_storage.get(file_id)
|
200
|
+
if file_path is None:
|
201
|
+
raise HTTPException(status_code=404, detail="File not found")
|
202
|
+
os.remove(file_path)
|
203
|
+
del file_id_response[file_id]
|
204
|
+
del file_id_storage[file_id]
|
205
|
+
return FileDeleteResponse(id=file_id, deleted=True)
|
206
|
+
|
207
|
+
|
177
208
|
async def v1_batches(tokenizer_manager, raw_request: Request):
|
178
209
|
try:
|
179
210
|
body = await raw_request.json()
|
@@ -287,6 +318,13 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
287
318
|
retrieve_batch = batch_storage[batch_id]
|
288
319
|
retrieve_batch.output_file_id = output_file_id
|
289
320
|
file_id_storage[output_file_id] = output_file_path
|
321
|
+
file_id_response[output_file_id] = FileResponse(
|
322
|
+
id=output_file_id,
|
323
|
+
bytes=os.path.getsize(output_file_path),
|
324
|
+
created_at=int(time.time()),
|
325
|
+
filename=f"{output_file_id}.jsonl",
|
326
|
+
purpose="batch_result",
|
327
|
+
)
|
290
328
|
# Update batch status to "completed"
|
291
329
|
retrieve_batch.status = "completed"
|
292
330
|
retrieve_batch.completed_at = int(time.time())
|
@@ -297,7 +335,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
297
335
|
}
|
298
336
|
|
299
337
|
except Exception as e:
|
300
|
-
print("error in
|
338
|
+
print("error in SGLang:", e)
|
301
339
|
# Update batch status to "failed"
|
302
340
|
retrieve_batch = batch_storage[batch_id]
|
303
341
|
retrieve_batch.status = "failed"
|
@@ -335,7 +373,6 @@ async def v1_retrieve_file_content(file_id: str):
|
|
335
373
|
|
336
374
|
|
337
375
|
def v1_generate_request(all_requests):
|
338
|
-
|
339
376
|
prompts = []
|
340
377
|
sampling_params_list = []
|
341
378
|
return_logprobs = []
|
@@ -356,10 +393,13 @@ def v1_generate_request(all_requests):
|
|
356
393
|
{
|
357
394
|
"temperature": request.temperature,
|
358
395
|
"max_new_tokens": request.max_tokens,
|
396
|
+
"min_new_tokens": request.min_tokens,
|
359
397
|
"stop": request.stop,
|
398
|
+
"stop_token_ids": request.stop_token_ids,
|
360
399
|
"top_p": request.top_p,
|
361
400
|
"presence_penalty": request.presence_penalty,
|
362
401
|
"frequency_penalty": request.frequency_penalty,
|
402
|
+
"repetition_penalty": request.repetition_penalty,
|
363
403
|
"regex": request.regex,
|
364
404
|
"n": request.n,
|
365
405
|
"ignore_eos": request.ignore_eos,
|
@@ -380,7 +420,7 @@ def v1_generate_request(all_requests):
|
|
380
420
|
else:
|
381
421
|
prompt_kwargs = {"input_ids": prompt}
|
382
422
|
else:
|
383
|
-
if isinstance(prompts[0], str)
|
423
|
+
if isinstance(prompts[0], str):
|
384
424
|
prompt_kwargs = {"text": prompts}
|
385
425
|
else:
|
386
426
|
prompt_kwargs = {"input_ids": prompts}
|
@@ -463,14 +503,18 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|
463
503
|
"index": 0,
|
464
504
|
"text": text,
|
465
505
|
"logprobs": logprobs,
|
466
|
-
"finish_reason":
|
506
|
+
"finish_reason": format_finish_reason(
|
507
|
+
ret_item["meta_info"]["finish_reason"]
|
508
|
+
),
|
467
509
|
}
|
468
510
|
else:
|
469
511
|
choice_data = CompletionResponseChoice(
|
470
512
|
index=idx,
|
471
513
|
text=text,
|
472
514
|
logprobs=logprobs,
|
473
|
-
finish_reason=
|
515
|
+
finish_reason=format_finish_reason(
|
516
|
+
ret_item["meta_info"]["finish_reason"]
|
517
|
+
),
|
474
518
|
)
|
475
519
|
|
476
520
|
choices.append(choice_data)
|
@@ -500,7 +544,9 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|
500
544
|
responses.append(response)
|
501
545
|
return responses
|
502
546
|
else:
|
503
|
-
prompt_tokens = sum(
|
547
|
+
prompt_tokens = sum(
|
548
|
+
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
|
549
|
+
)
|
504
550
|
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
505
551
|
response = CompletionResponse(
|
506
552
|
id=ret[0]["meta_info"]["id"],
|
@@ -583,20 +629,34 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
583
629
|
index=0,
|
584
630
|
text=delta,
|
585
631
|
logprobs=logprobs,
|
586
|
-
finish_reason=
|
632
|
+
finish_reason=format_finish_reason(
|
633
|
+
content["meta_info"]["finish_reason"]
|
634
|
+
),
|
587
635
|
)
|
588
636
|
chunk = CompletionStreamResponse(
|
589
637
|
id=content["meta_info"]["id"],
|
590
638
|
object="text_completion",
|
591
639
|
choices=[choice_data],
|
592
640
|
model=request.model,
|
593
|
-
usage=UsageInfo(
|
594
|
-
prompt_tokens=prompt_tokens,
|
595
|
-
completion_tokens=completion_tokens,
|
596
|
-
total_tokens=prompt_tokens + completion_tokens,
|
597
|
-
),
|
598
641
|
)
|
599
642
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
643
|
+
if request.stream_options and request.stream_options.include_usage:
|
644
|
+
usage = UsageInfo(
|
645
|
+
prompt_tokens=prompt_tokens,
|
646
|
+
completion_tokens=completion_tokens,
|
647
|
+
total_tokens=prompt_tokens + completion_tokens,
|
648
|
+
)
|
649
|
+
|
650
|
+
final_usage_chunk = CompletionStreamResponse(
|
651
|
+
id=str(uuid.uuid4().hex),
|
652
|
+
choices=[],
|
653
|
+
model=request.model,
|
654
|
+
usage=usage,
|
655
|
+
)
|
656
|
+
final_usage_data = final_usage_chunk.model_dump_json(
|
657
|
+
exclude_unset=True, exclude_none=True
|
658
|
+
)
|
659
|
+
yield f"data: {final_usage_data}\n\n"
|
600
660
|
except ValueError as e:
|
601
661
|
error = create_streaming_error_response(str(e))
|
602
662
|
yield f"data: {error}\n\n"
|
@@ -624,7 +684,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
624
684
|
|
625
685
|
|
626
686
|
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
627
|
-
|
628
687
|
input_ids = []
|
629
688
|
sampling_params_list = []
|
630
689
|
image_data_list = []
|
@@ -667,10 +726,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
667
726
|
{
|
668
727
|
"temperature": request.temperature,
|
669
728
|
"max_new_tokens": request.max_tokens,
|
729
|
+
"min_new_tokens": request.min_tokens,
|
670
730
|
"stop": stop,
|
731
|
+
"stop_token_ids": request.stop_token_ids,
|
671
732
|
"top_p": request.top_p,
|
672
733
|
"presence_penalty": request.presence_penalty,
|
673
734
|
"frequency_penalty": request.frequency_penalty,
|
735
|
+
"repetition_penalty": request.repetition_penalty,
|
674
736
|
"regex": request.regex,
|
675
737
|
"n": request.n,
|
676
738
|
}
|
@@ -707,8 +769,6 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
707
769
|
|
708
770
|
def v1_chat_generate_response(request, ret, to_file=False):
|
709
771
|
choices = []
|
710
|
-
total_prompt_tokens = 0
|
711
|
-
total_completion_tokens = 0
|
712
772
|
|
713
773
|
for idx, ret_item in enumerate(ret):
|
714
774
|
logprobs = False
|
@@ -747,8 +807,6 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
747
807
|
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
|
748
808
|
else:
|
749
809
|
choice_logprobs = None
|
750
|
-
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
751
|
-
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
752
810
|
|
753
811
|
if to_file:
|
754
812
|
# to make the choice data json serializable
|
@@ -756,19 +814,22 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
756
814
|
"index": 0,
|
757
815
|
"message": {"role": "assistant", "content": ret_item["text"]},
|
758
816
|
"logprobs": choice_logprobs,
|
759
|
-
"finish_reason":
|
817
|
+
"finish_reason": format_finish_reason(
|
818
|
+
ret_item["meta_info"]["finish_reason"]
|
819
|
+
),
|
760
820
|
}
|
761
821
|
else:
|
762
822
|
choice_data = ChatCompletionResponseChoice(
|
763
823
|
index=idx,
|
764
824
|
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
765
825
|
logprobs=choice_logprobs,
|
766
|
-
finish_reason=
|
826
|
+
finish_reason=format_finish_reason(
|
827
|
+
ret_item["meta_info"]["finish_reason"]
|
828
|
+
),
|
767
829
|
)
|
768
830
|
|
769
831
|
choices.append(choice_data)
|
770
|
-
|
771
|
-
total_completion_tokens += completion_tokens
|
832
|
+
|
772
833
|
if to_file:
|
773
834
|
responses = []
|
774
835
|
|
@@ -795,14 +856,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
795
856
|
responses.append(response)
|
796
857
|
return responses
|
797
858
|
else:
|
859
|
+
prompt_tokens = sum(
|
860
|
+
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
|
861
|
+
)
|
862
|
+
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
798
863
|
response = ChatCompletionResponse(
|
799
864
|
id=ret[0]["meta_info"]["id"],
|
800
865
|
model=request.model,
|
801
866
|
choices=choices,
|
802
867
|
usage=UsageInfo(
|
803
|
-
prompt_tokens=
|
804
|
-
completion_tokens=
|
805
|
-
total_tokens=
|
868
|
+
prompt_tokens=prompt_tokens,
|
869
|
+
completion_tokens=completion_tokens,
|
870
|
+
total_tokens=prompt_tokens + completion_tokens,
|
806
871
|
),
|
807
872
|
)
|
808
873
|
return response
|
@@ -877,18 +942,15 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
877
942
|
choice_data = ChatCompletionResponseStreamChoice(
|
878
943
|
index=0,
|
879
944
|
delta=DeltaMessage(role="assistant"),
|
880
|
-
finish_reason=
|
945
|
+
finish_reason=format_finish_reason(
|
946
|
+
content["meta_info"]["finish_reason"]
|
947
|
+
),
|
881
948
|
logprobs=choice_logprobs,
|
882
949
|
)
|
883
950
|
chunk = ChatCompletionStreamResponse(
|
884
951
|
id=content["meta_info"]["id"],
|
885
952
|
choices=[choice_data],
|
886
953
|
model=request.model,
|
887
|
-
usage=UsageInfo(
|
888
|
-
prompt_tokens=prompt_tokens,
|
889
|
-
completion_tokens=completion_tokens,
|
890
|
-
total_tokens=prompt_tokens + completion_tokens,
|
891
|
-
),
|
892
954
|
)
|
893
955
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
894
956
|
|
@@ -898,20 +960,34 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
898
960
|
choice_data = ChatCompletionResponseStreamChoice(
|
899
961
|
index=0,
|
900
962
|
delta=DeltaMessage(content=delta),
|
901
|
-
finish_reason=
|
963
|
+
finish_reason=format_finish_reason(
|
964
|
+
content["meta_info"]["finish_reason"]
|
965
|
+
),
|
902
966
|
logprobs=choice_logprobs,
|
903
967
|
)
|
904
968
|
chunk = ChatCompletionStreamResponse(
|
905
969
|
id=content["meta_info"]["id"],
|
906
970
|
choices=[choice_data],
|
907
971
|
model=request.model,
|
908
|
-
usage=UsageInfo(
|
909
|
-
prompt_tokens=prompt_tokens,
|
910
|
-
completion_tokens=completion_tokens,
|
911
|
-
total_tokens=prompt_tokens + completion_tokens,
|
912
|
-
),
|
913
972
|
)
|
914
973
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
974
|
+
if request.stream_options and request.stream_options.include_usage:
|
975
|
+
usage = UsageInfo(
|
976
|
+
prompt_tokens=prompt_tokens,
|
977
|
+
completion_tokens=completion_tokens,
|
978
|
+
total_tokens=prompt_tokens + completion_tokens,
|
979
|
+
)
|
980
|
+
|
981
|
+
final_usage_chunk = ChatCompletionStreamResponse(
|
982
|
+
id=str(uuid.uuid4().hex),
|
983
|
+
choices=[],
|
984
|
+
model=request.model,
|
985
|
+
usage=usage,
|
986
|
+
)
|
987
|
+
final_usage_data = final_usage_chunk.model_dump_json(
|
988
|
+
exclude_unset=True, exclude_none=True
|
989
|
+
)
|
990
|
+
yield f"data: {final_usage_data}\n\n"
|
915
991
|
except ValueError as e:
|
916
992
|
error = create_streaming_error_response(str(e))
|
917
993
|
yield f"data: {error}\n\n"
|
@@ -930,7 +1006,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
930
1006
|
).__anext__()
|
931
1007
|
except ValueError as e:
|
932
1008
|
return create_error_response(str(e))
|
933
|
-
|
934
1009
|
if not isinstance(ret, list):
|
935
1010
|
ret = [ret]
|
936
1011
|
|
@@ -939,6 +1014,81 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
939
1014
|
return response
|
940
1015
|
|
941
1016
|
|
1017
|
+
def v1_embedding_request(all_requests, tokenizer_manager):
|
1018
|
+
prompts = []
|
1019
|
+
sampling_params_list = []
|
1020
|
+
first_prompt_type = type(all_requests[0].input)
|
1021
|
+
|
1022
|
+
for request in all_requests:
|
1023
|
+
prompt = request.input
|
1024
|
+
assert (
|
1025
|
+
type(prompt) == first_prompt_type
|
1026
|
+
), "All prompts must be of the same type in file input settings"
|
1027
|
+
prompts.append(prompt)
|
1028
|
+
|
1029
|
+
if len(all_requests) == 1:
|
1030
|
+
prompt = prompts[0]
|
1031
|
+
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
1032
|
+
prompt_kwargs = {"text": prompt}
|
1033
|
+
else:
|
1034
|
+
prompt_kwargs = {"input_ids": prompt}
|
1035
|
+
else:
|
1036
|
+
if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
|
1037
|
+
prompt_kwargs = {"text": prompts}
|
1038
|
+
else:
|
1039
|
+
prompt_kwargs = {"input_ids": prompts}
|
1040
|
+
|
1041
|
+
adapted_request = EmbeddingReqInput(
|
1042
|
+
**prompt_kwargs,
|
1043
|
+
)
|
1044
|
+
|
1045
|
+
if len(all_requests) == 1:
|
1046
|
+
return adapted_request, all_requests[0]
|
1047
|
+
return adapted_request, all_requests
|
1048
|
+
|
1049
|
+
|
1050
|
+
def v1_embedding_response(ret, model_path, to_file=False):
|
1051
|
+
embedding_objects = []
|
1052
|
+
prompt_tokens = 0
|
1053
|
+
for idx, ret_item in enumerate(ret):
|
1054
|
+
embedding_objects.append(
|
1055
|
+
EmbeddingObject(
|
1056
|
+
embedding=ret[idx]["embedding"],
|
1057
|
+
index=idx,
|
1058
|
+
)
|
1059
|
+
)
|
1060
|
+
prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]
|
1061
|
+
|
1062
|
+
return EmbeddingResponse(
|
1063
|
+
data=embedding_objects,
|
1064
|
+
model=model_path,
|
1065
|
+
usage=UsageInfo(
|
1066
|
+
prompt_tokens=prompt_tokens,
|
1067
|
+
total_tokens=prompt_tokens,
|
1068
|
+
),
|
1069
|
+
)
|
1070
|
+
|
1071
|
+
|
1072
|
+
async def v1_embeddings(tokenizer_manager, raw_request: Request):
|
1073
|
+
request_json = await raw_request.json()
|
1074
|
+
all_requests = [EmbeddingRequest(**request_json)]
|
1075
|
+
adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager)
|
1076
|
+
|
1077
|
+
try:
|
1078
|
+
ret = await tokenizer_manager.generate_request(
|
1079
|
+
adapted_request, raw_request
|
1080
|
+
).__anext__()
|
1081
|
+
except ValueError as e:
|
1082
|
+
return create_error_response(str(e))
|
1083
|
+
|
1084
|
+
if not isinstance(ret, list):
|
1085
|
+
ret = [ret]
|
1086
|
+
|
1087
|
+
response = v1_embedding_response(ret, tokenizer_manager.model_path)
|
1088
|
+
|
1089
|
+
return response
|
1090
|
+
|
1091
|
+
|
942
1092
|
def to_openai_style_logprobs(
|
943
1093
|
input_token_logprobs=None,
|
944
1094
|
output_token_logprobs=None,
|