sglang 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +33 -26
  2. sglang/api.py +9 -1
  3. sglang/bench_latency.py +2 -2
  4. sglang/bench_serving.py +10 -1
  5. sglang/check_env.py +1 -1
  6. sglang/lang/backend/litellm.py +1 -1
  7. sglang/lang/backend/openai.py +1 -1
  8. sglang/lang/backend/runtime_endpoint.py +4 -4
  9. sglang/lang/interpreter.py +24 -9
  10. sglang/lang/ir.py +1 -1
  11. sglang/srt/constrained/__init__.py +15 -0
  12. sglang/srt/constrained/base_cache.py +15 -0
  13. sglang/srt/constrained/fsm_cache.py +36 -1
  14. sglang/srt/constrained/jump_forward.py +15 -0
  15. sglang/srt/conversation.py +26 -0
  16. sglang/srt/hf_transformers_utils.py +18 -1
  17. sglang/srt/layers/context_flashattention_nopad.py +15 -0
  18. sglang/srt/layers/extend_attention.py +15 -0
  19. sglang/srt/layers/fused_moe.py +15 -0
  20. sglang/srt/layers/linear.py +15 -0
  21. sglang/srt/layers/logits_processor.py +109 -72
  22. sglang/srt/layers/quantization/__init__.py +15 -0
  23. sglang/srt/layers/quantization/fp8.py +15 -0
  24. sglang/srt/layers/radix_attention.py +21 -3
  25. sglang/srt/layers/token_attention.py +16 -1
  26. sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
  27. sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
  28. sglang/srt/managers/detokenizer_manager.py +16 -1
  29. sglang/srt/managers/io_struct.py +38 -5
  30. sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
  31. sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +85 -25
  32. sglang/srt/managers/tokenizer_manager.py +99 -57
  33. sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +177 -81
  34. sglang/srt/mem_cache/flush_cache.py +33 -0
  35. sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
  36. sglang/srt/{managers/controller → mem_cache}/radix_cache.py +15 -0
  37. sglang/srt/mm_utils.py +15 -0
  38. sglang/srt/model_config.py +20 -0
  39. sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +42 -18
  40. sglang/srt/{managers/controller → model_executor}/model_runner.py +51 -16
  41. sglang/srt/model_loader/model_loader.py +15 -0
  42. sglang/srt/model_loader/utils.py +16 -1
  43. sglang/srt/models/chatglm.py +16 -1
  44. sglang/srt/models/commandr.py +16 -1
  45. sglang/srt/models/dbrx.py +16 -1
  46. sglang/srt/models/deepseek.py +16 -1
  47. sglang/srt/models/deepseek_v2.py +532 -0
  48. sglang/srt/models/gemma.py +16 -1
  49. sglang/srt/models/gemma2.py +16 -1
  50. sglang/srt/models/gpt_bigcode.py +16 -1
  51. sglang/srt/models/grok.py +16 -1
  52. sglang/srt/models/internlm2.py +16 -1
  53. sglang/srt/models/llama2.py +16 -1
  54. sglang/srt/models/llama_classification.py +19 -4
  55. sglang/srt/models/llava.py +17 -2
  56. sglang/srt/models/llavavid.py +17 -2
  57. sglang/srt/models/minicpm.py +16 -1
  58. sglang/srt/models/mistral.py +15 -0
  59. sglang/srt/models/mixtral.py +16 -1
  60. sglang/srt/models/mixtral_quant.py +16 -1
  61. sglang/srt/models/qwen.py +16 -1
  62. sglang/srt/models/qwen2.py +16 -1
  63. sglang/srt/models/qwen2_moe.py +16 -1
  64. sglang/srt/models/stablelm.py +16 -1
  65. sglang/srt/models/yivl.py +15 -0
  66. sglang/srt/openai_api/adapter.py +545 -160
  67. sglang/srt/openai_api/protocol.py +65 -1
  68. sglang/srt/sampling_params.py +20 -4
  69. sglang/srt/server.py +90 -37
  70. sglang/srt/server_args.py +76 -17
  71. sglang/srt/utils.py +15 -0
  72. sglang/test/test_programs.py +5 -1
  73. sglang/utils.py +22 -0
  74. sglang/version.py +1 -1
  75. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/METADATA +40 -12
  76. sglang-0.2.7.dist-info/RECORD +93 -0
  77. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/WHEEL +1 -1
  78. sglang/srt/flush_cache.py +0 -18
  79. sglang-0.2.5.dist-info/RECORD +0 -92
  80. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,31 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """Conversion between OpenAI APIs and native SRT APIs"""
2
17
 
3
18
  import asyncio
4
19
  import json
5
20
  import os
21
+ import time
22
+ import uuid
6
23
  from http import HTTPStatus
24
+ from typing import Dict, List, Optional
7
25
 
8
- from fastapi import Request
26
+ from fastapi import HTTPException, Request, UploadFile
9
27
  from fastapi.responses import JSONResponse, StreamingResponse
28
+ from pydantic import ValidationError
10
29
 
11
30
  from sglang.srt.conversation import (
12
31
  Conversation,
@@ -17,6 +36,8 @@ from sglang.srt.conversation import (
17
36
  )
18
37
  from sglang.srt.managers.io_struct import GenerateReqInput
19
38
  from sglang.srt.openai_api.protocol import (
39
+ BatchRequest,
40
+ BatchResponse,
20
41
  ChatCompletionRequest,
21
42
  ChatCompletionResponse,
22
43
  ChatCompletionResponseChoice,
@@ -30,6 +51,8 @@ from sglang.srt.openai_api.protocol import (
30
51
  CompletionStreamResponse,
31
52
  DeltaMessage,
32
53
  ErrorResponse,
54
+ FileRequest,
55
+ FileResponse,
33
56
  LogProbs,
34
57
  UsageInfo,
35
58
  )
@@ -37,6 +60,24 @@ from sglang.srt.openai_api.protocol import (
37
60
  chat_template_name = None
38
61
 
39
62
 
63
+ class FileMetadata:
64
+ def __init__(self, filename: str, purpose: str):
65
+ self.filename = filename
66
+ self.purpose = purpose
67
+
68
+
69
+ # In-memory storage for batch jobs and files
70
+ batch_storage: Dict[str, BatchResponse] = {}
71
+ file_id_request: Dict[str, FileMetadata] = {}
72
+ file_id_response: Dict[str, FileResponse] = {}
73
+ ## map file id to file path in SGlang backend
74
+ file_id_storage: Dict[str, str] = {}
75
+
76
+
77
+ # backend storage directory
78
+ storage_dir = None
79
+
80
+
40
81
  def create_error_response(
41
82
  message: str,
42
83
  err_type: str = "BadRequestError",
@@ -91,33 +132,364 @@ def load_chat_template_for_openai_api(chat_template_arg):
91
132
  chat_template_name = chat_template_arg
92
133
 
93
134
 
94
- async def v1_completions(tokenizer_manager, raw_request: Request):
95
- request_json = await raw_request.json()
96
- request = CompletionRequest(**request_json)
97
- prompt = request.prompt
98
- if isinstance(prompt, str) or isinstance(prompt[0], str):
99
- prompt_kwargs = {"text": prompt}
135
+ async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None):
136
+ try:
137
+ global storage_dir
138
+ if file_storage_pth:
139
+ storage_dir = file_storage_pth
140
+ # Read the file content
141
+ file_content = await file.read()
142
+
143
+ # Create an instance of RequestBody
144
+ request_body = FileRequest(file=file_content, purpose=purpose)
145
+
146
+ # Save the file to the sglang_oai_storage directory
147
+ os.makedirs(storage_dir, exist_ok=True)
148
+ file_id = f"backend_input_file-{uuid.uuid4()}"
149
+ filename = f"{file_id}.jsonl"
150
+ file_path = os.path.join(storage_dir, filename)
151
+
152
+ with open(file_path, "wb") as f:
153
+ f.write(request_body.file)
154
+
155
+ # add info to global file map
156
+ file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose)
157
+ file_id_storage[file_id] = file_path
158
+
159
+ # Return the response in the required format
160
+ response = FileResponse(
161
+ id=file_id,
162
+ bytes=len(request_body.file),
163
+ created_at=int(time.time()),
164
+ filename=file.filename,
165
+ purpose=request_body.purpose,
166
+ )
167
+ file_id_response[file_id] = response
168
+
169
+ return response
170
+ except ValidationError as e:
171
+ return {"error": "Invalid input", "details": e.errors()}
172
+
173
+
174
+ async def v1_batches(tokenizer_manager, raw_request: Request):
175
+ try:
176
+ body = await raw_request.json()
177
+
178
+ batch_request = BatchRequest(**body)
179
+
180
+ batch_id = f"batch_{uuid.uuid4()}"
181
+
182
+ # Create an instance of BatchResponse
183
+ batch_response = BatchResponse(
184
+ id=batch_id,
185
+ endpoint=batch_request.endpoint,
186
+ input_file_id=batch_request.input_file_id,
187
+ completion_window=batch_request.completion_window,
188
+ created_at=int(time.time()),
189
+ metadata=batch_request.metadata,
190
+ )
191
+
192
+ batch_storage[batch_id] = batch_response
193
+
194
+ # Start processing the batch asynchronously
195
+ asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
196
+
197
+ # Return the initial batch_response
198
+ return batch_response
199
+
200
+ except ValidationError as e:
201
+ return {"error": "Invalid input", "details": e.errors()}
202
+ except Exception as e:
203
+ return {"error": str(e)}
204
+
205
+
206
+ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
207
+ try:
208
+ # Update the batch status to "in_progress"
209
+ batch_storage[batch_id].status = "in_progress"
210
+ batch_storage[batch_id].in_progress_at = int(time.time())
211
+
212
+ # Retrieve the input file content
213
+ input_file_request = file_id_request.get(batch_request.input_file_id)
214
+ if not input_file_request:
215
+ raise ValueError("Input file not found")
216
+
217
+ # Parse the JSONL file and process each request
218
+ input_file_path = file_id_storage.get(batch_request.input_file_id)
219
+ with open(input_file_path, "r", encoding="utf-8") as f:
220
+ lines = f.readlines()
221
+
222
+ total_requests = len(lines)
223
+ completed_requests = 0
224
+ failed_requests = 0
225
+
226
+ all_ret = []
227
+ end_point = batch_storage[batch_id].endpoint
228
+ file_request_list = []
229
+ all_requests = []
230
+ for line in lines:
231
+ request_data = json.loads(line)
232
+ file_request_list.append(request_data)
233
+ body = request_data["body"]
234
+ if end_point == "/v1/chat/completions":
235
+ all_requests.append(ChatCompletionRequest(**body))
236
+ elif end_point == "/v1/completions":
237
+ all_requests.append(CompletionRequest(**body))
238
+ if end_point == "/v1/chat/completions":
239
+ adapted_request, request = v1_chat_generate_request(
240
+ all_requests, tokenizer_manager
241
+ )
242
+ elif end_point == "/v1/completions":
243
+ adapted_request, request = v1_generate_request(all_requests)
244
+ try:
245
+ ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
246
+ if not isinstance(ret, list):
247
+ ret = [ret]
248
+ if end_point == "/v1/chat/completions":
249
+ responses = v1_chat_generate_response(request, ret, to_file=True)
250
+ else:
251
+ responses = v1_generate_response(request, ret, to_file=True)
252
+
253
+ except Exception as e:
254
+ error_json = {
255
+ "id": f"batch_req_{uuid.uuid4()}",
256
+ "custom_id": request_data.get("custom_id"),
257
+ "response": None,
258
+ "error": {"message": str(e)},
259
+ }
260
+ all_ret.append(error_json)
261
+ failed_requests += len(file_request_list)
262
+
263
+ for idx, response in enumerate(responses):
264
+ ## the batch_req here can be changed to be named within a batch granularity
265
+ response_json = {
266
+ "id": f"batch_req_{uuid.uuid4()}",
267
+ "custom_id": file_request_list[idx].get("custom_id"),
268
+ "response": response,
269
+ "error": None,
270
+ }
271
+ all_ret.append(response_json)
272
+ completed_requests += 1
273
+ # Write results to a new file
274
+ output_file_id = f"backend_result_file-{uuid.uuid4()}"
275
+ global storage_dir
276
+ output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl")
277
+ with open(output_file_path, "w", encoding="utf-8") as f:
278
+ for ret in all_ret:
279
+ f.write(json.dumps(ret) + "\n")
280
+
281
+ # Update batch response with output file information
282
+ retrieve_batch = batch_storage[batch_id]
283
+ retrieve_batch.output_file_id = output_file_id
284
+ file_id_storage[output_file_id] = output_file_path
285
+ # Update batch status to "completed"
286
+ retrieve_batch.status = "completed"
287
+ retrieve_batch.completed_at = int(time.time())
288
+ retrieve_batch.request_counts = {
289
+ "total": total_requests,
290
+ "completed": completed_requests,
291
+ "failed": failed_requests,
292
+ }
293
+
294
+ except Exception as e:
295
+ print("error in SGlang:", e)
296
+ # Update batch status to "failed"
297
+ retrieve_batch = batch_storage[batch_id]
298
+ retrieve_batch.status = "failed"
299
+ retrieve_batch.failed_at = int(time.time())
300
+ retrieve_batch.errors = {"message": str(e)}
301
+
302
+
303
+ async def v1_retrieve_batch(batch_id: str):
304
+ # Retrieve the batch job from the in-memory storage
305
+ batch_response = batch_storage.get(batch_id)
306
+ if batch_response is None:
307
+ raise HTTPException(status_code=404, detail="Batch not found")
308
+
309
+ return batch_response
310
+
311
+
312
+ async def v1_retrieve_file(file_id: str):
313
+ # Retrieve the batch job from the in-memory storage
314
+ file_response = file_id_response.get(file_id)
315
+ if file_response is None:
316
+ raise HTTPException(status_code=404, detail="File not found")
317
+ return file_response
318
+
319
+
320
+ async def v1_retrieve_file_content(file_id: str):
321
+ file_pth = file_id_storage.get(file_id)
322
+ if not file_pth or not os.path.exists(file_pth):
323
+ raise HTTPException(status_code=404, detail="File not found")
324
+
325
+ def iter_file():
326
+ with open(file_pth, mode="rb") as file_like:
327
+ yield from file_like
328
+
329
+ return StreamingResponse(iter_file(), media_type="application/octet-stream")
330
+
331
+
332
+ def v1_generate_request(all_requests):
333
+
334
+ prompts = []
335
+ sampling_params_list = []
336
+ first_prompt_type = type(all_requests[0].prompt)
337
+ for request in all_requests:
338
+ prompt = request.prompt
339
+ assert (
340
+ type(prompt) == first_prompt_type
341
+ ), "All prompts must be of the same type in file input settings"
342
+ prompts.append(prompt)
343
+ sampling_params_list.append(
344
+ {
345
+ "temperature": request.temperature,
346
+ "max_new_tokens": request.max_tokens,
347
+ "stop": request.stop,
348
+ "top_p": request.top_p,
349
+ "presence_penalty": request.presence_penalty,
350
+ "frequency_penalty": request.frequency_penalty,
351
+ "regex": request.regex,
352
+ "n": request.n,
353
+ "ignore_eos": request.ignore_eos,
354
+ }
355
+ )
356
+ if len(all_requests) > 1 and request.n > 1:
357
+ raise ValueError(
358
+ "Batch operation is not supported for completions from files"
359
+ )
360
+
361
+ if len(all_requests) == 1:
362
+ prompt = prompts[0]
363
+ sampling_params_list = sampling_params_list[0]
364
+ if isinstance(prompts, str) or isinstance(prompts[0], str):
365
+ prompt_kwargs = {"text": prompt}
366
+ else:
367
+ prompt_kwargs = {"input_ids": prompt}
100
368
  else:
101
- prompt_kwargs = {"input_ids": prompt}
369
+ if isinstance(prompts[0], str):
370
+ prompt_kwargs = {"text": prompts}
371
+ else:
372
+ prompt_kwargs = {"input_ids": prompts}
102
373
 
103
374
  adapted_request = GenerateReqInput(
104
375
  **prompt_kwargs,
105
- sampling_params={
106
- "temperature": request.temperature,
107
- "max_new_tokens": request.max_tokens,
108
- "stop": request.stop,
109
- "top_p": request.top_p,
110
- "presence_penalty": request.presence_penalty,
111
- "frequency_penalty": request.frequency_penalty,
112
- "regex": request.regex,
113
- "n": request.n,
114
- "ignore_eos": request.ignore_eos,
115
- },
116
- return_logprob=request.logprobs is not None and request.logprobs > 0,
117
- top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
376
+ sampling_params=sampling_params_list,
377
+ return_logprob=all_requests[0].logprobs is not None
378
+ and all_requests[0].logprobs > 0,
379
+ top_logprobs_num=(
380
+ all_requests[0].logprobs if all_requests[0].logprobs is not None else 0
381
+ ),
118
382
  return_text_in_logprobs=True,
119
- stream=request.stream,
383
+ stream=all_requests[0].stream,
120
384
  )
385
+ if len(all_requests) == 1:
386
+ return adapted_request, all_requests[0]
387
+ return adapted_request, all_requests
388
+
389
+
390
+ def v1_generate_response(request, ret, to_file=False):
391
+ choices = []
392
+ echo = False
393
+
394
+ if (not isinstance(request, List)) and request.echo:
395
+ # TODO: handle the case propmt is token ids
396
+ if isinstance(request.prompt, list):
397
+ prompts = request.prompt
398
+ else:
399
+ prompts = [request.prompt]
400
+ echo = True
401
+
402
+ for idx, ret_item in enumerate(ret):
403
+ text = ret_item["text"]
404
+ if isinstance(request, List) and request[idx].echo:
405
+ echo = True
406
+ text = request[idx].prompt + text
407
+ if (not isinstance(request, List)) and echo:
408
+ text = prompts[idx] + text
409
+
410
+ logprobs = False
411
+ if isinstance(request, List) and request[idx].logprobs:
412
+ logprobs = True
413
+ elif (not isinstance(request, List)) and request.logprobs:
414
+ logprobs = True
415
+ if logprobs:
416
+ if echo:
417
+ input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
418
+ input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
419
+ else:
420
+ input_token_logprobs = None
421
+ input_top_logprobs = None
422
+
423
+ logprobs = to_openai_style_logprobs(
424
+ input_token_logprobs=input_token_logprobs,
425
+ input_top_logprobs=input_top_logprobs,
426
+ output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
427
+ output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
428
+ )
429
+ else:
430
+ logprobs = None
431
+
432
+ if to_file:
433
+ ## to make the choise data json serializable
434
+ choice_data = {
435
+ "index": 0,
436
+ "text": text,
437
+ "logprobs": logprobs,
438
+ "finish_reason": ret_item["meta_info"]["finish_reason"],
439
+ }
440
+ else:
441
+ choice_data = CompletionResponseChoice(
442
+ index=idx,
443
+ text=text,
444
+ logprobs=logprobs,
445
+ finish_reason=ret_item["meta_info"]["finish_reason"],
446
+ )
447
+
448
+ choices.append(choice_data)
449
+
450
+ if to_file:
451
+ responses = []
452
+ for i, choice in enumerate(choices):
453
+ response = {
454
+ "status_code": 200,
455
+ "request_id": ret[i]["meta_info"]["id"],
456
+ "body": {
457
+ ## remain the same but if needed we can change that
458
+ "id": ret[i]["meta_info"]["id"],
459
+ "object": "text_completion",
460
+ "created": int(time.time()),
461
+ "model": request[i].model,
462
+ "choices": choice,
463
+ "usage": {
464
+ "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
465
+ "completion_tokens": ret[i]["meta_info"]["completion_tokens"],
466
+ "total_tokens": ret[i]["meta_info"]["prompt_tokens"]
467
+ + ret[i]["meta_info"]["completion_tokens"],
468
+ },
469
+ "system_fingerprint": None,
470
+ },
471
+ }
472
+ responses.append(response)
473
+ return responses
474
+ else:
475
+ completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
476
+ response = CompletionResponse(
477
+ id=ret[0]["meta_info"]["id"],
478
+ model=request.model,
479
+ choices=choices,
480
+ usage=UsageInfo(
481
+ prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
482
+ completion_tokens=completion_tokens,
483
+ total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
484
+ ),
485
+ )
486
+ return response
487
+
488
+
489
+ async def v1_completions(tokenizer_manager, raw_request: Request):
490
+ request_json = await raw_request.json()
491
+ all_requests = [CompletionRequest(**request_json)]
492
+ adapted_request, request = v1_generate_request(all_requests)
121
493
 
122
494
  if adapted_request.stream:
123
495
 
@@ -140,29 +512,29 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
140
512
  if request.logprobs:
141
513
  # The first chunk and echo is enabled.
142
514
  if not stream_buffer and request.echo:
143
- prefill_token_logprobs = content["meta_info"][
144
- "prefill_token_logprobs"
515
+ input_token_logprobs = content["meta_info"][
516
+ "input_token_logprobs"
145
517
  ]
146
- prefill_top_logprobs = content["meta_info"][
147
- "prefill_top_logprobs"
518
+ input_top_logprobs = content["meta_info"][
519
+ "input_top_logprobs"
148
520
  ]
149
521
  else:
150
- prefill_token_logprobs = None
151
- prefill_top_logprobs = None
522
+ input_token_logprobs = None
523
+ input_top_logprobs = None
152
524
 
153
525
  logprobs = to_openai_style_logprobs(
154
- prefill_token_logprobs=prefill_token_logprobs,
155
- prefill_top_logprobs=prefill_top_logprobs,
156
- decode_token_logprobs=content["meta_info"][
157
- "decode_token_logprobs"
526
+ input_token_logprobs=input_token_logprobs,
527
+ input_top_logprobs=input_top_logprobs,
528
+ output_token_logprobs=content["meta_info"][
529
+ "output_token_logprobs"
158
530
  ][n_prev_token:],
159
- decode_top_logprobs=content["meta_info"][
160
- "decode_top_logprobs"
531
+ output_top_logprobs=content["meta_info"][
532
+ "output_top_logprobs"
161
533
  ][n_prev_token:],
162
534
  )
163
535
 
164
536
  n_prev_token = len(
165
- content["meta_info"]["decode_token_logprobs"]
537
+ content["meta_info"]["output_token_logprobs"]
166
538
  )
167
539
  else:
168
540
  logprobs = None
@@ -208,105 +580,144 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
208
580
 
209
581
  if not isinstance(ret, list):
210
582
  ret = [ret]
211
- choices = []
212
583
 
213
- for idx, ret_item in enumerate(ret):
214
- text = ret_item["text"]
584
+ response = v1_generate_response(request, ret)
585
+ return response
215
586
 
216
- if request.echo:
217
- text = request.prompt + text
218
587
 
219
- if request.logprobs:
220
- if request.echo:
221
- prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"]
222
- prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"]
588
+ def v1_chat_generate_request(all_requests, tokenizer_manager):
589
+
590
+ texts = []
591
+ sampling_params_list = []
592
+ image_data_list = []
593
+ for request in all_requests:
594
+ # Prep the data needed for the underlying GenerateReqInput:
595
+ # - prompt: The full prompt string.
596
+ # - stop: Custom stop tokens.
597
+ # - image_data: None or a list of image strings (URLs or base64 strings).
598
+ # None skips any image processing in GenerateReqInput.
599
+ if not isinstance(request.messages, str):
600
+ # Apply chat template and its stop strings.
601
+ if chat_template_name is None:
602
+ prompt = tokenizer_manager.tokenizer.apply_chat_template(
603
+ request.messages, tokenize=False, add_generation_prompt=True
604
+ )
605
+ stop = request.stop
606
+ image_data = None
223
607
  else:
224
- prefill_token_logprobs = None
225
- prefill_top_logprobs = None
226
-
227
- logprobs = to_openai_style_logprobs(
228
- prefill_token_logprobs=prefill_token_logprobs,
229
- prefill_top_logprobs=prefill_top_logprobs,
230
- decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"],
231
- decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"],
232
- )
608
+ conv = generate_chat_conv(request, chat_template_name)
609
+ prompt = conv.get_prompt()
610
+ image_data = conv.image_data
611
+ stop = conv.stop_str or []
612
+ if request.stop:
613
+ if isinstance(request.stop, str):
614
+ stop.append(request.stop)
615
+ else:
616
+ stop.extend(request.stop)
233
617
  else:
234
- logprobs = None
235
-
236
- choice_data = CompletionResponseChoice(
237
- index=idx,
238
- text=text,
239
- logprobs=logprobs,
240
- finish_reason=ret_item["meta_info"]["finish_reason"],
618
+ # Use the raw prompt and stop strings if the messages is already a string.
619
+ prompt = request.messages
620
+ stop = request.stop
621
+ image_data = None
622
+ texts.append(prompt)
623
+ sampling_params_list.append(
624
+ {
625
+ "temperature": request.temperature,
626
+ "max_new_tokens": request.max_tokens,
627
+ "stop": stop,
628
+ "top_p": request.top_p,
629
+ "presence_penalty": request.presence_penalty,
630
+ "frequency_penalty": request.frequency_penalty,
631
+ "regex": request.regex,
632
+ "n": request.n,
633
+ }
241
634
  )
242
-
243
- choices.append(choice_data)
244
-
245
- response = CompletionResponse(
246
- id=ret[0]["meta_info"]["id"],
247
- model=request.model,
248
- choices=choices,
249
- usage=UsageInfo(
250
- prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
251
- completion_tokens=sum(
252
- item["meta_info"]["completion_tokens"] for item in ret
253
- ),
254
- total_tokens=ret[0]["meta_info"]["prompt_tokens"]
255
- + sum(item["meta_info"]["completion_tokens"] for item in ret),
256
- ),
635
+ image_data_list.append(image_data)
636
+ if len(all_requests) == 1:
637
+ texts = texts[0]
638
+ sampling_params_list = sampling_params_list[0]
639
+ image_data = image_data_list[0]
640
+ adapted_request = GenerateReqInput(
641
+ text=texts,
642
+ image_data=image_data,
643
+ sampling_params=sampling_params_list,
644
+ stream=request.stream,
257
645
  )
646
+ if len(all_requests) == 1:
647
+ return adapted_request, all_requests[0]
648
+ return adapted_request, all_requests
258
649
 
259
- return response
260
650
 
651
+ def v1_chat_generate_response(request, ret, to_file=False):
652
+ choices = []
653
+ total_prompt_tokens = 0
654
+ total_completion_tokens = 0
261
655
 
262
- async def v1_chat_completions(tokenizer_manager, raw_request: Request):
263
- request_json = await raw_request.json()
264
- request = ChatCompletionRequest(**request_json)
265
-
266
- # Prep the data needed for the underlying GenerateReqInput:
267
- # - prompt: The full prompt string.
268
- # - stop: Custom stop tokens.
269
- # - image_data: None or a list of image strings (URLs or base64 strings).
270
- # None skips any image processing in GenerateReqInput.
271
- if not isinstance(request.messages, str):
272
- # Apply chat template and its stop strings.
273
- if chat_template_name is None:
274
- prompt = tokenizer_manager.tokenizer.apply_chat_template(
275
- request.messages, tokenize=False, add_generation_prompt=True
276
- )
277
- stop = request.stop
278
- image_data = None
656
+ for idx, ret_item in enumerate(ret):
657
+ prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
658
+ completion_tokens = ret_item["meta_info"]["completion_tokens"]
659
+
660
+ if to_file:
661
+ ## to make the choice data json serializable
662
+ choice_data = {
663
+ "index": 0,
664
+ "message": {"role": "assistant", "content": ret_item["text"]},
665
+ "logprobs": None,
666
+ "finish_reason": ret_item["meta_info"]["finish_reason"],
667
+ }
279
668
  else:
280
- conv = generate_chat_conv(request, chat_template_name)
281
- prompt = conv.get_prompt()
282
- image_data = conv.image_data
283
- stop = conv.stop_str or []
284
- if request.stop:
285
- if isinstance(request.stop, str):
286
- stop.append(request.stop)
287
- else:
288
- stop.extend(request.stop)
669
+ choice_data = ChatCompletionResponseChoice(
670
+ index=idx,
671
+ message=ChatMessage(role="assistant", content=ret_item["text"]),
672
+ finish_reason=ret_item["meta_info"]["finish_reason"],
673
+ )
674
+
675
+ choices.append(choice_data)
676
+ total_prompt_tokens = prompt_tokens
677
+ total_completion_tokens += completion_tokens
678
+ if to_file:
679
+ responses = []
680
+
681
+ for i, choice in enumerate(choices):
682
+ response = {
683
+ "status_code": 200,
684
+ "request_id": ret[i]["meta_info"]["id"],
685
+ "body": {
686
+ ## remain the same but if needed we can change that
687
+ "id": ret[i]["meta_info"]["id"],
688
+ "object": "chat.completion",
689
+ "created": int(time.time()),
690
+ "model": request[i].model,
691
+ "choices": choice,
692
+ "usage": {
693
+ "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
694
+ "completion_tokens": ret[i]["meta_info"]["completion_tokens"],
695
+ "total_tokens": ret[i]["meta_info"]["prompt_tokens"]
696
+ + ret[i]["meta_info"]["completion_tokens"],
697
+ },
698
+ "system_fingerprint": None,
699
+ },
700
+ }
701
+ responses.append(response)
702
+ return responses
289
703
  else:
290
- # Use the raw prompt and stop strings if the messages is already a string.
291
- prompt = request.messages
292
- stop = request.stop
293
- image_data = None
704
+ response = ChatCompletionResponse(
705
+ id=ret[0]["meta_info"]["id"],
706
+ model=request.model,
707
+ choices=choices,
708
+ usage=UsageInfo(
709
+ prompt_tokens=total_prompt_tokens,
710
+ completion_tokens=total_completion_tokens,
711
+ total_tokens=total_prompt_tokens + total_completion_tokens,
712
+ ),
713
+ )
714
+ return response
294
715
 
295
- adapted_request = GenerateReqInput(
296
- text=prompt,
297
- image_data=image_data,
298
- sampling_params={
299
- "temperature": request.temperature,
300
- "max_new_tokens": request.max_tokens,
301
- "stop": stop,
302
- "top_p": request.top_p,
303
- "presence_penalty": request.presence_penalty,
304
- "frequency_penalty": request.frequency_penalty,
305
- "regex": request.regex,
306
- "n": request.n,
307
- },
308
- stream=request.stream,
309
- )
716
+
717
+ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
718
+ request_json = await raw_request.json()
719
+ all_requests = [ChatCompletionRequest(**request_json)]
720
+ adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
310
721
 
311
722
  if adapted_request.stream:
312
723
 
@@ -368,43 +779,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
368
779
 
369
780
  if not isinstance(ret, list):
370
781
  ret = [ret]
371
- choices = []
372
- total_prompt_tokens = 0
373
- total_completion_tokens = 0
374
-
375
- for idx, ret_item in enumerate(ret):
376
- prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
377
- completion_tokens = ret_item["meta_info"]["completion_tokens"]
378
782
 
379
- choice_data = ChatCompletionResponseChoice(
380
- index=idx,
381
- message=ChatMessage(role="assistant", content=ret_item["text"]),
382
- finish_reason=ret_item["meta_info"]["finish_reason"],
383
- )
384
-
385
- choices.append(choice_data)
386
- total_prompt_tokens = prompt_tokens
387
- total_completion_tokens += completion_tokens
388
-
389
- response = ChatCompletionResponse(
390
- id=ret[0]["meta_info"]["id"],
391
- model=request.model,
392
- choices=choices,
393
- usage=UsageInfo(
394
- prompt_tokens=total_prompt_tokens,
395
- completion_tokens=total_completion_tokens,
396
- total_tokens=total_prompt_tokens + total_completion_tokens,
397
- ),
398
- )
783
+ response = v1_chat_generate_response(request, ret)
399
784
 
400
785
  return response
401
786
 
402
787
 
403
788
  def to_openai_style_logprobs(
404
- prefill_token_logprobs=None,
405
- decode_token_logprobs=None,
406
- prefill_top_logprobs=None,
407
- decode_top_logprobs=None,
789
+ input_token_logprobs=None,
790
+ output_token_logprobs=None,
791
+ input_top_logprobs=None,
792
+ output_top_logprobs=None,
408
793
  ):
409
794
  ret_logprobs = LogProbs()
410
795
 
@@ -425,13 +810,13 @@ def to_openai_style_logprobs(
425
810
  else:
426
811
  ret_logprobs.top_logprobs.append(None)
427
812
 
428
- if prefill_token_logprobs is not None:
429
- append_token_logprobs(prefill_token_logprobs)
430
- if decode_token_logprobs is not None:
431
- append_token_logprobs(decode_token_logprobs)
432
- if prefill_top_logprobs is not None:
433
- append_top_logprobs(prefill_top_logprobs)
434
- if decode_top_logprobs is not None:
435
- append_top_logprobs(decode_top_logprobs)
813
+ if input_token_logprobs is not None:
814
+ append_token_logprobs(input_token_logprobs)
815
+ if output_token_logprobs is not None:
816
+ append_token_logprobs(output_token_logprobs)
817
+ if input_top_logprobs is not None:
818
+ append_top_logprobs(input_top_logprobs)
819
+ if output_top_logprobs is not None:
820
+ append_top_logprobs(output_top_logprobs)
436
821
 
437
822
  return ret_logprobs