sglang 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. sglang/__init__.py +33 -26
  2. sglang/api.py +9 -1
  3. sglang/bench_latency.py +2 -2
  4. sglang/bench_serving.py +10 -1
  5. sglang/check_env.py +1 -1
  6. sglang/lang/backend/litellm.py +1 -1
  7. sglang/lang/backend/openai.py +1 -1
  8. sglang/lang/interpreter.py +21 -5
  9. sglang/lang/ir.py +1 -2
  10. sglang/srt/constrained/__init__.py +15 -0
  11. sglang/srt/constrained/{base_cache.py → base_tool_cache.py} +17 -2
  12. sglang/srt/constrained/fsm_cache.py +17 -2
  13. sglang/srt/constrained/jump_forward.py +17 -2
  14. sglang/srt/conversation.py +26 -0
  15. sglang/srt/hf_transformers_utils.py +15 -0
  16. sglang/srt/layers/context_flashattention_nopad.py +15 -0
  17. sglang/srt/layers/extend_attention.py +15 -0
  18. sglang/srt/layers/fused_moe.py +15 -0
  19. sglang/srt/layers/linear.py +15 -0
  20. sglang/srt/layers/logits_processor.py +41 -13
  21. sglang/srt/layers/quantization/__init__.py +15 -0
  22. sglang/srt/layers/quantization/fp8.py +15 -0
  23. sglang/srt/layers/radix_attention.py +17 -2
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
  26. sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
  27. sglang/srt/managers/detokenizer_manager.py +16 -1
  28. sglang/srt/managers/io_struct.py +36 -3
  29. sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
  30. sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +60 -21
  31. sglang/srt/managers/tokenizer_manager.py +39 -16
  32. sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +159 -46
  33. sglang/srt/mem_cache/base_cache.py +43 -0
  34. sglang/srt/mem_cache/chunk_cache.py +60 -0
  35. sglang/srt/mem_cache/flush_cache.py +33 -0
  36. sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
  37. sglang/srt/{managers/controller → mem_cache}/radix_cache.py +20 -2
  38. sglang/srt/mm_utils.py +15 -0
  39. sglang/srt/model_config.py +15 -0
  40. sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +16 -1
  41. sglang/srt/{managers/controller → model_executor}/model_runner.py +49 -14
  42. sglang/srt/model_loader/model_loader.py +15 -0
  43. sglang/srt/model_loader/utils.py +16 -1
  44. sglang/srt/models/chatglm.py +16 -1
  45. sglang/srt/models/commandr.py +16 -1
  46. sglang/srt/models/dbrx.py +16 -1
  47. sglang/srt/models/deepseek.py +16 -1
  48. sglang/srt/models/deepseek_v2.py +16 -1
  49. sglang/srt/models/gemma.py +16 -1
  50. sglang/srt/models/gemma2.py +16 -1
  51. sglang/srt/models/gpt_bigcode.py +16 -1
  52. sglang/srt/models/grok.py +16 -1
  53. sglang/srt/models/internlm2.py +16 -1
  54. sglang/srt/models/llama2.py +21 -22
  55. sglang/srt/models/llama_classification.py +16 -1
  56. sglang/srt/models/llava.py +17 -2
  57. sglang/srt/models/llavavid.py +17 -2
  58. sglang/srt/models/minicpm.py +16 -1
  59. sglang/srt/models/mistral.py +15 -0
  60. sglang/srt/models/mixtral.py +16 -1
  61. sglang/srt/models/mixtral_quant.py +16 -1
  62. sglang/srt/models/qwen.py +16 -1
  63. sglang/srt/models/qwen2.py +16 -1
  64. sglang/srt/models/qwen2_moe.py +16 -1
  65. sglang/srt/models/stablelm.py +16 -1
  66. sglang/srt/models/yivl.py +15 -0
  67. sglang/srt/openai_api/adapter.py +569 -131
  68. sglang/srt/openai_api/protocol.py +84 -2
  69. sglang/srt/sampling_params.py +15 -0
  70. sglang/srt/server.py +92 -23
  71. sglang/srt/server_args.py +52 -11
  72. sglang/srt/utils.py +15 -0
  73. sglang/test/test_programs.py +9 -6
  74. sglang/utils.py +22 -0
  75. sglang/version.py +1 -1
  76. {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/METADATA +33 -7
  77. sglang-0.2.8.dist-info/RECORD +95 -0
  78. {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/WHEEL +1 -1
  79. sglang/srt/flush_cache.py +0 -18
  80. sglang-0.2.6.dist-info/RECORD +0 -93
  81. {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/LICENSE +0 -0
  82. {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,31 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """Conversion between OpenAI APIs and native SRT APIs"""
2
17
 
3
18
  import asyncio
4
19
  import json
5
20
  import os
21
+ import time
22
+ import uuid
6
23
  from http import HTTPStatus
24
+ from typing import Dict, List, Optional
7
25
 
8
- from fastapi import Request
26
+ from fastapi import HTTPException, Request, UploadFile
9
27
  from fastapi.responses import JSONResponse, StreamingResponse
28
+ from pydantic import ValidationError
10
29
 
11
30
  from sglang.srt.conversation import (
12
31
  Conversation,
@@ -17,12 +36,16 @@ from sglang.srt.conversation import (
17
36
  )
18
37
  from sglang.srt.managers.io_struct import GenerateReqInput
19
38
  from sglang.srt.openai_api.protocol import (
39
+ BatchRequest,
40
+ BatchResponse,
20
41
  ChatCompletionRequest,
21
42
  ChatCompletionResponse,
22
43
  ChatCompletionResponseChoice,
23
44
  ChatCompletionResponseStreamChoice,
24
45
  ChatCompletionStreamResponse,
46
+ ChatCompletionTokenLogprob,
25
47
  ChatMessage,
48
+ ChoiceLogprobs,
26
49
  CompletionRequest,
27
50
  CompletionResponse,
28
51
  CompletionResponseChoice,
@@ -30,13 +53,34 @@ from sglang.srt.openai_api.protocol import (
30
53
  CompletionStreamResponse,
31
54
  DeltaMessage,
32
55
  ErrorResponse,
56
+ FileRequest,
57
+ FileResponse,
33
58
  LogProbs,
59
+ TopLogprob,
34
60
  UsageInfo,
35
61
  )
36
62
 
37
63
  chat_template_name = None
38
64
 
39
65
 
66
+ class FileMetadata:
67
+ def __init__(self, filename: str, purpose: str):
68
+ self.filename = filename
69
+ self.purpose = purpose
70
+
71
+
72
+ # In-memory storage for batch jobs and files
73
+ batch_storage: Dict[str, BatchResponse] = {}
74
+ file_id_request: Dict[str, FileMetadata] = {}
75
+ file_id_response: Dict[str, FileResponse] = {}
76
+ # map file id to file path in SGlang backend
77
+ file_id_storage: Dict[str, str] = {}
78
+
79
+
80
+ # backend storage directory
81
+ storage_dir = None
82
+
83
+
40
84
  def create_error_response(
41
85
  message: str,
42
86
  err_type: str = "BadRequestError",
@@ -91,33 +135,368 @@ def load_chat_template_for_openai_api(chat_template_arg):
91
135
  chat_template_name = chat_template_arg
92
136
 
93
137
 
94
- async def v1_completions(tokenizer_manager, raw_request: Request):
95
- request_json = await raw_request.json()
96
- request = CompletionRequest(**request_json)
97
- prompt = request.prompt
98
- if isinstance(prompt, str) or isinstance(prompt[0], str):
99
- prompt_kwargs = {"text": prompt}
100
- else:
101
- prompt_kwargs = {"input_ids": prompt}
138
+ async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None):
139
+ try:
140
+ global storage_dir
141
+ if file_storage_pth:
142
+ storage_dir = file_storage_pth
143
+ # Read the file content
144
+ file_content = await file.read()
145
+
146
+ # Create an instance of RequestBody
147
+ request_body = FileRequest(file=file_content, purpose=purpose)
148
+
149
+ # Save the file to the sglang_oai_storage directory
150
+ os.makedirs(storage_dir, exist_ok=True)
151
+ file_id = f"backend_input_file-{uuid.uuid4()}"
152
+ filename = f"{file_id}.jsonl"
153
+ file_path = os.path.join(storage_dir, filename)
154
+
155
+ with open(file_path, "wb") as f:
156
+ f.write(request_body.file)
157
+
158
+ # add info to global file map
159
+ file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose)
160
+ file_id_storage[file_id] = file_path
161
+
162
+ # Return the response in the required format
163
+ response = FileResponse(
164
+ id=file_id,
165
+ bytes=len(request_body.file),
166
+ created_at=int(time.time()),
167
+ filename=file.filename,
168
+ purpose=request_body.purpose,
169
+ )
170
+ file_id_response[file_id] = response
171
+
172
+ return response
173
+ except ValidationError as e:
174
+ return {"error": "Invalid input", "details": e.errors()}
175
+
176
+
177
+ async def v1_batches(tokenizer_manager, raw_request: Request):
178
+ try:
179
+ body = await raw_request.json()
102
180
 
181
+ batch_request = BatchRequest(**body)
182
+
183
+ batch_id = f"batch_{uuid.uuid4()}"
184
+
185
+ # Create an instance of BatchResponse
186
+ batch_response = BatchResponse(
187
+ id=batch_id,
188
+ endpoint=batch_request.endpoint,
189
+ input_file_id=batch_request.input_file_id,
190
+ completion_window=batch_request.completion_window,
191
+ created_at=int(time.time()),
192
+ metadata=batch_request.metadata,
193
+ )
194
+
195
+ batch_storage[batch_id] = batch_response
196
+
197
+ # Start processing the batch asynchronously
198
+ asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
199
+
200
+ # Return the initial batch_response
201
+ return batch_response
202
+
203
+ except ValidationError as e:
204
+ return {"error": "Invalid input", "details": e.errors()}
205
+ except Exception as e:
206
+ return {"error": str(e)}
207
+
208
+
209
+ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
210
+ try:
211
+ # Update the batch status to "in_progress"
212
+ batch_storage[batch_id].status = "in_progress"
213
+ batch_storage[batch_id].in_progress_at = int(time.time())
214
+
215
+ # Retrieve the input file content
216
+ input_file_request = file_id_request.get(batch_request.input_file_id)
217
+ if not input_file_request:
218
+ raise ValueError("Input file not found")
219
+
220
+ # Parse the JSONL file and process each request
221
+ input_file_path = file_id_storage.get(batch_request.input_file_id)
222
+ with open(input_file_path, "r", encoding="utf-8") as f:
223
+ lines = f.readlines()
224
+
225
+ total_requests = len(lines)
226
+ completed_requests = 0
227
+ failed_requests = 0
228
+
229
+ all_ret = []
230
+ end_point = batch_storage[batch_id].endpoint
231
+ file_request_list = []
232
+ all_requests = []
233
+ for line in lines:
234
+ request_data = json.loads(line)
235
+ file_request_list.append(request_data)
236
+ body = request_data["body"]
237
+ if end_point == "/v1/chat/completions":
238
+ all_requests.append(ChatCompletionRequest(**body))
239
+ elif end_point == "/v1/completions":
240
+ all_requests.append(CompletionRequest(**body))
241
+ if end_point == "/v1/chat/completions":
242
+ adapted_request, request = v1_chat_generate_request(
243
+ all_requests, tokenizer_manager
244
+ )
245
+ elif end_point == "/v1/completions":
246
+ adapted_request, request = v1_generate_request(all_requests)
247
+ try:
248
+ ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
249
+ if not isinstance(ret, list):
250
+ ret = [ret]
251
+ if end_point == "/v1/chat/completions":
252
+ responses = v1_chat_generate_response(request, ret, to_file=True)
253
+ else:
254
+ responses = v1_generate_response(request, ret, to_file=True)
255
+
256
+ except Exception as e:
257
+ error_json = {
258
+ "id": f"batch_req_{uuid.uuid4()}",
259
+ "custom_id": request_data.get("custom_id"),
260
+ "response": None,
261
+ "error": {"message": str(e)},
262
+ }
263
+ all_ret.append(error_json)
264
+ failed_requests += len(file_request_list)
265
+
266
+ for idx, response in enumerate(responses):
267
+ # the batch_req here can be changed to be named within a batch granularity
268
+ response_json = {
269
+ "id": f"batch_req_{uuid.uuid4()}",
270
+ "custom_id": file_request_list[idx].get("custom_id"),
271
+ "response": response,
272
+ "error": None,
273
+ }
274
+ all_ret.append(response_json)
275
+ completed_requests += 1
276
+ # Write results to a new file
277
+ output_file_id = f"backend_result_file-{uuid.uuid4()}"
278
+ global storage_dir
279
+ output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl")
280
+ with open(output_file_path, "w", encoding="utf-8") as f:
281
+ for ret in all_ret:
282
+ f.write(json.dumps(ret) + "\n")
283
+
284
+ # Update batch response with output file information
285
+ retrieve_batch = batch_storage[batch_id]
286
+ retrieve_batch.output_file_id = output_file_id
287
+ file_id_storage[output_file_id] = output_file_path
288
+ # Update batch status to "completed"
289
+ retrieve_batch.status = "completed"
290
+ retrieve_batch.completed_at = int(time.time())
291
+ retrieve_batch.request_counts = {
292
+ "total": total_requests,
293
+ "completed": completed_requests,
294
+ "failed": failed_requests,
295
+ }
296
+
297
+ except Exception as e:
298
+ print("error in SGlang:", e)
299
+ # Update batch status to "failed"
300
+ retrieve_batch = batch_storage[batch_id]
301
+ retrieve_batch.status = "failed"
302
+ retrieve_batch.failed_at = int(time.time())
303
+ retrieve_batch.errors = {"message": str(e)}
304
+
305
+
306
+ async def v1_retrieve_batch(batch_id: str):
307
+ # Retrieve the batch job from the in-memory storage
308
+ batch_response = batch_storage.get(batch_id)
309
+ if batch_response is None:
310
+ raise HTTPException(status_code=404, detail="Batch not found")
311
+
312
+ return batch_response
313
+
314
+
315
+ async def v1_retrieve_file(file_id: str):
316
+ # Retrieve the batch job from the in-memory storage
317
+ file_response = file_id_response.get(file_id)
318
+ if file_response is None:
319
+ raise HTTPException(status_code=404, detail="File not found")
320
+ return file_response
321
+
322
+
323
+ async def v1_retrieve_file_content(file_id: str):
324
+ file_pth = file_id_storage.get(file_id)
325
+ if not file_pth or not os.path.exists(file_pth):
326
+ raise HTTPException(status_code=404, detail="File not found")
327
+
328
+ def iter_file():
329
+ with open(file_pth, mode="rb") as file_like:
330
+ yield from file_like
331
+
332
+ return StreamingResponse(iter_file(), media_type="application/octet-stream")
333
+
334
+
335
+ def v1_generate_request(all_requests):
336
+
337
+ prompts = []
338
+ sampling_params_list = []
339
+ return_logprobs = []
340
+ top_logprobs_nums = []
341
+ first_prompt_type = type(all_requests[0].prompt)
342
+ for request in all_requests:
343
+ prompt = request.prompt
344
+ assert (
345
+ type(prompt) == first_prompt_type
346
+ ), "All prompts must be of the same type in file input settings"
347
+ prompts.append(prompt)
348
+ return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
349
+ top_logprobs_nums.append(
350
+ request.logprobs if request.logprobs is not None else 0
351
+ )
352
+ sampling_params_list.append(
353
+ {
354
+ "temperature": request.temperature,
355
+ "max_new_tokens": request.max_tokens,
356
+ "stop": request.stop,
357
+ "top_p": request.top_p,
358
+ "presence_penalty": request.presence_penalty,
359
+ "frequency_penalty": request.frequency_penalty,
360
+ "regex": request.regex,
361
+ "n": request.n,
362
+ "ignore_eos": request.ignore_eos,
363
+ }
364
+ )
365
+ if len(all_requests) > 1 and request.n > 1:
366
+ raise ValueError(
367
+ "Batch operation is not supported for completions from files"
368
+ )
369
+
370
+ if len(all_requests) == 1:
371
+ prompt = prompts[0]
372
+ sampling_params_list = sampling_params_list[0]
373
+ return_logprobs = return_logprobs[0]
374
+ top_logprobs_nums = top_logprobs_nums[0]
375
+ if isinstance(prompt, str) or isinstance(prompt[0], str):
376
+ prompt_kwargs = {"text": prompt}
377
+ else:
378
+ prompt_kwargs = {"input_ids": prompt}
379
+ else:
380
+ if isinstance(prompts[0], str):
381
+ prompt_kwargs = {"text": prompts}
382
+ else:
383
+ prompt_kwargs = {"input_ids": prompts}
103
384
  adapted_request = GenerateReqInput(
104
385
  **prompt_kwargs,
105
- sampling_params={
106
- "temperature": request.temperature,
107
- "max_new_tokens": request.max_tokens,
108
- "stop": request.stop,
109
- "top_p": request.top_p,
110
- "presence_penalty": request.presence_penalty,
111
- "frequency_penalty": request.frequency_penalty,
112
- "regex": request.regex,
113
- "n": request.n,
114
- "ignore_eos": request.ignore_eos,
115
- },
116
- return_logprob=request.logprobs is not None and request.logprobs > 0,
117
- top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
386
+ sampling_params=sampling_params_list,
387
+ return_logprob=return_logprobs,
388
+ top_logprobs_num=top_logprobs_nums,
118
389
  return_text_in_logprobs=True,
119
- stream=request.stream,
390
+ stream=all_requests[0].stream,
120
391
  )
392
+ if len(all_requests) == 1:
393
+ return adapted_request, all_requests[0]
394
+ return adapted_request, all_requests
395
+
396
+
397
+ def v1_generate_response(request, ret, to_file=False):
398
+ choices = []
399
+ echo = False
400
+
401
+ if (not isinstance(request, List)) and request.echo:
402
+ # TODO: handle the case propmt is token ids
403
+ if isinstance(request.prompt, list):
404
+ prompts = request.prompt
405
+ else:
406
+ prompts = [request.prompt]
407
+ echo = True
408
+
409
+ for idx, ret_item in enumerate(ret):
410
+ text = ret_item["text"]
411
+ if isinstance(request, List) and request[idx].echo:
412
+ echo = True
413
+ text = request[idx].prompt + text
414
+ if (not isinstance(request, List)) and echo:
415
+ text = prompts[idx] + text
416
+
417
+ logprobs = False
418
+ if isinstance(request, List) and request[idx].logprobs:
419
+ logprobs = True
420
+ elif (not isinstance(request, List)) and request.logprobs:
421
+ logprobs = True
422
+ if logprobs:
423
+ if echo:
424
+ input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
425
+ input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
426
+ else:
427
+ input_token_logprobs = None
428
+ input_top_logprobs = None
429
+
430
+ logprobs = to_openai_style_logprobs(
431
+ input_token_logprobs=input_token_logprobs,
432
+ input_top_logprobs=input_top_logprobs,
433
+ output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
434
+ output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
435
+ )
436
+ else:
437
+ logprobs = None
438
+
439
+ if to_file:
440
+ # to make the choise data json serializable
441
+ choice_data = {
442
+ "index": 0,
443
+ "text": text,
444
+ "logprobs": logprobs,
445
+ "finish_reason": ret_item["meta_info"]["finish_reason"],
446
+ }
447
+ else:
448
+ choice_data = CompletionResponseChoice(
449
+ index=idx,
450
+ text=text,
451
+ logprobs=logprobs,
452
+ finish_reason=ret_item["meta_info"]["finish_reason"],
453
+ )
454
+
455
+ choices.append(choice_data)
456
+
457
+ if to_file:
458
+ responses = []
459
+ for i, choice in enumerate(choices):
460
+ response = {
461
+ "status_code": 200,
462
+ "request_id": ret[i]["meta_info"]["id"],
463
+ "body": {
464
+ # remain the same but if needed we can change that
465
+ "id": ret[i]["meta_info"]["id"],
466
+ "object": "text_completion",
467
+ "created": int(time.time()),
468
+ "model": request[i].model,
469
+ "choices": choice,
470
+ "usage": {
471
+ "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
472
+ "completion_tokens": ret[i]["meta_info"]["completion_tokens"],
473
+ "total_tokens": ret[i]["meta_info"]["prompt_tokens"]
474
+ + ret[i]["meta_info"]["completion_tokens"],
475
+ },
476
+ "system_fingerprint": None,
477
+ },
478
+ }
479
+ responses.append(response)
480
+ return responses
481
+ else:
482
+ completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
483
+ response = CompletionResponse(
484
+ id=ret[0]["meta_info"]["id"],
485
+ model=request.model,
486
+ choices=choices,
487
+ usage=UsageInfo(
488
+ prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
489
+ completion_tokens=completion_tokens,
490
+ total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
491
+ ),
492
+ )
493
+ return response
494
+
495
+
496
+ async def v1_completions(tokenizer_manager, raw_request: Request):
497
+ request_json = await raw_request.json()
498
+ all_requests = [CompletionRequest(**request_json)]
499
+ adapted_request, request = v1_generate_request(all_requests)
121
500
 
122
501
  if adapted_request.stream:
123
502
 
@@ -208,105 +587,190 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
208
587
 
209
588
  if not isinstance(ret, list):
210
589
  ret = [ret]
211
- choices = []
212
590
 
213
- for idx, ret_item in enumerate(ret):
214
- text = ret_item["text"]
591
+ response = v1_generate_response(request, ret)
592
+ return response
215
593
 
216
- if request.echo:
217
- text = request.prompt + text
218
594
 
219
- if request.logprobs:
220
- if request.echo:
221
- input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
222
- input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
595
+ def v1_chat_generate_request(all_requests, tokenizer_manager):
596
+
597
+ texts = []
598
+ sampling_params_list = []
599
+ image_data_list = []
600
+ return_logprobs = []
601
+ top_logprobs_nums = []
602
+ for request in all_requests:
603
+ # Prep the data needed for the underlying GenerateReqInput:
604
+ # - prompt: The full prompt string.
605
+ # - stop: Custom stop tokens.
606
+ # - image_data: None or a list of image strings (URLs or base64 strings).
607
+ # None skips any image processing in GenerateReqInput.
608
+ if not isinstance(request.messages, str):
609
+ # Apply chat template and its stop strings.
610
+ if chat_template_name is None:
611
+ prompt = tokenizer_manager.tokenizer.apply_chat_template(
612
+ request.messages, tokenize=False, add_generation_prompt=True
613
+ )
614
+ stop = request.stop
615
+ image_data = None
223
616
  else:
224
- input_token_logprobs = None
225
- input_top_logprobs = None
617
+ conv = generate_chat_conv(request, chat_template_name)
618
+ prompt = conv.get_prompt()
619
+ image_data = conv.image_data
620
+ stop = conv.stop_str or []
621
+ if request.stop:
622
+ if isinstance(request.stop, str):
623
+ stop.append(request.stop)
624
+ else:
625
+ stop.extend(request.stop)
626
+ else:
627
+ # Use the raw prompt and stop strings if the messages is already a string.
628
+ prompt = request.messages
629
+ stop = request.stop
630
+ image_data = None
631
+ texts.append(prompt)
632
+ return_logprobs.append(request.logprobs)
633
+ top_logprobs_nums.append(request.top_logprobs)
634
+ sampling_params_list.append(
635
+ {
636
+ "temperature": request.temperature,
637
+ "max_new_tokens": request.max_tokens,
638
+ "stop": stop,
639
+ "top_p": request.top_p,
640
+ "presence_penalty": request.presence_penalty,
641
+ "frequency_penalty": request.frequency_penalty,
642
+ "regex": request.regex,
643
+ "n": request.n,
644
+ }
645
+ )
646
+ image_data_list.append(image_data)
647
+ if len(all_requests) == 1:
648
+ texts = texts[0]
649
+ sampling_params_list = sampling_params_list[0]
650
+ image_data = image_data_list[0]
651
+ return_logprobs = return_logprobs[0]
652
+ top_logprobs_nums = top_logprobs_nums[0]
653
+ adapted_request = GenerateReqInput(
654
+ text=texts,
655
+ image_data=image_data,
656
+ sampling_params=sampling_params_list,
657
+ return_logprob=return_logprobs,
658
+ top_logprobs_num=top_logprobs_nums,
659
+ stream=all_requests[0].stream,
660
+ return_text_in_logprobs=True,
661
+ )
662
+ if len(all_requests) == 1:
663
+ return adapted_request, all_requests[0]
664
+ return adapted_request, all_requests
665
+
666
+
667
+ def v1_chat_generate_response(request, ret, to_file=False):
668
+ choices = []
669
+ total_prompt_tokens = 0
670
+ total_completion_tokens = 0
226
671
 
672
+ for idx, ret_item in enumerate(ret):
673
+ logprobs = False
674
+ if isinstance(request, List) and request[idx].logprobs:
675
+ logprobs = True
676
+ elif (not isinstance(request, List)) and request.logprobs:
677
+ logprobs = True
678
+ if logprobs:
227
679
  logprobs = to_openai_style_logprobs(
228
- input_token_logprobs=input_token_logprobs,
229
- input_top_logprobs=input_top_logprobs,
230
680
  output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
231
681
  output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
232
682
  )
683
+ token_logprobs = []
684
+ for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs):
685
+ token_bytes = list(token.encode("utf-8"))
686
+ top_logprobs = []
687
+ if logprobs.top_logprobs:
688
+ for top_token, top_logprob in logprobs.top_logprobs[0].items():
689
+ top_token_bytes = list(top_token.encode("utf-8"))
690
+ top_logprobs.append(
691
+ TopLogprob(
692
+ token=top_token,
693
+ bytes=top_token_bytes,
694
+ logprob=top_logprob,
695
+ )
696
+ )
697
+ token_logprobs.append(
698
+ ChatCompletionTokenLogprob(
699
+ token=token,
700
+ bytes=token_bytes,
701
+ logprob=logprob,
702
+ top_logprobs=top_logprobs,
703
+ )
704
+ )
705
+
706
+ choice_logprobs = ChoiceLogprobs(content=token_logprobs)
233
707
  else:
234
- logprobs = None
708
+ choice_logprobs = None
709
+ prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
710
+ completion_tokens = ret_item["meta_info"]["completion_tokens"]
235
711
 
236
- choice_data = CompletionResponseChoice(
237
- index=idx,
238
- text=text,
239
- logprobs=logprobs,
240
- finish_reason=ret_item["meta_info"]["finish_reason"],
241
- )
712
+ if to_file:
713
+ # to make the choice data json serializable
714
+ choice_data = {
715
+ "index": 0,
716
+ "message": {"role": "assistant", "content": ret_item["text"]},
717
+ "logprobs": choice_logprobs,
718
+ "finish_reason": ret_item["meta_info"]["finish_reason"],
719
+ }
720
+ else:
721
+ choice_data = ChatCompletionResponseChoice(
722
+ index=idx,
723
+ message=ChatMessage(role="assistant", content=ret_item["text"]),
724
+ logprobs=choice_logprobs,
725
+ finish_reason=ret_item["meta_info"]["finish_reason"],
726
+ )
242
727
 
243
728
  choices.append(choice_data)
244
-
245
- response = CompletionResponse(
246
- id=ret[0]["meta_info"]["id"],
247
- model=request.model,
248
- choices=choices,
249
- usage=UsageInfo(
250
- prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
251
- completion_tokens=sum(
252
- item["meta_info"]["completion_tokens"] for item in ret
729
+ total_prompt_tokens += prompt_tokens
730
+ total_completion_tokens += completion_tokens
731
+ if to_file:
732
+ responses = []
733
+
734
+ for i, choice in enumerate(choices):
735
+ response = {
736
+ "status_code": 200,
737
+ "request_id": ret[i]["meta_info"]["id"],
738
+ "body": {
739
+ # remain the same but if needed we can change that
740
+ "id": ret[i]["meta_info"]["id"],
741
+ "object": "chat.completion",
742
+ "created": int(time.time()),
743
+ "model": request[i].model,
744
+ "choices": choice,
745
+ "usage": {
746
+ "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
747
+ "completion_tokens": ret[i]["meta_info"]["completion_tokens"],
748
+ "total_tokens": ret[i]["meta_info"]["prompt_tokens"]
749
+ + ret[i]["meta_info"]["completion_tokens"],
750
+ },
751
+ "system_fingerprint": None,
752
+ },
753
+ }
754
+ responses.append(response)
755
+ return responses
756
+ else:
757
+ response = ChatCompletionResponse(
758
+ id=ret[0]["meta_info"]["id"],
759
+ model=request.model,
760
+ choices=choices,
761
+ usage=UsageInfo(
762
+ prompt_tokens=total_prompt_tokens,
763
+ completion_tokens=total_completion_tokens,
764
+ total_tokens=total_prompt_tokens + total_completion_tokens,
253
765
  ),
254
- total_tokens=ret[0]["meta_info"]["prompt_tokens"]
255
- + sum(item["meta_info"]["completion_tokens"] for item in ret),
256
- ),
257
- )
258
-
259
- return response
766
+ )
767
+ return response
260
768
 
261
769
 
262
770
  async def v1_chat_completions(tokenizer_manager, raw_request: Request):
263
771
  request_json = await raw_request.json()
264
- request = ChatCompletionRequest(**request_json)
265
-
266
- # Prep the data needed for the underlying GenerateReqInput:
267
- # - prompt: The full prompt string.
268
- # - stop: Custom stop tokens.
269
- # - image_data: None or a list of image strings (URLs or base64 strings).
270
- # None skips any image processing in GenerateReqInput.
271
- if not isinstance(request.messages, str):
272
- # Apply chat template and its stop strings.
273
- if chat_template_name is None:
274
- prompt = tokenizer_manager.tokenizer.apply_chat_template(
275
- request.messages, tokenize=False, add_generation_prompt=True
276
- )
277
- stop = request.stop
278
- image_data = None
279
- else:
280
- conv = generate_chat_conv(request, chat_template_name)
281
- prompt = conv.get_prompt()
282
- image_data = conv.image_data
283
- stop = conv.stop_str or []
284
- if request.stop:
285
- if isinstance(request.stop, str):
286
- stop.append(request.stop)
287
- else:
288
- stop.extend(request.stop)
289
- else:
290
- # Use the raw prompt and stop strings if the messages is already a string.
291
- prompt = request.messages
292
- stop = request.stop
293
- image_data = None
294
-
295
- adapted_request = GenerateReqInput(
296
- text=prompt,
297
- image_data=image_data,
298
- sampling_params={
299
- "temperature": request.temperature,
300
- "max_new_tokens": request.max_tokens,
301
- "stop": stop,
302
- "top_p": request.top_p,
303
- "presence_penalty": request.presence_penalty,
304
- "frequency_penalty": request.frequency_penalty,
305
- "regex": request.regex,
306
- "n": request.n,
307
- },
308
- stream=request.stream,
309
- )
772
+ all_requests = [ChatCompletionRequest(**request_json)]
773
+ adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
310
774
 
311
775
  if adapted_request.stream:
312
776
 
@@ -368,34 +832,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
368
832
 
369
833
  if not isinstance(ret, list):
370
834
  ret = [ret]
371
- choices = []
372
- total_prompt_tokens = 0
373
- total_completion_tokens = 0
374
835
 
375
- for idx, ret_item in enumerate(ret):
376
- prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
377
- completion_tokens = ret_item["meta_info"]["completion_tokens"]
378
-
379
- choice_data = ChatCompletionResponseChoice(
380
- index=idx,
381
- message=ChatMessage(role="assistant", content=ret_item["text"]),
382
- finish_reason=ret_item["meta_info"]["finish_reason"],
383
- )
384
-
385
- choices.append(choice_data)
386
- total_prompt_tokens = prompt_tokens
387
- total_completion_tokens += completion_tokens
388
-
389
- response = ChatCompletionResponse(
390
- id=ret[0]["meta_info"]["id"],
391
- model=request.model,
392
- choices=choices,
393
- usage=UsageInfo(
394
- prompt_tokens=total_prompt_tokens,
395
- completion_tokens=total_completion_tokens,
396
- total_tokens=total_prompt_tokens + total_completion_tokens,
397
- ),
398
- )
836
+ response = v1_chat_generate_response(request, ret)
399
837
 
400
838
  return response
401
839