sglang 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +33 -26
- sglang/api.py +9 -1
- sglang/bench_latency.py +2 -2
- sglang/bench_serving.py +10 -1
- sglang/check_env.py +1 -1
- sglang/lang/backend/litellm.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/interpreter.py +21 -5
- sglang/lang/ir.py +1 -2
- sglang/srt/constrained/__init__.py +15 -0
- sglang/srt/constrained/{base_cache.py → base_tool_cache.py} +17 -2
- sglang/srt/constrained/fsm_cache.py +17 -2
- sglang/srt/constrained/jump_forward.py +17 -2
- sglang/srt/conversation.py +26 -0
- sglang/srt/hf_transformers_utils.py +15 -0
- sglang/srt/layers/context_flashattention_nopad.py +15 -0
- sglang/srt/layers/extend_attention.py +15 -0
- sglang/srt/layers/fused_moe.py +15 -0
- sglang/srt/layers/linear.py +15 -0
- sglang/srt/layers/logits_processor.py +41 -13
- sglang/srt/layers/quantization/__init__.py +15 -0
- sglang/srt/layers/quantization/fp8.py +15 -0
- sglang/srt/layers/radix_attention.py +17 -2
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
- sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
- sglang/srt/managers/detokenizer_manager.py +16 -1
- sglang/srt/managers/io_struct.py +36 -3
- sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
- sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +60 -21
- sglang/srt/managers/tokenizer_manager.py +39 -16
- sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +159 -46
- sglang/srt/mem_cache/base_cache.py +43 -0
- sglang/srt/mem_cache/chunk_cache.py +60 -0
- sglang/srt/mem_cache/flush_cache.py +33 -0
- sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
- sglang/srt/{managers/controller → mem_cache}/radix_cache.py +20 -2
- sglang/srt/mm_utils.py +15 -0
- sglang/srt/model_config.py +15 -0
- sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +16 -1
- sglang/srt/{managers/controller → model_executor}/model_runner.py +49 -14
- sglang/srt/model_loader/model_loader.py +15 -0
- sglang/srt/model_loader/utils.py +16 -1
- sglang/srt/models/chatglm.py +16 -1
- sglang/srt/models/commandr.py +16 -1
- sglang/srt/models/dbrx.py +16 -1
- sglang/srt/models/deepseek.py +16 -1
- sglang/srt/models/deepseek_v2.py +16 -1
- sglang/srt/models/gemma.py +16 -1
- sglang/srt/models/gemma2.py +16 -1
- sglang/srt/models/gpt_bigcode.py +16 -1
- sglang/srt/models/grok.py +16 -1
- sglang/srt/models/internlm2.py +16 -1
- sglang/srt/models/llama2.py +21 -22
- sglang/srt/models/llama_classification.py +16 -1
- sglang/srt/models/llava.py +17 -2
- sglang/srt/models/llavavid.py +17 -2
- sglang/srt/models/minicpm.py +16 -1
- sglang/srt/models/mistral.py +15 -0
- sglang/srt/models/mixtral.py +16 -1
- sglang/srt/models/mixtral_quant.py +16 -1
- sglang/srt/models/qwen.py +16 -1
- sglang/srt/models/qwen2.py +16 -1
- sglang/srt/models/qwen2_moe.py +16 -1
- sglang/srt/models/stablelm.py +16 -1
- sglang/srt/models/yivl.py +15 -0
- sglang/srt/openai_api/adapter.py +569 -131
- sglang/srt/openai_api/protocol.py +84 -2
- sglang/srt/sampling_params.py +15 -0
- sglang/srt/server.py +92 -23
- sglang/srt/server_args.py +52 -11
- sglang/srt/utils.py +15 -0
- sglang/test/test_programs.py +9 -6
- sglang/utils.py +22 -0
- sglang/version.py +1 -1
- {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/METADATA +33 -7
- sglang-0.2.8.dist-info/RECORD +95 -0
- {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/WHEEL +1 -1
- sglang/srt/flush_cache.py +0 -18
- sglang-0.2.6.dist-info/RECORD +0 -93
- {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/LICENSE +0 -0
- {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/top_level.txt +0 -0
sglang/srt/openai_api/adapter.py
CHANGED
@@ -1,12 +1,31 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""Conversion between OpenAI APIs and native SRT APIs"""
|
2
17
|
|
3
18
|
import asyncio
|
4
19
|
import json
|
5
20
|
import os
|
21
|
+
import time
|
22
|
+
import uuid
|
6
23
|
from http import HTTPStatus
|
24
|
+
from typing import Dict, List, Optional
|
7
25
|
|
8
|
-
from fastapi import Request
|
26
|
+
from fastapi import HTTPException, Request, UploadFile
|
9
27
|
from fastapi.responses import JSONResponse, StreamingResponse
|
28
|
+
from pydantic import ValidationError
|
10
29
|
|
11
30
|
from sglang.srt.conversation import (
|
12
31
|
Conversation,
|
@@ -17,12 +36,16 @@ from sglang.srt.conversation import (
|
|
17
36
|
)
|
18
37
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
19
38
|
from sglang.srt.openai_api.protocol import (
|
39
|
+
BatchRequest,
|
40
|
+
BatchResponse,
|
20
41
|
ChatCompletionRequest,
|
21
42
|
ChatCompletionResponse,
|
22
43
|
ChatCompletionResponseChoice,
|
23
44
|
ChatCompletionResponseStreamChoice,
|
24
45
|
ChatCompletionStreamResponse,
|
46
|
+
ChatCompletionTokenLogprob,
|
25
47
|
ChatMessage,
|
48
|
+
ChoiceLogprobs,
|
26
49
|
CompletionRequest,
|
27
50
|
CompletionResponse,
|
28
51
|
CompletionResponseChoice,
|
@@ -30,13 +53,34 @@ from sglang.srt.openai_api.protocol import (
|
|
30
53
|
CompletionStreamResponse,
|
31
54
|
DeltaMessage,
|
32
55
|
ErrorResponse,
|
56
|
+
FileRequest,
|
57
|
+
FileResponse,
|
33
58
|
LogProbs,
|
59
|
+
TopLogprob,
|
34
60
|
UsageInfo,
|
35
61
|
)
|
36
62
|
|
37
63
|
chat_template_name = None
|
38
64
|
|
39
65
|
|
66
|
+
class FileMetadata:
|
67
|
+
def __init__(self, filename: str, purpose: str):
|
68
|
+
self.filename = filename
|
69
|
+
self.purpose = purpose
|
70
|
+
|
71
|
+
|
72
|
+
# In-memory storage for batch jobs and files
|
73
|
+
batch_storage: Dict[str, BatchResponse] = {}
|
74
|
+
file_id_request: Dict[str, FileMetadata] = {}
|
75
|
+
file_id_response: Dict[str, FileResponse] = {}
|
76
|
+
# map file id to file path in SGlang backend
|
77
|
+
file_id_storage: Dict[str, str] = {}
|
78
|
+
|
79
|
+
|
80
|
+
# backend storage directory
|
81
|
+
storage_dir = None
|
82
|
+
|
83
|
+
|
40
84
|
def create_error_response(
|
41
85
|
message: str,
|
42
86
|
err_type: str = "BadRequestError",
|
@@ -91,33 +135,368 @@ def load_chat_template_for_openai_api(chat_template_arg):
|
|
91
135
|
chat_template_name = chat_template_arg
|
92
136
|
|
93
137
|
|
94
|
-
async def
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
138
|
+
async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None):
|
139
|
+
try:
|
140
|
+
global storage_dir
|
141
|
+
if file_storage_pth:
|
142
|
+
storage_dir = file_storage_pth
|
143
|
+
# Read the file content
|
144
|
+
file_content = await file.read()
|
145
|
+
|
146
|
+
# Create an instance of RequestBody
|
147
|
+
request_body = FileRequest(file=file_content, purpose=purpose)
|
148
|
+
|
149
|
+
# Save the file to the sglang_oai_storage directory
|
150
|
+
os.makedirs(storage_dir, exist_ok=True)
|
151
|
+
file_id = f"backend_input_file-{uuid.uuid4()}"
|
152
|
+
filename = f"{file_id}.jsonl"
|
153
|
+
file_path = os.path.join(storage_dir, filename)
|
154
|
+
|
155
|
+
with open(file_path, "wb") as f:
|
156
|
+
f.write(request_body.file)
|
157
|
+
|
158
|
+
# add info to global file map
|
159
|
+
file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose)
|
160
|
+
file_id_storage[file_id] = file_path
|
161
|
+
|
162
|
+
# Return the response in the required format
|
163
|
+
response = FileResponse(
|
164
|
+
id=file_id,
|
165
|
+
bytes=len(request_body.file),
|
166
|
+
created_at=int(time.time()),
|
167
|
+
filename=file.filename,
|
168
|
+
purpose=request_body.purpose,
|
169
|
+
)
|
170
|
+
file_id_response[file_id] = response
|
171
|
+
|
172
|
+
return response
|
173
|
+
except ValidationError as e:
|
174
|
+
return {"error": "Invalid input", "details": e.errors()}
|
175
|
+
|
176
|
+
|
177
|
+
async def v1_batches(tokenizer_manager, raw_request: Request):
|
178
|
+
try:
|
179
|
+
body = await raw_request.json()
|
102
180
|
|
181
|
+
batch_request = BatchRequest(**body)
|
182
|
+
|
183
|
+
batch_id = f"batch_{uuid.uuid4()}"
|
184
|
+
|
185
|
+
# Create an instance of BatchResponse
|
186
|
+
batch_response = BatchResponse(
|
187
|
+
id=batch_id,
|
188
|
+
endpoint=batch_request.endpoint,
|
189
|
+
input_file_id=batch_request.input_file_id,
|
190
|
+
completion_window=batch_request.completion_window,
|
191
|
+
created_at=int(time.time()),
|
192
|
+
metadata=batch_request.metadata,
|
193
|
+
)
|
194
|
+
|
195
|
+
batch_storage[batch_id] = batch_response
|
196
|
+
|
197
|
+
# Start processing the batch asynchronously
|
198
|
+
asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
|
199
|
+
|
200
|
+
# Return the initial batch_response
|
201
|
+
return batch_response
|
202
|
+
|
203
|
+
except ValidationError as e:
|
204
|
+
return {"error": "Invalid input", "details": e.errors()}
|
205
|
+
except Exception as e:
|
206
|
+
return {"error": str(e)}
|
207
|
+
|
208
|
+
|
209
|
+
async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
|
210
|
+
try:
|
211
|
+
# Update the batch status to "in_progress"
|
212
|
+
batch_storage[batch_id].status = "in_progress"
|
213
|
+
batch_storage[batch_id].in_progress_at = int(time.time())
|
214
|
+
|
215
|
+
# Retrieve the input file content
|
216
|
+
input_file_request = file_id_request.get(batch_request.input_file_id)
|
217
|
+
if not input_file_request:
|
218
|
+
raise ValueError("Input file not found")
|
219
|
+
|
220
|
+
# Parse the JSONL file and process each request
|
221
|
+
input_file_path = file_id_storage.get(batch_request.input_file_id)
|
222
|
+
with open(input_file_path, "r", encoding="utf-8") as f:
|
223
|
+
lines = f.readlines()
|
224
|
+
|
225
|
+
total_requests = len(lines)
|
226
|
+
completed_requests = 0
|
227
|
+
failed_requests = 0
|
228
|
+
|
229
|
+
all_ret = []
|
230
|
+
end_point = batch_storage[batch_id].endpoint
|
231
|
+
file_request_list = []
|
232
|
+
all_requests = []
|
233
|
+
for line in lines:
|
234
|
+
request_data = json.loads(line)
|
235
|
+
file_request_list.append(request_data)
|
236
|
+
body = request_data["body"]
|
237
|
+
if end_point == "/v1/chat/completions":
|
238
|
+
all_requests.append(ChatCompletionRequest(**body))
|
239
|
+
elif end_point == "/v1/completions":
|
240
|
+
all_requests.append(CompletionRequest(**body))
|
241
|
+
if end_point == "/v1/chat/completions":
|
242
|
+
adapted_request, request = v1_chat_generate_request(
|
243
|
+
all_requests, tokenizer_manager
|
244
|
+
)
|
245
|
+
elif end_point == "/v1/completions":
|
246
|
+
adapted_request, request = v1_generate_request(all_requests)
|
247
|
+
try:
|
248
|
+
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
|
249
|
+
if not isinstance(ret, list):
|
250
|
+
ret = [ret]
|
251
|
+
if end_point == "/v1/chat/completions":
|
252
|
+
responses = v1_chat_generate_response(request, ret, to_file=True)
|
253
|
+
else:
|
254
|
+
responses = v1_generate_response(request, ret, to_file=True)
|
255
|
+
|
256
|
+
except Exception as e:
|
257
|
+
error_json = {
|
258
|
+
"id": f"batch_req_{uuid.uuid4()}",
|
259
|
+
"custom_id": request_data.get("custom_id"),
|
260
|
+
"response": None,
|
261
|
+
"error": {"message": str(e)},
|
262
|
+
}
|
263
|
+
all_ret.append(error_json)
|
264
|
+
failed_requests += len(file_request_list)
|
265
|
+
|
266
|
+
for idx, response in enumerate(responses):
|
267
|
+
# the batch_req here can be changed to be named within a batch granularity
|
268
|
+
response_json = {
|
269
|
+
"id": f"batch_req_{uuid.uuid4()}",
|
270
|
+
"custom_id": file_request_list[idx].get("custom_id"),
|
271
|
+
"response": response,
|
272
|
+
"error": None,
|
273
|
+
}
|
274
|
+
all_ret.append(response_json)
|
275
|
+
completed_requests += 1
|
276
|
+
# Write results to a new file
|
277
|
+
output_file_id = f"backend_result_file-{uuid.uuid4()}"
|
278
|
+
global storage_dir
|
279
|
+
output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl")
|
280
|
+
with open(output_file_path, "w", encoding="utf-8") as f:
|
281
|
+
for ret in all_ret:
|
282
|
+
f.write(json.dumps(ret) + "\n")
|
283
|
+
|
284
|
+
# Update batch response with output file information
|
285
|
+
retrieve_batch = batch_storage[batch_id]
|
286
|
+
retrieve_batch.output_file_id = output_file_id
|
287
|
+
file_id_storage[output_file_id] = output_file_path
|
288
|
+
# Update batch status to "completed"
|
289
|
+
retrieve_batch.status = "completed"
|
290
|
+
retrieve_batch.completed_at = int(time.time())
|
291
|
+
retrieve_batch.request_counts = {
|
292
|
+
"total": total_requests,
|
293
|
+
"completed": completed_requests,
|
294
|
+
"failed": failed_requests,
|
295
|
+
}
|
296
|
+
|
297
|
+
except Exception as e:
|
298
|
+
print("error in SGlang:", e)
|
299
|
+
# Update batch status to "failed"
|
300
|
+
retrieve_batch = batch_storage[batch_id]
|
301
|
+
retrieve_batch.status = "failed"
|
302
|
+
retrieve_batch.failed_at = int(time.time())
|
303
|
+
retrieve_batch.errors = {"message": str(e)}
|
304
|
+
|
305
|
+
|
306
|
+
async def v1_retrieve_batch(batch_id: str):
|
307
|
+
# Retrieve the batch job from the in-memory storage
|
308
|
+
batch_response = batch_storage.get(batch_id)
|
309
|
+
if batch_response is None:
|
310
|
+
raise HTTPException(status_code=404, detail="Batch not found")
|
311
|
+
|
312
|
+
return batch_response
|
313
|
+
|
314
|
+
|
315
|
+
async def v1_retrieve_file(file_id: str):
|
316
|
+
# Retrieve the batch job from the in-memory storage
|
317
|
+
file_response = file_id_response.get(file_id)
|
318
|
+
if file_response is None:
|
319
|
+
raise HTTPException(status_code=404, detail="File not found")
|
320
|
+
return file_response
|
321
|
+
|
322
|
+
|
323
|
+
async def v1_retrieve_file_content(file_id: str):
|
324
|
+
file_pth = file_id_storage.get(file_id)
|
325
|
+
if not file_pth or not os.path.exists(file_pth):
|
326
|
+
raise HTTPException(status_code=404, detail="File not found")
|
327
|
+
|
328
|
+
def iter_file():
|
329
|
+
with open(file_pth, mode="rb") as file_like:
|
330
|
+
yield from file_like
|
331
|
+
|
332
|
+
return StreamingResponse(iter_file(), media_type="application/octet-stream")
|
333
|
+
|
334
|
+
|
335
|
+
def v1_generate_request(all_requests):
|
336
|
+
|
337
|
+
prompts = []
|
338
|
+
sampling_params_list = []
|
339
|
+
return_logprobs = []
|
340
|
+
top_logprobs_nums = []
|
341
|
+
first_prompt_type = type(all_requests[0].prompt)
|
342
|
+
for request in all_requests:
|
343
|
+
prompt = request.prompt
|
344
|
+
assert (
|
345
|
+
type(prompt) == first_prompt_type
|
346
|
+
), "All prompts must be of the same type in file input settings"
|
347
|
+
prompts.append(prompt)
|
348
|
+
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
|
349
|
+
top_logprobs_nums.append(
|
350
|
+
request.logprobs if request.logprobs is not None else 0
|
351
|
+
)
|
352
|
+
sampling_params_list.append(
|
353
|
+
{
|
354
|
+
"temperature": request.temperature,
|
355
|
+
"max_new_tokens": request.max_tokens,
|
356
|
+
"stop": request.stop,
|
357
|
+
"top_p": request.top_p,
|
358
|
+
"presence_penalty": request.presence_penalty,
|
359
|
+
"frequency_penalty": request.frequency_penalty,
|
360
|
+
"regex": request.regex,
|
361
|
+
"n": request.n,
|
362
|
+
"ignore_eos": request.ignore_eos,
|
363
|
+
}
|
364
|
+
)
|
365
|
+
if len(all_requests) > 1 and request.n > 1:
|
366
|
+
raise ValueError(
|
367
|
+
"Batch operation is not supported for completions from files"
|
368
|
+
)
|
369
|
+
|
370
|
+
if len(all_requests) == 1:
|
371
|
+
prompt = prompts[0]
|
372
|
+
sampling_params_list = sampling_params_list[0]
|
373
|
+
return_logprobs = return_logprobs[0]
|
374
|
+
top_logprobs_nums = top_logprobs_nums[0]
|
375
|
+
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
376
|
+
prompt_kwargs = {"text": prompt}
|
377
|
+
else:
|
378
|
+
prompt_kwargs = {"input_ids": prompt}
|
379
|
+
else:
|
380
|
+
if isinstance(prompts[0], str):
|
381
|
+
prompt_kwargs = {"text": prompts}
|
382
|
+
else:
|
383
|
+
prompt_kwargs = {"input_ids": prompts}
|
103
384
|
adapted_request = GenerateReqInput(
|
104
385
|
**prompt_kwargs,
|
105
|
-
sampling_params=
|
106
|
-
|
107
|
-
|
108
|
-
"stop": request.stop,
|
109
|
-
"top_p": request.top_p,
|
110
|
-
"presence_penalty": request.presence_penalty,
|
111
|
-
"frequency_penalty": request.frequency_penalty,
|
112
|
-
"regex": request.regex,
|
113
|
-
"n": request.n,
|
114
|
-
"ignore_eos": request.ignore_eos,
|
115
|
-
},
|
116
|
-
return_logprob=request.logprobs is not None and request.logprobs > 0,
|
117
|
-
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
|
386
|
+
sampling_params=sampling_params_list,
|
387
|
+
return_logprob=return_logprobs,
|
388
|
+
top_logprobs_num=top_logprobs_nums,
|
118
389
|
return_text_in_logprobs=True,
|
119
|
-
stream=
|
390
|
+
stream=all_requests[0].stream,
|
120
391
|
)
|
392
|
+
if len(all_requests) == 1:
|
393
|
+
return adapted_request, all_requests[0]
|
394
|
+
return adapted_request, all_requests
|
395
|
+
|
396
|
+
|
397
|
+
def v1_generate_response(request, ret, to_file=False):
|
398
|
+
choices = []
|
399
|
+
echo = False
|
400
|
+
|
401
|
+
if (not isinstance(request, List)) and request.echo:
|
402
|
+
# TODO: handle the case propmt is token ids
|
403
|
+
if isinstance(request.prompt, list):
|
404
|
+
prompts = request.prompt
|
405
|
+
else:
|
406
|
+
prompts = [request.prompt]
|
407
|
+
echo = True
|
408
|
+
|
409
|
+
for idx, ret_item in enumerate(ret):
|
410
|
+
text = ret_item["text"]
|
411
|
+
if isinstance(request, List) and request[idx].echo:
|
412
|
+
echo = True
|
413
|
+
text = request[idx].prompt + text
|
414
|
+
if (not isinstance(request, List)) and echo:
|
415
|
+
text = prompts[idx] + text
|
416
|
+
|
417
|
+
logprobs = False
|
418
|
+
if isinstance(request, List) and request[idx].logprobs:
|
419
|
+
logprobs = True
|
420
|
+
elif (not isinstance(request, List)) and request.logprobs:
|
421
|
+
logprobs = True
|
422
|
+
if logprobs:
|
423
|
+
if echo:
|
424
|
+
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
|
425
|
+
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
|
426
|
+
else:
|
427
|
+
input_token_logprobs = None
|
428
|
+
input_top_logprobs = None
|
429
|
+
|
430
|
+
logprobs = to_openai_style_logprobs(
|
431
|
+
input_token_logprobs=input_token_logprobs,
|
432
|
+
input_top_logprobs=input_top_logprobs,
|
433
|
+
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
|
434
|
+
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
435
|
+
)
|
436
|
+
else:
|
437
|
+
logprobs = None
|
438
|
+
|
439
|
+
if to_file:
|
440
|
+
# to make the choise data json serializable
|
441
|
+
choice_data = {
|
442
|
+
"index": 0,
|
443
|
+
"text": text,
|
444
|
+
"logprobs": logprobs,
|
445
|
+
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
446
|
+
}
|
447
|
+
else:
|
448
|
+
choice_data = CompletionResponseChoice(
|
449
|
+
index=idx,
|
450
|
+
text=text,
|
451
|
+
logprobs=logprobs,
|
452
|
+
finish_reason=ret_item["meta_info"]["finish_reason"],
|
453
|
+
)
|
454
|
+
|
455
|
+
choices.append(choice_data)
|
456
|
+
|
457
|
+
if to_file:
|
458
|
+
responses = []
|
459
|
+
for i, choice in enumerate(choices):
|
460
|
+
response = {
|
461
|
+
"status_code": 200,
|
462
|
+
"request_id": ret[i]["meta_info"]["id"],
|
463
|
+
"body": {
|
464
|
+
# remain the same but if needed we can change that
|
465
|
+
"id": ret[i]["meta_info"]["id"],
|
466
|
+
"object": "text_completion",
|
467
|
+
"created": int(time.time()),
|
468
|
+
"model": request[i].model,
|
469
|
+
"choices": choice,
|
470
|
+
"usage": {
|
471
|
+
"prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
|
472
|
+
"completion_tokens": ret[i]["meta_info"]["completion_tokens"],
|
473
|
+
"total_tokens": ret[i]["meta_info"]["prompt_tokens"]
|
474
|
+
+ ret[i]["meta_info"]["completion_tokens"],
|
475
|
+
},
|
476
|
+
"system_fingerprint": None,
|
477
|
+
},
|
478
|
+
}
|
479
|
+
responses.append(response)
|
480
|
+
return responses
|
481
|
+
else:
|
482
|
+
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
483
|
+
response = CompletionResponse(
|
484
|
+
id=ret[0]["meta_info"]["id"],
|
485
|
+
model=request.model,
|
486
|
+
choices=choices,
|
487
|
+
usage=UsageInfo(
|
488
|
+
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
|
489
|
+
completion_tokens=completion_tokens,
|
490
|
+
total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
|
491
|
+
),
|
492
|
+
)
|
493
|
+
return response
|
494
|
+
|
495
|
+
|
496
|
+
async def v1_completions(tokenizer_manager, raw_request: Request):
|
497
|
+
request_json = await raw_request.json()
|
498
|
+
all_requests = [CompletionRequest(**request_json)]
|
499
|
+
adapted_request, request = v1_generate_request(all_requests)
|
121
500
|
|
122
501
|
if adapted_request.stream:
|
123
502
|
|
@@ -208,105 +587,190 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
208
587
|
|
209
588
|
if not isinstance(ret, list):
|
210
589
|
ret = [ret]
|
211
|
-
choices = []
|
212
590
|
|
213
|
-
|
214
|
-
|
591
|
+
response = v1_generate_response(request, ret)
|
592
|
+
return response
|
215
593
|
|
216
|
-
if request.echo:
|
217
|
-
text = request.prompt + text
|
218
594
|
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
595
|
+
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
596
|
+
|
597
|
+
texts = []
|
598
|
+
sampling_params_list = []
|
599
|
+
image_data_list = []
|
600
|
+
return_logprobs = []
|
601
|
+
top_logprobs_nums = []
|
602
|
+
for request in all_requests:
|
603
|
+
# Prep the data needed for the underlying GenerateReqInput:
|
604
|
+
# - prompt: The full prompt string.
|
605
|
+
# - stop: Custom stop tokens.
|
606
|
+
# - image_data: None or a list of image strings (URLs or base64 strings).
|
607
|
+
# None skips any image processing in GenerateReqInput.
|
608
|
+
if not isinstance(request.messages, str):
|
609
|
+
# Apply chat template and its stop strings.
|
610
|
+
if chat_template_name is None:
|
611
|
+
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
612
|
+
request.messages, tokenize=False, add_generation_prompt=True
|
613
|
+
)
|
614
|
+
stop = request.stop
|
615
|
+
image_data = None
|
223
616
|
else:
|
224
|
-
|
225
|
-
|
617
|
+
conv = generate_chat_conv(request, chat_template_name)
|
618
|
+
prompt = conv.get_prompt()
|
619
|
+
image_data = conv.image_data
|
620
|
+
stop = conv.stop_str or []
|
621
|
+
if request.stop:
|
622
|
+
if isinstance(request.stop, str):
|
623
|
+
stop.append(request.stop)
|
624
|
+
else:
|
625
|
+
stop.extend(request.stop)
|
626
|
+
else:
|
627
|
+
# Use the raw prompt and stop strings if the messages is already a string.
|
628
|
+
prompt = request.messages
|
629
|
+
stop = request.stop
|
630
|
+
image_data = None
|
631
|
+
texts.append(prompt)
|
632
|
+
return_logprobs.append(request.logprobs)
|
633
|
+
top_logprobs_nums.append(request.top_logprobs)
|
634
|
+
sampling_params_list.append(
|
635
|
+
{
|
636
|
+
"temperature": request.temperature,
|
637
|
+
"max_new_tokens": request.max_tokens,
|
638
|
+
"stop": stop,
|
639
|
+
"top_p": request.top_p,
|
640
|
+
"presence_penalty": request.presence_penalty,
|
641
|
+
"frequency_penalty": request.frequency_penalty,
|
642
|
+
"regex": request.regex,
|
643
|
+
"n": request.n,
|
644
|
+
}
|
645
|
+
)
|
646
|
+
image_data_list.append(image_data)
|
647
|
+
if len(all_requests) == 1:
|
648
|
+
texts = texts[0]
|
649
|
+
sampling_params_list = sampling_params_list[0]
|
650
|
+
image_data = image_data_list[0]
|
651
|
+
return_logprobs = return_logprobs[0]
|
652
|
+
top_logprobs_nums = top_logprobs_nums[0]
|
653
|
+
adapted_request = GenerateReqInput(
|
654
|
+
text=texts,
|
655
|
+
image_data=image_data,
|
656
|
+
sampling_params=sampling_params_list,
|
657
|
+
return_logprob=return_logprobs,
|
658
|
+
top_logprobs_num=top_logprobs_nums,
|
659
|
+
stream=all_requests[0].stream,
|
660
|
+
return_text_in_logprobs=True,
|
661
|
+
)
|
662
|
+
if len(all_requests) == 1:
|
663
|
+
return adapted_request, all_requests[0]
|
664
|
+
return adapted_request, all_requests
|
665
|
+
|
666
|
+
|
667
|
+
def v1_chat_generate_response(request, ret, to_file=False):
|
668
|
+
choices = []
|
669
|
+
total_prompt_tokens = 0
|
670
|
+
total_completion_tokens = 0
|
226
671
|
|
672
|
+
for idx, ret_item in enumerate(ret):
|
673
|
+
logprobs = False
|
674
|
+
if isinstance(request, List) and request[idx].logprobs:
|
675
|
+
logprobs = True
|
676
|
+
elif (not isinstance(request, List)) and request.logprobs:
|
677
|
+
logprobs = True
|
678
|
+
if logprobs:
|
227
679
|
logprobs = to_openai_style_logprobs(
|
228
|
-
input_token_logprobs=input_token_logprobs,
|
229
|
-
input_top_logprobs=input_top_logprobs,
|
230
680
|
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
|
231
681
|
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
232
682
|
)
|
683
|
+
token_logprobs = []
|
684
|
+
for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs):
|
685
|
+
token_bytes = list(token.encode("utf-8"))
|
686
|
+
top_logprobs = []
|
687
|
+
if logprobs.top_logprobs:
|
688
|
+
for top_token, top_logprob in logprobs.top_logprobs[0].items():
|
689
|
+
top_token_bytes = list(top_token.encode("utf-8"))
|
690
|
+
top_logprobs.append(
|
691
|
+
TopLogprob(
|
692
|
+
token=top_token,
|
693
|
+
bytes=top_token_bytes,
|
694
|
+
logprob=top_logprob,
|
695
|
+
)
|
696
|
+
)
|
697
|
+
token_logprobs.append(
|
698
|
+
ChatCompletionTokenLogprob(
|
699
|
+
token=token,
|
700
|
+
bytes=token_bytes,
|
701
|
+
logprob=logprob,
|
702
|
+
top_logprobs=top_logprobs,
|
703
|
+
)
|
704
|
+
)
|
705
|
+
|
706
|
+
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
|
233
707
|
else:
|
234
|
-
|
708
|
+
choice_logprobs = None
|
709
|
+
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
710
|
+
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
235
711
|
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
712
|
+
if to_file:
|
713
|
+
# to make the choice data json serializable
|
714
|
+
choice_data = {
|
715
|
+
"index": 0,
|
716
|
+
"message": {"role": "assistant", "content": ret_item["text"]},
|
717
|
+
"logprobs": choice_logprobs,
|
718
|
+
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
719
|
+
}
|
720
|
+
else:
|
721
|
+
choice_data = ChatCompletionResponseChoice(
|
722
|
+
index=idx,
|
723
|
+
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
724
|
+
logprobs=choice_logprobs,
|
725
|
+
finish_reason=ret_item["meta_info"]["finish_reason"],
|
726
|
+
)
|
242
727
|
|
243
728
|
choices.append(choice_data)
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
729
|
+
total_prompt_tokens += prompt_tokens
|
730
|
+
total_completion_tokens += completion_tokens
|
731
|
+
if to_file:
|
732
|
+
responses = []
|
733
|
+
|
734
|
+
for i, choice in enumerate(choices):
|
735
|
+
response = {
|
736
|
+
"status_code": 200,
|
737
|
+
"request_id": ret[i]["meta_info"]["id"],
|
738
|
+
"body": {
|
739
|
+
# remain the same but if needed we can change that
|
740
|
+
"id": ret[i]["meta_info"]["id"],
|
741
|
+
"object": "chat.completion",
|
742
|
+
"created": int(time.time()),
|
743
|
+
"model": request[i].model,
|
744
|
+
"choices": choice,
|
745
|
+
"usage": {
|
746
|
+
"prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
|
747
|
+
"completion_tokens": ret[i]["meta_info"]["completion_tokens"],
|
748
|
+
"total_tokens": ret[i]["meta_info"]["prompt_tokens"]
|
749
|
+
+ ret[i]["meta_info"]["completion_tokens"],
|
750
|
+
},
|
751
|
+
"system_fingerprint": None,
|
752
|
+
},
|
753
|
+
}
|
754
|
+
responses.append(response)
|
755
|
+
return responses
|
756
|
+
else:
|
757
|
+
response = ChatCompletionResponse(
|
758
|
+
id=ret[0]["meta_info"]["id"],
|
759
|
+
model=request.model,
|
760
|
+
choices=choices,
|
761
|
+
usage=UsageInfo(
|
762
|
+
prompt_tokens=total_prompt_tokens,
|
763
|
+
completion_tokens=total_completion_tokens,
|
764
|
+
total_tokens=total_prompt_tokens + total_completion_tokens,
|
253
765
|
),
|
254
|
-
|
255
|
-
|
256
|
-
),
|
257
|
-
)
|
258
|
-
|
259
|
-
return response
|
766
|
+
)
|
767
|
+
return response
|
260
768
|
|
261
769
|
|
262
770
|
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
263
771
|
request_json = await raw_request.json()
|
264
|
-
|
265
|
-
|
266
|
-
# Prep the data needed for the underlying GenerateReqInput:
|
267
|
-
# - prompt: The full prompt string.
|
268
|
-
# - stop: Custom stop tokens.
|
269
|
-
# - image_data: None or a list of image strings (URLs or base64 strings).
|
270
|
-
# None skips any image processing in GenerateReqInput.
|
271
|
-
if not isinstance(request.messages, str):
|
272
|
-
# Apply chat template and its stop strings.
|
273
|
-
if chat_template_name is None:
|
274
|
-
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
275
|
-
request.messages, tokenize=False, add_generation_prompt=True
|
276
|
-
)
|
277
|
-
stop = request.stop
|
278
|
-
image_data = None
|
279
|
-
else:
|
280
|
-
conv = generate_chat_conv(request, chat_template_name)
|
281
|
-
prompt = conv.get_prompt()
|
282
|
-
image_data = conv.image_data
|
283
|
-
stop = conv.stop_str or []
|
284
|
-
if request.stop:
|
285
|
-
if isinstance(request.stop, str):
|
286
|
-
stop.append(request.stop)
|
287
|
-
else:
|
288
|
-
stop.extend(request.stop)
|
289
|
-
else:
|
290
|
-
# Use the raw prompt and stop strings if the messages is already a string.
|
291
|
-
prompt = request.messages
|
292
|
-
stop = request.stop
|
293
|
-
image_data = None
|
294
|
-
|
295
|
-
adapted_request = GenerateReqInput(
|
296
|
-
text=prompt,
|
297
|
-
image_data=image_data,
|
298
|
-
sampling_params={
|
299
|
-
"temperature": request.temperature,
|
300
|
-
"max_new_tokens": request.max_tokens,
|
301
|
-
"stop": stop,
|
302
|
-
"top_p": request.top_p,
|
303
|
-
"presence_penalty": request.presence_penalty,
|
304
|
-
"frequency_penalty": request.frequency_penalty,
|
305
|
-
"regex": request.regex,
|
306
|
-
"n": request.n,
|
307
|
-
},
|
308
|
-
stream=request.stream,
|
309
|
-
)
|
772
|
+
all_requests = [ChatCompletionRequest(**request_json)]
|
773
|
+
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
310
774
|
|
311
775
|
if adapted_request.stream:
|
312
776
|
|
@@ -368,34 +832,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
368
832
|
|
369
833
|
if not isinstance(ret, list):
|
370
834
|
ret = [ret]
|
371
|
-
choices = []
|
372
|
-
total_prompt_tokens = 0
|
373
|
-
total_completion_tokens = 0
|
374
835
|
|
375
|
-
|
376
|
-
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
377
|
-
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
378
|
-
|
379
|
-
choice_data = ChatCompletionResponseChoice(
|
380
|
-
index=idx,
|
381
|
-
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
382
|
-
finish_reason=ret_item["meta_info"]["finish_reason"],
|
383
|
-
)
|
384
|
-
|
385
|
-
choices.append(choice_data)
|
386
|
-
total_prompt_tokens = prompt_tokens
|
387
|
-
total_completion_tokens += completion_tokens
|
388
|
-
|
389
|
-
response = ChatCompletionResponse(
|
390
|
-
id=ret[0]["meta_info"]["id"],
|
391
|
-
model=request.model,
|
392
|
-
choices=choices,
|
393
|
-
usage=UsageInfo(
|
394
|
-
prompt_tokens=total_prompt_tokens,
|
395
|
-
completion_tokens=total_completion_tokens,
|
396
|
-
total_tokens=total_prompt_tokens + total_completion_tokens,
|
397
|
-
),
|
398
|
-
)
|
836
|
+
response = v1_chat_generate_response(request, ret)
|
399
837
|
|
400
838
|
return response
|
401
839
|
|