sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +234 -74
  4. sglang/check_env.py +25 -2
  5. sglang/global_config.py +0 -1
  6. sglang/lang/backend/base_backend.py +3 -1
  7. sglang/lang/backend/openai.py +8 -3
  8. sglang/lang/backend/runtime_endpoint.py +46 -40
  9. sglang/lang/choices.py +164 -0
  10. sglang/lang/interpreter.py +6 -13
  11. sglang/lang/ir.py +11 -2
  12. sglang/srt/hf_transformers_utils.py +2 -2
  13. sglang/srt/layers/extend_attention.py +59 -7
  14. sglang/srt/layers/logits_processor.py +1 -1
  15. sglang/srt/layers/radix_attention.py +24 -14
  16. sglang/srt/layers/token_attention.py +28 -2
  17. sglang/srt/managers/io_struct.py +9 -4
  18. sglang/srt/managers/schedule_batch.py +98 -323
  19. sglang/srt/managers/tokenizer_manager.py +34 -16
  20. sglang/srt/managers/tp_worker.py +20 -22
  21. sglang/srt/mem_cache/memory_pool.py +74 -38
  22. sglang/srt/model_config.py +11 -0
  23. sglang/srt/model_executor/cuda_graph_runner.py +3 -3
  24. sglang/srt/model_executor/forward_batch_info.py +256 -0
  25. sglang/srt/model_executor/model_runner.py +51 -26
  26. sglang/srt/models/chatglm.py +1 -1
  27. sglang/srt/models/commandr.py +1 -1
  28. sglang/srt/models/dbrx.py +1 -1
  29. sglang/srt/models/deepseek.py +1 -1
  30. sglang/srt/models/deepseek_v2.py +199 -17
  31. sglang/srt/models/gemma.py +1 -1
  32. sglang/srt/models/gemma2.py +1 -1
  33. sglang/srt/models/gpt_bigcode.py +1 -1
  34. sglang/srt/models/grok.py +1 -1
  35. sglang/srt/models/internlm2.py +1 -1
  36. sglang/srt/models/llama2.py +1 -1
  37. sglang/srt/models/llama_classification.py +1 -1
  38. sglang/srt/models/llava.py +1 -2
  39. sglang/srt/models/llavavid.py +1 -2
  40. sglang/srt/models/minicpm.py +1 -1
  41. sglang/srt/models/mixtral.py +1 -1
  42. sglang/srt/models/mixtral_quant.py +1 -1
  43. sglang/srt/models/qwen.py +1 -1
  44. sglang/srt/models/qwen2.py +1 -1
  45. sglang/srt/models/qwen2_moe.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/openai_api/adapter.py +151 -29
  48. sglang/srt/openai_api/protocol.py +7 -1
  49. sglang/srt/server.py +111 -84
  50. sglang/srt/server_args.py +12 -2
  51. sglang/srt/utils.py +25 -20
  52. sglang/test/run_eval.py +21 -10
  53. sglang/test/runners.py +237 -0
  54. sglang/test/simple_eval_common.py +12 -12
  55. sglang/test/simple_eval_gpqa.py +92 -0
  56. sglang/test/simple_eval_humaneval.py +5 -5
  57. sglang/test/simple_eval_math.py +72 -0
  58. sglang/test/test_utils.py +95 -14
  59. sglang/utils.py +15 -37
  60. sglang/version.py +1 -1
  61. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
  62. sglang-0.2.11.dist-info/RECORD +102 -0
  63. sglang-0.2.9.post1.dist-info/RECORD +0 -97
  64. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
  65. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
  66. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
35
 
36
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.managers.schedule_batch import InputMetadata
38
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
39
39
 
40
40
 
41
41
  class GPTBigCodeAttention(nn.Module):
sglang/srt/models/grok.py CHANGED
@@ -52,7 +52,7 @@ from vllm.utils import print_warning_once
52
52
  from sglang.srt.layers.fused_moe import fused_moe
53
53
  from sglang.srt.layers.logits_processor import LogitsProcessor
54
54
  from sglang.srt.layers.radix_attention import RadixAttention
55
- from sglang.srt.model_executor.model_runner import InputMetadata
55
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
56
56
 
57
57
  use_fused = True
58
58
 
@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
40
 
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.model_executor.model_runner import InputMetadata
43
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
44
 
45
45
 
46
46
  class InternLM2MLP(nn.Module):
@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
41
 
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.model_executor.model_runner import InputMetadata
44
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
45
 
46
46
 
47
47
  class LlamaMLP(nn.Module):
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
25
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
26
 
27
27
  from sglang.srt.layers.logits_processor import LogitProcessorOutput
28
- from sglang.srt.model_executor.model_runner import InputMetadata
28
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
29
29
  from sglang.srt.models.llama2 import LlamaModel
30
30
 
31
31
 
@@ -32,13 +32,12 @@ from vllm.config import CacheConfig
32
32
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
33
33
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
34
 
35
- from sglang.srt.managers.schedule_batch import ForwardMode
36
35
  from sglang.srt.mm_utils import (
37
36
  get_anyres_image_grid_shape,
38
37
  unpad_image,
39
38
  unpad_image_shape,
40
39
  )
41
- from sglang.srt.model_executor.model_runner import InputMetadata
40
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
42
41
  from sglang.srt.models.llama2 import LlamaForCausalLM
43
42
  from sglang.srt.models.mistral import MistralForCausalLM
44
43
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -26,13 +26,12 @@ from vllm.config import CacheConfig
26
26
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
27
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
28
 
29
- from sglang.srt.managers.schedule_batch import ForwardMode
30
29
  from sglang.srt.mm_utils import (
31
30
  get_anyres_image_grid_shape,
32
31
  unpad_image,
33
32
  unpad_image_shape,
34
33
  )
35
- from sglang.srt.model_executor.model_runner import InputMetadata
34
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
36
35
  from sglang.srt.models.llama2 import LlamaForCausalLM
37
36
 
38
37
 
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
39
 
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.model_executor.model_runner import InputMetadata
42
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
43
 
44
44
 
45
45
  class MiniCPMMLP(nn.Module):
@@ -50,7 +50,7 @@ from vllm.utils import print_warning_once
50
50
 
51
51
  from sglang.srt.layers.logits_processor import LogitsProcessor
52
52
  from sglang.srt.layers.radix_attention import RadixAttention
53
- from sglang.srt.model_executor.model_runner import InputMetadata
53
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
54
54
 
55
55
 
56
56
  class MixtralMoE(nn.Module):
@@ -45,7 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.model_executor.model_runner import InputMetadata
48
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
49
 
50
50
 
51
51
  class MixtralMLP(nn.Module):
sglang/srt/models/qwen.py CHANGED
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
39
 
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.model_executor.model_runner import InputMetadata
42
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
43
 
44
44
 
45
45
  class QWenMLP(nn.Module):
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
39
 
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.model_executor.model_runner import InputMetadata
42
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
43
 
44
44
  Qwen2Config = None
45
45
 
@@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
51
51
 
52
52
  from sglang.srt.layers.logits_processor import LogitsProcessor
53
53
  from sglang.srt.layers.radix_attention import RadixAttention
54
- from sglang.srt.model_executor.model_runner import InputMetadata
54
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
55
55
 
56
56
 
57
57
  class Qwen2MoeMLP(nn.Module):
@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
40
 
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.model_executor.model_runner import InputMetadata
43
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
44
 
45
45
 
46
46
  class StablelmMLP(nn.Module):
@@ -53,6 +53,7 @@ from sglang.srt.openai_api.protocol import (
53
53
  CompletionStreamResponse,
54
54
  DeltaMessage,
55
55
  ErrorResponse,
56
+ FileDeleteResponse,
56
57
  FileRequest,
57
58
  FileResponse,
58
59
  LogProbs,
@@ -174,6 +175,20 @@ async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str
174
175
  return {"error": "Invalid input", "details": e.errors()}
175
176
 
176
177
 
178
+ async def v1_delete_file(file_id: str):
179
+ # Retrieve the file job from the in-memory storage
180
+ file_response = file_id_response.get(file_id)
181
+ if file_response is None:
182
+ raise HTTPException(status_code=404, detail="File not found")
183
+ file_path = file_id_storage.get(file_id)
184
+ if file_path is None:
185
+ raise HTTPException(status_code=404, detail="File not found")
186
+ os.remove(file_path)
187
+ del file_id_response[file_id]
188
+ del file_id_storage[file_id]
189
+ return FileDeleteResponse(id=file_id, deleted=True)
190
+
191
+
177
192
  async def v1_batches(tokenizer_manager, raw_request: Request):
178
193
  try:
179
194
  body = await raw_request.json()
@@ -251,7 +266,9 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
251
266
  if end_point == "/v1/chat/completions":
252
267
  responses = v1_chat_generate_response(request, ret, to_file=True)
253
268
  else:
254
- responses = v1_generate_response(request, ret, to_file=True)
269
+ responses = v1_generate_response(
270
+ request, ret, tokenizer_manager, to_file=True
271
+ )
255
272
 
256
273
  except Exception as e:
257
274
  error_json = {
@@ -285,6 +302,13 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
285
302
  retrieve_batch = batch_storage[batch_id]
286
303
  retrieve_batch.output_file_id = output_file_id
287
304
  file_id_storage[output_file_id] = output_file_path
305
+ file_id_response[output_file_id] = FileResponse(
306
+ id=output_file_id,
307
+ bytes=os.path.getsize(output_file_path),
308
+ created_at=int(time.time()),
309
+ filename=f"{output_file_id}.jsonl",
310
+ purpose="batch_result",
311
+ )
288
312
  # Update batch status to "completed"
289
313
  retrieve_batch.status = "completed"
290
314
  retrieve_batch.completed_at = int(time.time())
@@ -339,6 +363,7 @@ def v1_generate_request(all_requests):
339
363
  return_logprobs = []
340
364
  top_logprobs_nums = []
341
365
  first_prompt_type = type(all_requests[0].prompt)
366
+
342
367
  for request in all_requests:
343
368
  prompt = request.prompt
344
369
  assert (
@@ -364,7 +389,7 @@ def v1_generate_request(all_requests):
364
389
  )
365
390
  if len(all_requests) > 1 and request.n > 1:
366
391
  raise ValueError(
367
- "Batch operation is not supported for completions from files"
392
+ "Parallel sampling is not supported for completions from files"
368
393
  )
369
394
 
370
395
  if len(all_requests) == 1:
@@ -381,6 +406,7 @@ def v1_generate_request(all_requests):
381
406
  prompt_kwargs = {"text": prompts}
382
407
  else:
383
408
  prompt_kwargs = {"input_ids": prompts}
409
+
384
410
  adapted_request = GenerateReqInput(
385
411
  **prompt_kwargs,
386
412
  sampling_params=sampling_params_list,
@@ -389,35 +415,52 @@ def v1_generate_request(all_requests):
389
415
  return_text_in_logprobs=True,
390
416
  stream=all_requests[0].stream,
391
417
  )
418
+
392
419
  if len(all_requests) == 1:
393
420
  return adapted_request, all_requests[0]
394
421
  return adapted_request, all_requests
395
422
 
396
423
 
397
- def v1_generate_response(request, ret, to_file=False):
424
+ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
398
425
  choices = []
399
426
  echo = False
400
427
 
401
- if (not isinstance(request, List)) and request.echo:
428
+ if (not isinstance(request, list)) and request.echo:
402
429
  # TODO: handle the case propmt is token ids
403
- if isinstance(request.prompt, list):
430
+ if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
431
+ # for the case of multiple str prompts
404
432
  prompts = request.prompt
433
+ elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
434
+ # for the case of multiple token ids prompts
435
+ prompts = [
436
+ tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
437
+ for prompt in request.prompt
438
+ ]
439
+ elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
440
+ # for the case of single token ids prompt
441
+ prompts = [
442
+ tokenizer_manager.tokenizer.decode(
443
+ request.prompt, skip_special_tokens=True
444
+ )
445
+ ]
405
446
  else:
447
+ # for the case of single str prompt
406
448
  prompts = [request.prompt]
407
449
  echo = True
408
450
 
409
451
  for idx, ret_item in enumerate(ret):
410
452
  text = ret_item["text"]
411
- if isinstance(request, List) and request[idx].echo:
453
+ if isinstance(request, list) and request[idx].echo:
412
454
  echo = True
413
455
  text = request[idx].prompt + text
414
- if (not isinstance(request, List)) and echo:
415
- text = prompts[idx] + text
456
+ if (not isinstance(request, list)) and echo:
457
+ prompt_index = idx // request.n
458
+ text = prompts[prompt_index] + text
416
459
 
417
460
  logprobs = False
418
- if isinstance(request, List) and request[idx].logprobs:
461
+ if isinstance(request, list) and request[idx].logprobs:
419
462
  logprobs = True
420
- elif (not isinstance(request, List)) and request.logprobs:
463
+ elif (not isinstance(request, list)) and request.logprobs:
421
464
  logprobs = True
422
465
  if logprobs:
423
466
  if echo:
@@ -479,15 +522,18 @@ def v1_generate_response(request, ret, to_file=False):
479
522
  responses.append(response)
480
523
  return responses
481
524
  else:
525
+ prompt_tokens = sum(
526
+ ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
527
+ )
482
528
  completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
483
529
  response = CompletionResponse(
484
530
  id=ret[0]["meta_info"]["id"],
485
531
  model=request.model,
486
532
  choices=choices,
487
533
  usage=UsageInfo(
488
- prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
534
+ prompt_tokens=prompt_tokens,
489
535
  completion_tokens=completion_tokens,
490
- total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
536
+ total_tokens=prompt_tokens + completion_tokens,
491
537
  ),
492
538
  )
493
539
  return response
@@ -513,8 +559,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
513
559
 
514
560
  if not stream_buffer: # The first chunk
515
561
  if request.echo:
562
+ if isinstance(request.prompt, str):
563
+ # for the case of single str prompts
564
+ prompts = request.prompt
565
+ elif isinstance(request.prompt, list) and isinstance(
566
+ request.prompt[0], int
567
+ ):
568
+ prompts = tokenizer_manager.tokenizer.decode(
569
+ request.prompt, skip_special_tokens=True
570
+ )
571
+
516
572
  # Prepend prompt in response text.
517
- text = request.prompt + text
573
+ text = prompts + text
518
574
 
519
575
  if request.logprobs:
520
576
  # The first chunk and echo is enabled.
@@ -539,7 +595,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
539
595
  "output_top_logprobs"
540
596
  ][n_prev_token:],
541
597
  )
542
-
543
598
  n_prev_token = len(
544
599
  content["meta_info"]["output_token_logprobs"]
545
600
  )
@@ -588,7 +643,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
588
643
  if not isinstance(ret, list):
589
644
  ret = [ret]
590
645
 
591
- response = v1_generate_response(request, ret)
646
+ response = v1_generate_response(request, ret, tokenizer_manager)
592
647
  return response
593
648
 
594
649
 
@@ -626,7 +681,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
626
681
  prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
627
682
  else:
628
683
  # Use the raw prompt and stop strings if the messages is already a string.
629
- prompt = request.messages
684
+ prompt_ids = request.messages
630
685
  stop = request.stop
631
686
  image_data = None
632
687
  input_ids.append(prompt_ids)
@@ -647,12 +702,21 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
647
702
  image_data_list.append(image_data)
648
703
  if len(all_requests) == 1:
649
704
  input_ids = input_ids[0]
705
+ if isinstance(input_ids, str):
706
+ prompt_kwargs = {"text": input_ids}
707
+ else:
708
+ prompt_kwargs = {"input_ids": input_ids}
650
709
  sampling_params_list = sampling_params_list[0]
651
710
  image_data = image_data_list[0]
652
711
  return_logprobs = return_logprobs[0]
653
712
  top_logprobs_nums = top_logprobs_nums[0]
713
+ else:
714
+ if isinstance(input_ids[0], str):
715
+ prompt_kwargs = {"text": input_ids}
716
+ else:
717
+ prompt_kwargs = {"input_ids": input_ids}
654
718
  adapted_request = GenerateReqInput(
655
- input_ids=input_ids,
719
+ **prompt_kwargs,
656
720
  image_data=image_data,
657
721
  sampling_params=sampling_params_list,
658
722
  return_logprob=return_logprobs,
@@ -667,14 +731,12 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
667
731
 
668
732
  def v1_chat_generate_response(request, ret, to_file=False):
669
733
  choices = []
670
- total_prompt_tokens = 0
671
- total_completion_tokens = 0
672
734
 
673
735
  for idx, ret_item in enumerate(ret):
674
736
  logprobs = False
675
- if isinstance(request, List) and request[idx].logprobs:
737
+ if isinstance(request, list) and request[idx].logprobs:
676
738
  logprobs = True
677
- elif (not isinstance(request, List)) and request.logprobs:
739
+ elif (not isinstance(request, list)) and request.logprobs:
678
740
  logprobs = True
679
741
  if logprobs:
680
742
  logprobs = to_openai_style_logprobs(
@@ -707,8 +769,6 @@ def v1_chat_generate_response(request, ret, to_file=False):
707
769
  choice_logprobs = ChoiceLogprobs(content=token_logprobs)
708
770
  else:
709
771
  choice_logprobs = None
710
- prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
711
- completion_tokens = ret_item["meta_info"]["completion_tokens"]
712
772
 
713
773
  if to_file:
714
774
  # to make the choice data json serializable
@@ -727,8 +787,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
727
787
  )
728
788
 
729
789
  choices.append(choice_data)
730
- total_prompt_tokens += prompt_tokens
731
- total_completion_tokens += completion_tokens
790
+
732
791
  if to_file:
733
792
  responses = []
734
793
 
@@ -755,14 +814,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
755
814
  responses.append(response)
756
815
  return responses
757
816
  else:
817
+ prompt_tokens = sum(
818
+ ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
819
+ )
820
+ completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
758
821
  response = ChatCompletionResponse(
759
822
  id=ret[0]["meta_info"]["id"],
760
823
  model=request.model,
761
824
  choices=choices,
762
825
  usage=UsageInfo(
763
- prompt_tokens=total_prompt_tokens,
764
- completion_tokens=total_completion_tokens,
765
- total_tokens=total_prompt_tokens + total_completion_tokens,
826
+ prompt_tokens=prompt_tokens,
827
+ completion_tokens=completion_tokens,
828
+ total_tokens=prompt_tokens + completion_tokens,
766
829
  ),
767
830
  )
768
831
  return response
@@ -779,10 +842,58 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
779
842
  is_first = True
780
843
 
781
844
  stream_buffer = ""
845
+ n_prev_token = 0
782
846
  try:
783
847
  async for content in tokenizer_manager.generate_request(
784
848
  adapted_request, raw_request
785
849
  ):
850
+ prompt_tokens = content["meta_info"]["prompt_tokens"]
851
+ completion_tokens = content["meta_info"]["completion_tokens"]
852
+ if request.logprobs:
853
+ logprobs = to_openai_style_logprobs(
854
+ output_token_logprobs=content["meta_info"][
855
+ "output_token_logprobs"
856
+ ][n_prev_token:],
857
+ output_top_logprobs=content["meta_info"][
858
+ "output_top_logprobs"
859
+ ][n_prev_token:],
860
+ )
861
+
862
+ n_prev_token = len(
863
+ content["meta_info"]["output_token_logprobs"]
864
+ )
865
+ token_logprobs = []
866
+ for token, logprob in zip(
867
+ logprobs.tokens, logprobs.token_logprobs
868
+ ):
869
+ token_bytes = list(token.encode("utf-8"))
870
+ top_logprobs = []
871
+ if logprobs.top_logprobs:
872
+ for top_token, top_logprob in logprobs.top_logprobs[
873
+ 0
874
+ ].items():
875
+ top_token_bytes = list(top_token.encode("utf-8"))
876
+ top_logprobs.append(
877
+ TopLogprob(
878
+ token=top_token,
879
+ bytes=top_token_bytes,
880
+ logprob=top_logprob,
881
+ )
882
+ )
883
+ token_logprobs.append(
884
+ ChatCompletionTokenLogprob(
885
+ token=token,
886
+ bytes=token_bytes,
887
+ logprob=logprob,
888
+ top_logprobs=top_logprobs,
889
+ )
890
+ )
891
+
892
+ choice_logprobs = ChoiceLogprobs(content=token_logprobs)
893
+
894
+ else:
895
+ choice_logprobs = None
896
+
786
897
  if is_first:
787
898
  # First chunk with role
788
899
  is_first = False
@@ -790,11 +901,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
790
901
  index=0,
791
902
  delta=DeltaMessage(role="assistant"),
792
903
  finish_reason=content["meta_info"]["finish_reason"],
904
+ logprobs=choice_logprobs,
793
905
  )
794
906
  chunk = ChatCompletionStreamResponse(
795
907
  id=content["meta_info"]["id"],
796
908
  choices=[choice_data],
797
909
  model=request.model,
910
+ usage=UsageInfo(
911
+ prompt_tokens=prompt_tokens,
912
+ completion_tokens=completion_tokens,
913
+ total_tokens=prompt_tokens + completion_tokens,
914
+ ),
798
915
  )
799
916
  yield f"data: {chunk.model_dump_json()}\n\n"
800
917
 
@@ -805,11 +922,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
805
922
  index=0,
806
923
  delta=DeltaMessage(content=delta),
807
924
  finish_reason=content["meta_info"]["finish_reason"],
925
+ logprobs=choice_logprobs,
808
926
  )
809
927
  chunk = ChatCompletionStreamResponse(
810
928
  id=content["meta_info"]["id"],
811
929
  choices=[choice_data],
812
930
  model=request.model,
931
+ usage=UsageInfo(
932
+ prompt_tokens=prompt_tokens,
933
+ completion_tokens=completion_tokens,
934
+ total_tokens=prompt_tokens + completion_tokens,
935
+ ),
813
936
  )
814
937
  yield f"data: {chunk.model_dump_json()}\n\n"
815
938
  except ValueError as e:
@@ -830,7 +953,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
830
953
  ).__anext__()
831
954
  except ValueError as e:
832
955
  return create_error_response(str(e))
833
-
834
956
  if not isinstance(ret, list):
835
957
  ret = [ret]
836
958
 
@@ -95,6 +95,12 @@ class FileResponse(BaseModel):
95
95
  purpose: str
96
96
 
97
97
 
98
+ class FileDeleteResponse(BaseModel):
99
+ id: str
100
+ object: str = "file"
101
+ deleted: bool
102
+
103
+
98
104
  class BatchRequest(BaseModel):
99
105
  input_file_id: (
100
106
  str # The ID of an uploaded file that contains requests for the new batch
@@ -278,7 +284,7 @@ class DeltaMessage(BaseModel):
278
284
  class ChatCompletionResponseStreamChoice(BaseModel):
279
285
  index: int
280
286
  delta: DeltaMessage
281
- logprobs: Optional[LogProbs] = None
287
+ logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
282
288
  finish_reason: Optional[str] = None
283
289
 
284
290