sglang 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +33 -26
- sglang/api.py +9 -1
- sglang/bench_latency.py +2 -2
- sglang/bench_serving.py +10 -1
- sglang/check_env.py +1 -1
- sglang/lang/backend/litellm.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/interpreter.py +20 -5
- sglang/lang/ir.py +1 -1
- sglang/srt/constrained/__init__.py +15 -0
- sglang/srt/constrained/base_cache.py +15 -0
- sglang/srt/constrained/fsm_cache.py +15 -0
- sglang/srt/constrained/jump_forward.py +15 -0
- sglang/srt/conversation.py +26 -0
- sglang/srt/hf_transformers_utils.py +15 -0
- sglang/srt/layers/context_flashattention_nopad.py +15 -0
- sglang/srt/layers/extend_attention.py +15 -0
- sglang/srt/layers/fused_moe.py +15 -0
- sglang/srt/layers/linear.py +15 -0
- sglang/srt/layers/logits_processor.py +41 -13
- sglang/srt/layers/quantization/__init__.py +15 -0
- sglang/srt/layers/quantization/fp8.py +15 -0
- sglang/srt/layers/radix_attention.py +17 -2
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
- sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
- sglang/srt/managers/detokenizer_manager.py +16 -1
- sglang/srt/managers/io_struct.py +36 -3
- sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
- sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +31 -12
- sglang/srt/managers/tokenizer_manager.py +39 -16
- sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +130 -40
- sglang/srt/mem_cache/flush_cache.py +33 -0
- sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
- sglang/srt/{managers/controller → mem_cache}/radix_cache.py +15 -0
- sglang/srt/mm_utils.py +15 -0
- sglang/srt/model_config.py +15 -0
- sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +16 -1
- sglang/srt/{managers/controller → model_executor}/model_runner.py +32 -12
- sglang/srt/model_loader/model_loader.py +15 -0
- sglang/srt/model_loader/utils.py +16 -1
- sglang/srt/models/chatglm.py +16 -1
- sglang/srt/models/commandr.py +16 -1
- sglang/srt/models/dbrx.py +16 -1
- sglang/srt/models/deepseek.py +16 -1
- sglang/srt/models/deepseek_v2.py +16 -1
- sglang/srt/models/gemma.py +16 -1
- sglang/srt/models/gemma2.py +16 -1
- sglang/srt/models/gpt_bigcode.py +16 -1
- sglang/srt/models/grok.py +16 -1
- sglang/srt/models/internlm2.py +16 -1
- sglang/srt/models/llama2.py +16 -1
- sglang/srt/models/llama_classification.py +16 -1
- sglang/srt/models/llava.py +17 -2
- sglang/srt/models/llavavid.py +17 -2
- sglang/srt/models/minicpm.py +16 -1
- sglang/srt/models/mistral.py +15 -0
- sglang/srt/models/mixtral.py +16 -1
- sglang/srt/models/mixtral_quant.py +16 -1
- sglang/srt/models/qwen.py +16 -1
- sglang/srt/models/qwen2.py +16 -1
- sglang/srt/models/qwen2_moe.py +16 -1
- sglang/srt/models/stablelm.py +16 -1
- sglang/srt/models/yivl.py +15 -0
- sglang/srt/openai_api/adapter.py +520 -135
- sglang/srt/openai_api/protocol.py +64 -0
- sglang/srt/sampling_params.py +15 -0
- sglang/srt/server.py +89 -23
- sglang/srt/server_args.py +49 -11
- sglang/srt/utils.py +15 -0
- sglang/utils.py +22 -0
- sglang/version.py +1 -1
- {sglang-0.2.6.dist-info → sglang-0.2.7.dist-info}/METADATA +32 -6
- sglang-0.2.7.dist-info/RECORD +93 -0
- {sglang-0.2.6.dist-info → sglang-0.2.7.dist-info}/WHEEL +1 -1
- sglang/srt/flush_cache.py +0 -18
- sglang-0.2.6.dist-info/RECORD +0 -93
- {sglang-0.2.6.dist-info → sglang-0.2.7.dist-info}/LICENSE +0 -0
- {sglang-0.2.6.dist-info → sglang-0.2.7.dist-info}/top_level.txt +0 -0
sglang/srt/openai_api/adapter.py
CHANGED
@@ -1,12 +1,31 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""Conversion between OpenAI APIs and native SRT APIs"""
|
2
17
|
|
3
18
|
import asyncio
|
4
19
|
import json
|
5
20
|
import os
|
21
|
+
import time
|
22
|
+
import uuid
|
6
23
|
from http import HTTPStatus
|
24
|
+
from typing import Dict, List, Optional
|
7
25
|
|
8
|
-
from fastapi import Request
|
26
|
+
from fastapi import HTTPException, Request, UploadFile
|
9
27
|
from fastapi.responses import JSONResponse, StreamingResponse
|
28
|
+
from pydantic import ValidationError
|
10
29
|
|
11
30
|
from sglang.srt.conversation import (
|
12
31
|
Conversation,
|
@@ -17,6 +36,8 @@ from sglang.srt.conversation import (
|
|
17
36
|
)
|
18
37
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
19
38
|
from sglang.srt.openai_api.protocol import (
|
39
|
+
BatchRequest,
|
40
|
+
BatchResponse,
|
20
41
|
ChatCompletionRequest,
|
21
42
|
ChatCompletionResponse,
|
22
43
|
ChatCompletionResponseChoice,
|
@@ -30,6 +51,8 @@ from sglang.srt.openai_api.protocol import (
|
|
30
51
|
CompletionStreamResponse,
|
31
52
|
DeltaMessage,
|
32
53
|
ErrorResponse,
|
54
|
+
FileRequest,
|
55
|
+
FileResponse,
|
33
56
|
LogProbs,
|
34
57
|
UsageInfo,
|
35
58
|
)
|
@@ -37,6 +60,24 @@ from sglang.srt.openai_api.protocol import (
|
|
37
60
|
chat_template_name = None
|
38
61
|
|
39
62
|
|
63
|
+
class FileMetadata:
|
64
|
+
def __init__(self, filename: str, purpose: str):
|
65
|
+
self.filename = filename
|
66
|
+
self.purpose = purpose
|
67
|
+
|
68
|
+
|
69
|
+
# In-memory storage for batch jobs and files
|
70
|
+
batch_storage: Dict[str, BatchResponse] = {}
|
71
|
+
file_id_request: Dict[str, FileMetadata] = {}
|
72
|
+
file_id_response: Dict[str, FileResponse] = {}
|
73
|
+
## map file id to file path in SGlang backend
|
74
|
+
file_id_storage: Dict[str, str] = {}
|
75
|
+
|
76
|
+
|
77
|
+
# backend storage directory
|
78
|
+
storage_dir = None
|
79
|
+
|
80
|
+
|
40
81
|
def create_error_response(
|
41
82
|
message: str,
|
42
83
|
err_type: str = "BadRequestError",
|
@@ -91,33 +132,364 @@ def load_chat_template_for_openai_api(chat_template_arg):
|
|
91
132
|
chat_template_name = chat_template_arg
|
92
133
|
|
93
134
|
|
94
|
-
async def
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
135
|
+
async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None):
|
136
|
+
try:
|
137
|
+
global storage_dir
|
138
|
+
if file_storage_pth:
|
139
|
+
storage_dir = file_storage_pth
|
140
|
+
# Read the file content
|
141
|
+
file_content = await file.read()
|
142
|
+
|
143
|
+
# Create an instance of RequestBody
|
144
|
+
request_body = FileRequest(file=file_content, purpose=purpose)
|
145
|
+
|
146
|
+
# Save the file to the sglang_oai_storage directory
|
147
|
+
os.makedirs(storage_dir, exist_ok=True)
|
148
|
+
file_id = f"backend_input_file-{uuid.uuid4()}"
|
149
|
+
filename = f"{file_id}.jsonl"
|
150
|
+
file_path = os.path.join(storage_dir, filename)
|
151
|
+
|
152
|
+
with open(file_path, "wb") as f:
|
153
|
+
f.write(request_body.file)
|
154
|
+
|
155
|
+
# add info to global file map
|
156
|
+
file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose)
|
157
|
+
file_id_storage[file_id] = file_path
|
158
|
+
|
159
|
+
# Return the response in the required format
|
160
|
+
response = FileResponse(
|
161
|
+
id=file_id,
|
162
|
+
bytes=len(request_body.file),
|
163
|
+
created_at=int(time.time()),
|
164
|
+
filename=file.filename,
|
165
|
+
purpose=request_body.purpose,
|
166
|
+
)
|
167
|
+
file_id_response[file_id] = response
|
168
|
+
|
169
|
+
return response
|
170
|
+
except ValidationError as e:
|
171
|
+
return {"error": "Invalid input", "details": e.errors()}
|
172
|
+
|
173
|
+
|
174
|
+
async def v1_batches(tokenizer_manager, raw_request: Request):
|
175
|
+
try:
|
176
|
+
body = await raw_request.json()
|
177
|
+
|
178
|
+
batch_request = BatchRequest(**body)
|
179
|
+
|
180
|
+
batch_id = f"batch_{uuid.uuid4()}"
|
181
|
+
|
182
|
+
# Create an instance of BatchResponse
|
183
|
+
batch_response = BatchResponse(
|
184
|
+
id=batch_id,
|
185
|
+
endpoint=batch_request.endpoint,
|
186
|
+
input_file_id=batch_request.input_file_id,
|
187
|
+
completion_window=batch_request.completion_window,
|
188
|
+
created_at=int(time.time()),
|
189
|
+
metadata=batch_request.metadata,
|
190
|
+
)
|
191
|
+
|
192
|
+
batch_storage[batch_id] = batch_response
|
193
|
+
|
194
|
+
# Start processing the batch asynchronously
|
195
|
+
asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
|
196
|
+
|
197
|
+
# Return the initial batch_response
|
198
|
+
return batch_response
|
199
|
+
|
200
|
+
except ValidationError as e:
|
201
|
+
return {"error": "Invalid input", "details": e.errors()}
|
202
|
+
except Exception as e:
|
203
|
+
return {"error": str(e)}
|
204
|
+
|
205
|
+
|
206
|
+
async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
|
207
|
+
try:
|
208
|
+
# Update the batch status to "in_progress"
|
209
|
+
batch_storage[batch_id].status = "in_progress"
|
210
|
+
batch_storage[batch_id].in_progress_at = int(time.time())
|
211
|
+
|
212
|
+
# Retrieve the input file content
|
213
|
+
input_file_request = file_id_request.get(batch_request.input_file_id)
|
214
|
+
if not input_file_request:
|
215
|
+
raise ValueError("Input file not found")
|
216
|
+
|
217
|
+
# Parse the JSONL file and process each request
|
218
|
+
input_file_path = file_id_storage.get(batch_request.input_file_id)
|
219
|
+
with open(input_file_path, "r", encoding="utf-8") as f:
|
220
|
+
lines = f.readlines()
|
221
|
+
|
222
|
+
total_requests = len(lines)
|
223
|
+
completed_requests = 0
|
224
|
+
failed_requests = 0
|
225
|
+
|
226
|
+
all_ret = []
|
227
|
+
end_point = batch_storage[batch_id].endpoint
|
228
|
+
file_request_list = []
|
229
|
+
all_requests = []
|
230
|
+
for line in lines:
|
231
|
+
request_data = json.loads(line)
|
232
|
+
file_request_list.append(request_data)
|
233
|
+
body = request_data["body"]
|
234
|
+
if end_point == "/v1/chat/completions":
|
235
|
+
all_requests.append(ChatCompletionRequest(**body))
|
236
|
+
elif end_point == "/v1/completions":
|
237
|
+
all_requests.append(CompletionRequest(**body))
|
238
|
+
if end_point == "/v1/chat/completions":
|
239
|
+
adapted_request, request = v1_chat_generate_request(
|
240
|
+
all_requests, tokenizer_manager
|
241
|
+
)
|
242
|
+
elif end_point == "/v1/completions":
|
243
|
+
adapted_request, request = v1_generate_request(all_requests)
|
244
|
+
try:
|
245
|
+
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
|
246
|
+
if not isinstance(ret, list):
|
247
|
+
ret = [ret]
|
248
|
+
if end_point == "/v1/chat/completions":
|
249
|
+
responses = v1_chat_generate_response(request, ret, to_file=True)
|
250
|
+
else:
|
251
|
+
responses = v1_generate_response(request, ret, to_file=True)
|
252
|
+
|
253
|
+
except Exception as e:
|
254
|
+
error_json = {
|
255
|
+
"id": f"batch_req_{uuid.uuid4()}",
|
256
|
+
"custom_id": request_data.get("custom_id"),
|
257
|
+
"response": None,
|
258
|
+
"error": {"message": str(e)},
|
259
|
+
}
|
260
|
+
all_ret.append(error_json)
|
261
|
+
failed_requests += len(file_request_list)
|
262
|
+
|
263
|
+
for idx, response in enumerate(responses):
|
264
|
+
## the batch_req here can be changed to be named within a batch granularity
|
265
|
+
response_json = {
|
266
|
+
"id": f"batch_req_{uuid.uuid4()}",
|
267
|
+
"custom_id": file_request_list[idx].get("custom_id"),
|
268
|
+
"response": response,
|
269
|
+
"error": None,
|
270
|
+
}
|
271
|
+
all_ret.append(response_json)
|
272
|
+
completed_requests += 1
|
273
|
+
# Write results to a new file
|
274
|
+
output_file_id = f"backend_result_file-{uuid.uuid4()}"
|
275
|
+
global storage_dir
|
276
|
+
output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl")
|
277
|
+
with open(output_file_path, "w", encoding="utf-8") as f:
|
278
|
+
for ret in all_ret:
|
279
|
+
f.write(json.dumps(ret) + "\n")
|
280
|
+
|
281
|
+
# Update batch response with output file information
|
282
|
+
retrieve_batch = batch_storage[batch_id]
|
283
|
+
retrieve_batch.output_file_id = output_file_id
|
284
|
+
file_id_storage[output_file_id] = output_file_path
|
285
|
+
# Update batch status to "completed"
|
286
|
+
retrieve_batch.status = "completed"
|
287
|
+
retrieve_batch.completed_at = int(time.time())
|
288
|
+
retrieve_batch.request_counts = {
|
289
|
+
"total": total_requests,
|
290
|
+
"completed": completed_requests,
|
291
|
+
"failed": failed_requests,
|
292
|
+
}
|
293
|
+
|
294
|
+
except Exception as e:
|
295
|
+
print("error in SGlang:", e)
|
296
|
+
# Update batch status to "failed"
|
297
|
+
retrieve_batch = batch_storage[batch_id]
|
298
|
+
retrieve_batch.status = "failed"
|
299
|
+
retrieve_batch.failed_at = int(time.time())
|
300
|
+
retrieve_batch.errors = {"message": str(e)}
|
301
|
+
|
302
|
+
|
303
|
+
async def v1_retrieve_batch(batch_id: str):
|
304
|
+
# Retrieve the batch job from the in-memory storage
|
305
|
+
batch_response = batch_storage.get(batch_id)
|
306
|
+
if batch_response is None:
|
307
|
+
raise HTTPException(status_code=404, detail="Batch not found")
|
308
|
+
|
309
|
+
return batch_response
|
310
|
+
|
311
|
+
|
312
|
+
async def v1_retrieve_file(file_id: str):
|
313
|
+
# Retrieve the batch job from the in-memory storage
|
314
|
+
file_response = file_id_response.get(file_id)
|
315
|
+
if file_response is None:
|
316
|
+
raise HTTPException(status_code=404, detail="File not found")
|
317
|
+
return file_response
|
318
|
+
|
319
|
+
|
320
|
+
async def v1_retrieve_file_content(file_id: str):
|
321
|
+
file_pth = file_id_storage.get(file_id)
|
322
|
+
if not file_pth or not os.path.exists(file_pth):
|
323
|
+
raise HTTPException(status_code=404, detail="File not found")
|
324
|
+
|
325
|
+
def iter_file():
|
326
|
+
with open(file_pth, mode="rb") as file_like:
|
327
|
+
yield from file_like
|
328
|
+
|
329
|
+
return StreamingResponse(iter_file(), media_type="application/octet-stream")
|
330
|
+
|
331
|
+
|
332
|
+
def v1_generate_request(all_requests):
|
333
|
+
|
334
|
+
prompts = []
|
335
|
+
sampling_params_list = []
|
336
|
+
first_prompt_type = type(all_requests[0].prompt)
|
337
|
+
for request in all_requests:
|
338
|
+
prompt = request.prompt
|
339
|
+
assert (
|
340
|
+
type(prompt) == first_prompt_type
|
341
|
+
), "All prompts must be of the same type in file input settings"
|
342
|
+
prompts.append(prompt)
|
343
|
+
sampling_params_list.append(
|
344
|
+
{
|
345
|
+
"temperature": request.temperature,
|
346
|
+
"max_new_tokens": request.max_tokens,
|
347
|
+
"stop": request.stop,
|
348
|
+
"top_p": request.top_p,
|
349
|
+
"presence_penalty": request.presence_penalty,
|
350
|
+
"frequency_penalty": request.frequency_penalty,
|
351
|
+
"regex": request.regex,
|
352
|
+
"n": request.n,
|
353
|
+
"ignore_eos": request.ignore_eos,
|
354
|
+
}
|
355
|
+
)
|
356
|
+
if len(all_requests) > 1 and request.n > 1:
|
357
|
+
raise ValueError(
|
358
|
+
"Batch operation is not supported for completions from files"
|
359
|
+
)
|
360
|
+
|
361
|
+
if len(all_requests) == 1:
|
362
|
+
prompt = prompts[0]
|
363
|
+
sampling_params_list = sampling_params_list[0]
|
364
|
+
if isinstance(prompts, str) or isinstance(prompts[0], str):
|
365
|
+
prompt_kwargs = {"text": prompt}
|
366
|
+
else:
|
367
|
+
prompt_kwargs = {"input_ids": prompt}
|
100
368
|
else:
|
101
|
-
|
369
|
+
if isinstance(prompts[0], str):
|
370
|
+
prompt_kwargs = {"text": prompts}
|
371
|
+
else:
|
372
|
+
prompt_kwargs = {"input_ids": prompts}
|
102
373
|
|
103
374
|
adapted_request = GenerateReqInput(
|
104
375
|
**prompt_kwargs,
|
105
|
-
sampling_params=
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
"frequency_penalty": request.frequency_penalty,
|
112
|
-
"regex": request.regex,
|
113
|
-
"n": request.n,
|
114
|
-
"ignore_eos": request.ignore_eos,
|
115
|
-
},
|
116
|
-
return_logprob=request.logprobs is not None and request.logprobs > 0,
|
117
|
-
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
|
376
|
+
sampling_params=sampling_params_list,
|
377
|
+
return_logprob=all_requests[0].logprobs is not None
|
378
|
+
and all_requests[0].logprobs > 0,
|
379
|
+
top_logprobs_num=(
|
380
|
+
all_requests[0].logprobs if all_requests[0].logprobs is not None else 0
|
381
|
+
),
|
118
382
|
return_text_in_logprobs=True,
|
119
|
-
stream=
|
383
|
+
stream=all_requests[0].stream,
|
120
384
|
)
|
385
|
+
if len(all_requests) == 1:
|
386
|
+
return adapted_request, all_requests[0]
|
387
|
+
return adapted_request, all_requests
|
388
|
+
|
389
|
+
|
390
|
+
def v1_generate_response(request, ret, to_file=False):
|
391
|
+
choices = []
|
392
|
+
echo = False
|
393
|
+
|
394
|
+
if (not isinstance(request, List)) and request.echo:
|
395
|
+
# TODO: handle the case propmt is token ids
|
396
|
+
if isinstance(request.prompt, list):
|
397
|
+
prompts = request.prompt
|
398
|
+
else:
|
399
|
+
prompts = [request.prompt]
|
400
|
+
echo = True
|
401
|
+
|
402
|
+
for idx, ret_item in enumerate(ret):
|
403
|
+
text = ret_item["text"]
|
404
|
+
if isinstance(request, List) and request[idx].echo:
|
405
|
+
echo = True
|
406
|
+
text = request[idx].prompt + text
|
407
|
+
if (not isinstance(request, List)) and echo:
|
408
|
+
text = prompts[idx] + text
|
409
|
+
|
410
|
+
logprobs = False
|
411
|
+
if isinstance(request, List) and request[idx].logprobs:
|
412
|
+
logprobs = True
|
413
|
+
elif (not isinstance(request, List)) and request.logprobs:
|
414
|
+
logprobs = True
|
415
|
+
if logprobs:
|
416
|
+
if echo:
|
417
|
+
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
|
418
|
+
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
|
419
|
+
else:
|
420
|
+
input_token_logprobs = None
|
421
|
+
input_top_logprobs = None
|
422
|
+
|
423
|
+
logprobs = to_openai_style_logprobs(
|
424
|
+
input_token_logprobs=input_token_logprobs,
|
425
|
+
input_top_logprobs=input_top_logprobs,
|
426
|
+
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
|
427
|
+
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
428
|
+
)
|
429
|
+
else:
|
430
|
+
logprobs = None
|
431
|
+
|
432
|
+
if to_file:
|
433
|
+
## to make the choise data json serializable
|
434
|
+
choice_data = {
|
435
|
+
"index": 0,
|
436
|
+
"text": text,
|
437
|
+
"logprobs": logprobs,
|
438
|
+
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
439
|
+
}
|
440
|
+
else:
|
441
|
+
choice_data = CompletionResponseChoice(
|
442
|
+
index=idx,
|
443
|
+
text=text,
|
444
|
+
logprobs=logprobs,
|
445
|
+
finish_reason=ret_item["meta_info"]["finish_reason"],
|
446
|
+
)
|
447
|
+
|
448
|
+
choices.append(choice_data)
|
449
|
+
|
450
|
+
if to_file:
|
451
|
+
responses = []
|
452
|
+
for i, choice in enumerate(choices):
|
453
|
+
response = {
|
454
|
+
"status_code": 200,
|
455
|
+
"request_id": ret[i]["meta_info"]["id"],
|
456
|
+
"body": {
|
457
|
+
## remain the same but if needed we can change that
|
458
|
+
"id": ret[i]["meta_info"]["id"],
|
459
|
+
"object": "text_completion",
|
460
|
+
"created": int(time.time()),
|
461
|
+
"model": request[i].model,
|
462
|
+
"choices": choice,
|
463
|
+
"usage": {
|
464
|
+
"prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
|
465
|
+
"completion_tokens": ret[i]["meta_info"]["completion_tokens"],
|
466
|
+
"total_tokens": ret[i]["meta_info"]["prompt_tokens"]
|
467
|
+
+ ret[i]["meta_info"]["completion_tokens"],
|
468
|
+
},
|
469
|
+
"system_fingerprint": None,
|
470
|
+
},
|
471
|
+
}
|
472
|
+
responses.append(response)
|
473
|
+
return responses
|
474
|
+
else:
|
475
|
+
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
476
|
+
response = CompletionResponse(
|
477
|
+
id=ret[0]["meta_info"]["id"],
|
478
|
+
model=request.model,
|
479
|
+
choices=choices,
|
480
|
+
usage=UsageInfo(
|
481
|
+
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
|
482
|
+
completion_tokens=completion_tokens,
|
483
|
+
total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
|
484
|
+
),
|
485
|
+
)
|
486
|
+
return response
|
487
|
+
|
488
|
+
|
489
|
+
async def v1_completions(tokenizer_manager, raw_request: Request):
|
490
|
+
request_json = await raw_request.json()
|
491
|
+
all_requests = [CompletionRequest(**request_json)]
|
492
|
+
adapted_request, request = v1_generate_request(all_requests)
|
121
493
|
|
122
494
|
if adapted_request.stream:
|
123
495
|
|
@@ -208,105 +580,144 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
208
580
|
|
209
581
|
if not isinstance(ret, list):
|
210
582
|
ret = [ret]
|
211
|
-
choices = []
|
212
583
|
|
213
|
-
|
214
|
-
|
584
|
+
response = v1_generate_response(request, ret)
|
585
|
+
return response
|
215
586
|
|
216
|
-
if request.echo:
|
217
|
-
text = request.prompt + text
|
218
587
|
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
588
|
+
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
589
|
+
|
590
|
+
texts = []
|
591
|
+
sampling_params_list = []
|
592
|
+
image_data_list = []
|
593
|
+
for request in all_requests:
|
594
|
+
# Prep the data needed for the underlying GenerateReqInput:
|
595
|
+
# - prompt: The full prompt string.
|
596
|
+
# - stop: Custom stop tokens.
|
597
|
+
# - image_data: None or a list of image strings (URLs or base64 strings).
|
598
|
+
# None skips any image processing in GenerateReqInput.
|
599
|
+
if not isinstance(request.messages, str):
|
600
|
+
# Apply chat template and its stop strings.
|
601
|
+
if chat_template_name is None:
|
602
|
+
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
603
|
+
request.messages, tokenize=False, add_generation_prompt=True
|
604
|
+
)
|
605
|
+
stop = request.stop
|
606
|
+
image_data = None
|
223
607
|
else:
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
608
|
+
conv = generate_chat_conv(request, chat_template_name)
|
609
|
+
prompt = conv.get_prompt()
|
610
|
+
image_data = conv.image_data
|
611
|
+
stop = conv.stop_str or []
|
612
|
+
if request.stop:
|
613
|
+
if isinstance(request.stop, str):
|
614
|
+
stop.append(request.stop)
|
615
|
+
else:
|
616
|
+
stop.extend(request.stop)
|
233
617
|
else:
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
618
|
+
# Use the raw prompt and stop strings if the messages is already a string.
|
619
|
+
prompt = request.messages
|
620
|
+
stop = request.stop
|
621
|
+
image_data = None
|
622
|
+
texts.append(prompt)
|
623
|
+
sampling_params_list.append(
|
624
|
+
{
|
625
|
+
"temperature": request.temperature,
|
626
|
+
"max_new_tokens": request.max_tokens,
|
627
|
+
"stop": stop,
|
628
|
+
"top_p": request.top_p,
|
629
|
+
"presence_penalty": request.presence_penalty,
|
630
|
+
"frequency_penalty": request.frequency_penalty,
|
631
|
+
"regex": request.regex,
|
632
|
+
"n": request.n,
|
633
|
+
}
|
241
634
|
)
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
item["meta_info"]["completion_tokens"] for item in ret
|
253
|
-
),
|
254
|
-
total_tokens=ret[0]["meta_info"]["prompt_tokens"]
|
255
|
-
+ sum(item["meta_info"]["completion_tokens"] for item in ret),
|
256
|
-
),
|
635
|
+
image_data_list.append(image_data)
|
636
|
+
if len(all_requests) == 1:
|
637
|
+
texts = texts[0]
|
638
|
+
sampling_params_list = sampling_params_list[0]
|
639
|
+
image_data = image_data_list[0]
|
640
|
+
adapted_request = GenerateReqInput(
|
641
|
+
text=texts,
|
642
|
+
image_data=image_data,
|
643
|
+
sampling_params=sampling_params_list,
|
644
|
+
stream=request.stream,
|
257
645
|
)
|
646
|
+
if len(all_requests) == 1:
|
647
|
+
return adapted_request, all_requests[0]
|
648
|
+
return adapted_request, all_requests
|
258
649
|
|
259
|
-
return response
|
260
650
|
|
651
|
+
def v1_chat_generate_response(request, ret, to_file=False):
|
652
|
+
choices = []
|
653
|
+
total_prompt_tokens = 0
|
654
|
+
total_completion_tokens = 0
|
261
655
|
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
275
|
-
request.messages, tokenize=False, add_generation_prompt=True
|
276
|
-
)
|
277
|
-
stop = request.stop
|
278
|
-
image_data = None
|
656
|
+
for idx, ret_item in enumerate(ret):
|
657
|
+
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
658
|
+
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
659
|
+
|
660
|
+
if to_file:
|
661
|
+
## to make the choice data json serializable
|
662
|
+
choice_data = {
|
663
|
+
"index": 0,
|
664
|
+
"message": {"role": "assistant", "content": ret_item["text"]},
|
665
|
+
"logprobs": None,
|
666
|
+
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
667
|
+
}
|
279
668
|
else:
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
669
|
+
choice_data = ChatCompletionResponseChoice(
|
670
|
+
index=idx,
|
671
|
+
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
672
|
+
finish_reason=ret_item["meta_info"]["finish_reason"],
|
673
|
+
)
|
674
|
+
|
675
|
+
choices.append(choice_data)
|
676
|
+
total_prompt_tokens = prompt_tokens
|
677
|
+
total_completion_tokens += completion_tokens
|
678
|
+
if to_file:
|
679
|
+
responses = []
|
680
|
+
|
681
|
+
for i, choice in enumerate(choices):
|
682
|
+
response = {
|
683
|
+
"status_code": 200,
|
684
|
+
"request_id": ret[i]["meta_info"]["id"],
|
685
|
+
"body": {
|
686
|
+
## remain the same but if needed we can change that
|
687
|
+
"id": ret[i]["meta_info"]["id"],
|
688
|
+
"object": "chat.completion",
|
689
|
+
"created": int(time.time()),
|
690
|
+
"model": request[i].model,
|
691
|
+
"choices": choice,
|
692
|
+
"usage": {
|
693
|
+
"prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
|
694
|
+
"completion_tokens": ret[i]["meta_info"]["completion_tokens"],
|
695
|
+
"total_tokens": ret[i]["meta_info"]["prompt_tokens"]
|
696
|
+
+ ret[i]["meta_info"]["completion_tokens"],
|
697
|
+
},
|
698
|
+
"system_fingerprint": None,
|
699
|
+
},
|
700
|
+
}
|
701
|
+
responses.append(response)
|
702
|
+
return responses
|
289
703
|
else:
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
704
|
+
response = ChatCompletionResponse(
|
705
|
+
id=ret[0]["meta_info"]["id"],
|
706
|
+
model=request.model,
|
707
|
+
choices=choices,
|
708
|
+
usage=UsageInfo(
|
709
|
+
prompt_tokens=total_prompt_tokens,
|
710
|
+
completion_tokens=total_completion_tokens,
|
711
|
+
total_tokens=total_prompt_tokens + total_completion_tokens,
|
712
|
+
),
|
713
|
+
)
|
714
|
+
return response
|
294
715
|
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
"max_new_tokens": request.max_tokens,
|
301
|
-
"stop": stop,
|
302
|
-
"top_p": request.top_p,
|
303
|
-
"presence_penalty": request.presence_penalty,
|
304
|
-
"frequency_penalty": request.frequency_penalty,
|
305
|
-
"regex": request.regex,
|
306
|
-
"n": request.n,
|
307
|
-
},
|
308
|
-
stream=request.stream,
|
309
|
-
)
|
716
|
+
|
717
|
+
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
718
|
+
request_json = await raw_request.json()
|
719
|
+
all_requests = [ChatCompletionRequest(**request_json)]
|
720
|
+
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
310
721
|
|
311
722
|
if adapted_request.stream:
|
312
723
|
|
@@ -368,34 +779,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
368
779
|
|
369
780
|
if not isinstance(ret, list):
|
370
781
|
ret = [ret]
|
371
|
-
choices = []
|
372
|
-
total_prompt_tokens = 0
|
373
|
-
total_completion_tokens = 0
|
374
|
-
|
375
|
-
for idx, ret_item in enumerate(ret):
|
376
|
-
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
377
|
-
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
378
|
-
|
379
|
-
choice_data = ChatCompletionResponseChoice(
|
380
|
-
index=idx,
|
381
|
-
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
382
|
-
finish_reason=ret_item["meta_info"]["finish_reason"],
|
383
|
-
)
|
384
782
|
|
385
|
-
|
386
|
-
total_prompt_tokens = prompt_tokens
|
387
|
-
total_completion_tokens += completion_tokens
|
388
|
-
|
389
|
-
response = ChatCompletionResponse(
|
390
|
-
id=ret[0]["meta_info"]["id"],
|
391
|
-
model=request.model,
|
392
|
-
choices=choices,
|
393
|
-
usage=UsageInfo(
|
394
|
-
prompt_tokens=total_prompt_tokens,
|
395
|
-
completion_tokens=total_completion_tokens,
|
396
|
-
total_tokens=total_prompt_tokens + total_completion_tokens,
|
397
|
-
),
|
398
|
-
)
|
783
|
+
response = v1_chat_generate_response(request, ret)
|
399
784
|
|
400
785
|
return response
|
401
786
|
|