sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -28,27 +28,26 @@ from vllm.distributed import (
28
28
  get_tensor_model_parallel_world_size,
29
29
  tensor_model_parallel_all_reduce,
30
30
  )
31
- from vllm.model_executor.layers.activation import SiluAndMul
32
31
  from vllm.model_executor.layers.fused_moe import FusedMoE
33
- from vllm.model_executor.layers.layernorm import RMSNorm
34
32
  from vllm.model_executor.layers.linear import (
35
33
  MergedColumnParallelLinear,
36
34
  QKVParallelLinear,
37
35
  ReplicatedLinear,
38
36
  RowParallelLinear,
39
37
  )
40
- from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
38
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
42
39
  from vllm.model_executor.layers.rotary_embedding import get_rope
43
- from vllm.model_executor.layers.sampler import Sampler
44
40
  from vllm.model_executor.layers.vocab_parallel_embedding import (
45
41
  ParallelLMHead,
46
42
  VocabParallelEmbedding,
47
43
  )
48
44
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
49
45
 
46
+ from sglang.srt.layers.activation import SiluAndMul
47
+ from sglang.srt.layers.layernorm import RMSNorm
50
48
  from sglang.srt.layers.logits_processor import LogitsProcessor
51
49
  from sglang.srt.layers.radix_attention import RadixAttention
50
+ from sglang.srt.layers.sampler import Sampler
52
51
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
53
52
 
54
53
 
@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module):
366
365
  config.vocab_size, config.hidden_size, quant_config=quant_config
367
366
  )
368
367
  self.logits_processor = LogitsProcessor(config)
368
+ self.sampler = Sampler()
369
369
 
370
370
  @torch.no_grad()
371
371
  def forward(
@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module):
376
376
  input_embeds: torch.Tensor = None,
377
377
  ) -> torch.Tensor:
378
378
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
379
- return self.logits_processor(
379
+ logits_output = self.logits_processor(
380
380
  input_ids, hidden_states, self.lm_head.weight, input_metadata
381
381
  )
382
-
383
- def compute_logits(
384
- self,
385
- input_ids: torch.Tensor,
386
- hidden_states: torch.Tensor,
387
- input_metadata: InputMetadata,
388
- ) -> torch.Tensor:
389
- logits = self.logits_processor(
390
- input_ids, hidden_states, self.lm_head.weight, input_metadata
391
- )
392
- return logits
382
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
383
+ return sample_output, logits_output
393
384
 
394
385
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
395
386
  stacked_params_mapping = [
@@ -24,7 +24,6 @@ from torch import nn
24
24
  from transformers import PretrainedConfig
25
25
  from vllm.config import CacheConfig
26
26
  from vllm.distributed import get_tensor_model_parallel_world_size
27
- from vllm.model_executor.layers.activation import SiluAndMul
28
27
  from vllm.model_executor.layers.linear import (
29
28
  MergedColumnParallelLinear,
30
29
  QKVParallelLinear,
@@ -38,8 +37,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
38
37
  )
39
38
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
39
 
40
+ from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
43
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
45
 
45
46
 
@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module):
249
250
  self.model = StableLMEpochModel(config, quant_config=quant_config)
250
251
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
251
252
  self.logits_processor = LogitsProcessor(config)
253
+ self.sampler = Sampler()
252
254
 
253
255
  @torch.no_grad()
254
256
  def forward(
@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module):
259
261
  input_embeds: torch.Tensor = None,
260
262
  ) -> torch.Tensor:
261
263
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
262
- return self.logits_processor(
264
+ logits_output = self.logits_processor(
263
265
  input_ids, hidden_states, self.lm_head.weight, input_metadata
264
266
  )
267
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
268
+ return sample_output, logits_output
265
269
 
266
270
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
267
271
  stacked_params_mapping = [
@@ -17,6 +17,7 @@ limitations under the License.
17
17
 
18
18
  import asyncio
19
19
  import json
20
+ import logging
20
21
  import os
21
22
  import time
22
23
  import uuid
@@ -64,6 +65,8 @@ from sglang.srt.openai_api.protocol import (
64
65
  UsageInfo,
65
66
  )
66
67
 
68
+ logger = logging.getLogger(__name__)
69
+
67
70
  chat_template_name = None
68
71
 
69
72
 
@@ -117,37 +120,48 @@ def create_streaming_error_response(
117
120
  return json_str
118
121
 
119
122
 
120
- def load_chat_template_for_openai_api(chat_template_arg):
123
+ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
121
124
  global chat_template_name
122
125
 
123
- print(f"Use chat template: {chat_template_arg}")
126
+ logger.info(f"Use chat template: {chat_template_arg}")
124
127
  if not chat_template_exists(chat_template_arg):
125
128
  if not os.path.exists(chat_template_arg):
126
129
  raise RuntimeError(
127
130
  f"Chat template {chat_template_arg} is not a built-in template name "
128
131
  "or a valid chat template file path."
129
132
  )
130
- with open(chat_template_arg, "r") as filep:
131
- template = json.load(filep)
132
- try:
133
- sep_style = SeparatorStyle[template["sep_style"]]
134
- except KeyError:
135
- raise ValueError(
136
- f"Unknown separator style: {template['sep_style']}"
137
- ) from None
138
- register_conv_template(
139
- Conversation(
140
- name=template["name"],
141
- system_template=template["system"] + "\n{system_message}",
142
- system_message=template.get("system_message", ""),
143
- roles=(template["user"], template["assistant"]),
144
- sep_style=sep_style,
145
- sep=template.get("sep", "\n"),
146
- stop_str=template["stop_str"],
147
- ),
148
- override=True,
133
+ if chat_template_arg.endswith(".jinja"):
134
+ with open(chat_template_arg, "r") as f:
135
+ chat_template = "".join(f.readlines()).strip("\n")
136
+ tokenizer_manager.tokenizer.chat_template = chat_template.replace(
137
+ "\\n", "\n"
149
138
  )
150
- chat_template_name = template["name"]
139
+ chat_template_name = None
140
+ else:
141
+ assert chat_template_arg.endswith(
142
+ ".json"
143
+ ), "unrecognized format of chat template file"
144
+ with open(chat_template_arg, "r") as filep:
145
+ template = json.load(filep)
146
+ try:
147
+ sep_style = SeparatorStyle[template["sep_style"]]
148
+ except KeyError:
149
+ raise ValueError(
150
+ f"Unknown separator style: {template['sep_style']}"
151
+ ) from None
152
+ register_conv_template(
153
+ Conversation(
154
+ name=template["name"],
155
+ system_template=template["system"] + "\n{system_message}",
156
+ system_message=template.get("system_message", ""),
157
+ roles=(template["user"], template["assistant"]),
158
+ sep_style=sep_style,
159
+ sep=template.get("sep", "\n"),
160
+ stop_str=template["stop_str"],
161
+ ),
162
+ override=True,
163
+ )
164
+ chat_template_name = template["name"]
151
165
  else:
152
166
  chat_template_name = chat_template_arg
153
167
 
@@ -265,6 +279,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
265
279
  request_data = json.loads(line)
266
280
  file_request_list.append(request_data)
267
281
  body = request_data["body"]
282
+
283
+ # Although streaming is supported for standalone completions, it is not supported in
284
+ # batch mode (multiple completions in single request).
285
+ if body.get("stream", False):
286
+ raise ValueError("Streaming requests are not supported in batch mode")
287
+
268
288
  if end_point == "/v1/chat/completions":
269
289
  all_requests.append(ChatCompletionRequest(**body))
270
290
  elif end_point == "/v1/completions":
@@ -335,7 +355,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
335
355
  }
336
356
 
337
357
  except Exception as e:
338
- print("error in SGLang:", e)
358
+ logger.error("error in SGLang:", e)
339
359
  # Update batch status to "failed"
340
360
  retrieve_batch = batch_storage[batch_id]
341
361
  retrieve_batch.status = "failed"
@@ -372,20 +392,33 @@ async def v1_retrieve_file_content(file_id: str):
372
392
  return StreamingResponse(iter_file(), media_type="application/octet-stream")
373
393
 
374
394
 
375
- def v1_generate_request(all_requests):
395
+ def v1_generate_request(all_requests: List[CompletionRequest]):
376
396
  prompts = []
377
397
  sampling_params_list = []
378
398
  return_logprobs = []
399
+ logprob_start_lens = []
379
400
  top_logprobs_nums = []
380
- first_prompt_type = type(all_requests[0].prompt)
381
401
 
402
+ # NOTE: with openai API, the prompt's logprobs are always not computed
403
+ first_prompt_type = type(all_requests[0].prompt)
382
404
  for request in all_requests:
383
- prompt = request.prompt
384
405
  assert (
385
- type(prompt) == first_prompt_type
406
+ type(request.prompt) == first_prompt_type
386
407
  ), "All prompts must be of the same type in file input settings"
387
- prompts.append(prompt)
408
+ if len(all_requests) > 1 and request.n > 1:
409
+ raise ValueError(
410
+ "Parallel sampling is not supported for completions from files"
411
+ )
412
+ if request.echo and request.logprobs:
413
+ logger.warning(
414
+ "Echo is not compatible with logprobs. "
415
+ "To compute logprobs of input prompt, please use SGLang /request API."
416
+ )
417
+
418
+ for request in all_requests:
419
+ prompts.append(request.prompt)
388
420
  return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
421
+ logprob_start_lens.append(-1)
389
422
  top_logprobs_nums.append(
390
423
  request.logprobs if request.logprobs is not None else 0
391
424
  )
@@ -405,14 +438,11 @@ def v1_generate_request(all_requests):
405
438
  "ignore_eos": request.ignore_eos,
406
439
  }
407
440
  )
408
- if len(all_requests) > 1 and request.n > 1:
409
- raise ValueError(
410
- "Parallel sampling is not supported for completions from files"
411
- )
412
441
 
413
442
  if len(all_requests) == 1:
414
443
  prompt = prompts[0]
415
444
  sampling_params_list = sampling_params_list[0]
445
+ logprob_start_lens = logprob_start_lens[0]
416
446
  return_logprobs = return_logprobs[0]
417
447
  top_logprobs_nums = top_logprobs_nums[0]
418
448
  if isinstance(prompt, str) or isinstance(prompt[0], str):
@@ -430,6 +460,7 @@ def v1_generate_request(all_requests):
430
460
  sampling_params=sampling_params_list,
431
461
  return_logprob=return_logprobs,
432
462
  top_logprobs_num=top_logprobs_nums,
463
+ logprob_start_len=logprob_start_lens,
433
464
  return_text_in_logprobs=True,
434
465
  stream=all_requests[0].stream,
435
466
  )
@@ -569,27 +600,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
569
600
  if adapted_request.stream:
570
601
 
571
602
  async def generate_stream_resp():
572
- stream_buffer = ""
573
- n_prev_token = 0
603
+ stream_buffers = {}
604
+ n_prev_tokens = {}
605
+ prompt_tokens = {}
606
+ completion_tokens = {}
574
607
  try:
575
608
  async for content in tokenizer_manager.generate_request(
576
609
  adapted_request, raw_request
577
610
  ):
611
+ index = content["index"]
612
+
613
+ stream_buffer = stream_buffers.get(index, "")
614
+ n_prev_token = n_prev_tokens.get(index, 0)
615
+
578
616
  text = content["text"]
579
- prompt_tokens = content["meta_info"]["prompt_tokens"]
580
- completion_tokens = content["meta_info"]["completion_tokens"]
617
+ prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
618
+ completion_tokens[index] = content["meta_info"]["completion_tokens"]
581
619
 
582
620
  if not stream_buffer: # The first chunk
583
621
  if request.echo:
584
622
  if isinstance(request.prompt, str):
585
623
  # for the case of single str prompts
586
624
  prompts = request.prompt
587
- elif isinstance(request.prompt, list) and isinstance(
588
- request.prompt[0], int
589
- ):
590
- prompts = tokenizer_manager.tokenizer.decode(
591
- request.prompt, skip_special_tokens=True
592
- )
625
+ elif isinstance(request.prompt, list):
626
+ if isinstance(request.prompt[0], str):
627
+ # for the case of multiple str prompts
628
+ prompts = request.prompt[index // request.n]
629
+ elif isinstance(request.prompt[0], int):
630
+ # for the case of single token ids prompt
631
+ prompts = tokenizer_manager.tokenizer.decode(
632
+ request.prompt, skip_special_tokens=True
633
+ )
634
+ elif isinstance(request.prompt[0], list) and isinstance(
635
+ request.prompt[0][0], int
636
+ ):
637
+ # for the case of multiple token ids prompts
638
+ prompts = tokenizer_manager.tokenizer.decode(
639
+ request.prompt[index // request.n],
640
+ skip_special_tokens=True,
641
+ )
593
642
 
594
643
  # Prepend prompt in response text.
595
644
  text = prompts + text
@@ -626,7 +675,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
626
675
  delta = text[len(stream_buffer) :]
627
676
  stream_buffer = stream_buffer + delta
628
677
  choice_data = CompletionResponseStreamChoice(
629
- index=0,
678
+ index=index,
630
679
  text=delta,
631
680
  logprobs=logprobs,
632
681
  finish_reason=format_finish_reason(
@@ -639,12 +688,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
639
688
  choices=[choice_data],
640
689
  model=request.model,
641
690
  )
691
+
692
+ stream_buffers[index] = stream_buffer
693
+ n_prev_tokens[index] = n_prev_token
694
+
642
695
  yield f"data: {chunk.model_dump_json()}\n\n"
643
696
  if request.stream_options and request.stream_options.include_usage:
697
+ total_prompt_tokens = sum(
698
+ tokens
699
+ for i, tokens in prompt_tokens.items()
700
+ if i % request.n == 0
701
+ )
702
+ total_completion_tokens = sum(
703
+ tokens for tokens in completion_tokens.values()
704
+ )
644
705
  usage = UsageInfo(
645
- prompt_tokens=prompt_tokens,
646
- completion_tokens=completion_tokens,
647
- total_tokens=prompt_tokens + completion_tokens,
706
+ prompt_tokens=total_prompt_tokens,
707
+ completion_tokens=total_completion_tokens,
708
+ total_tokens=total_prompt_tokens + total_completion_tokens,
648
709
  )
649
710
 
650
711
  final_usage_chunk = CompletionStreamResponse(
@@ -683,12 +744,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
683
744
  return response
684
745
 
685
746
 
686
- def v1_chat_generate_request(all_requests, tokenizer_manager):
747
+ def v1_chat_generate_request(
748
+ all_requests: List[ChatCompletionRequest], tokenizer_manager
749
+ ):
687
750
  input_ids = []
688
751
  sampling_params_list = []
689
752
  image_data_list = []
690
753
  return_logprobs = []
754
+ logprob_start_lens = []
691
755
  top_logprobs_nums = []
756
+
757
+ # NOTE: with openai API, the prompt's logprobs are always not computed
758
+
692
759
  for request in all_requests:
693
760
  # Prep the data needed for the underlying GenerateReqInput:
694
761
  # - prompt: The full prompt string.
@@ -721,6 +788,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
721
788
  image_data = None
722
789
  input_ids.append(prompt_ids)
723
790
  return_logprobs.append(request.logprobs)
791
+ logprob_start_lens.append(-1)
724
792
  top_logprobs_nums.append(request.top_logprobs)
725
793
  sampling_params_list.append(
726
794
  {
@@ -747,17 +815,20 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
747
815
  sampling_params_list = sampling_params_list[0]
748
816
  image_data = image_data_list[0]
749
817
  return_logprobs = return_logprobs[0]
818
+ logprob_start_lens = logprob_start_lens[0]
750
819
  top_logprobs_nums = top_logprobs_nums[0]
751
820
  else:
752
821
  if isinstance(input_ids[0], str):
753
822
  prompt_kwargs = {"text": input_ids}
754
823
  else:
755
824
  prompt_kwargs = {"input_ids": input_ids}
825
+
756
826
  adapted_request = GenerateReqInput(
757
827
  **prompt_kwargs,
758
828
  image_data=image_data,
759
829
  sampling_params=sampling_params_list,
760
830
  return_logprob=return_logprobs,
831
+ logprob_start_len=logprob_start_lens,
761
832
  top_logprobs_num=top_logprobs_nums,
762
833
  stream=all_requests[0].stream,
763
834
  return_text_in_logprobs=True,
@@ -881,16 +952,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
881
952
  if adapted_request.stream:
882
953
 
883
954
  async def generate_stream_resp():
884
- is_first = True
885
-
886
- stream_buffer = ""
887
- n_prev_token = 0
955
+ is_firsts = {}
956
+ stream_buffers = {}
957
+ n_prev_tokens = {}
958
+ prompt_tokens = {}
959
+ completion_tokens = {}
888
960
  try:
889
961
  async for content in tokenizer_manager.generate_request(
890
962
  adapted_request, raw_request
891
963
  ):
892
- prompt_tokens = content["meta_info"]["prompt_tokens"]
893
- completion_tokens = content["meta_info"]["completion_tokens"]
964
+ index = content["index"]
965
+
966
+ is_first = is_firsts.get(index, True)
967
+ stream_buffer = stream_buffers.get(index, "")
968
+ n_prev_token = n_prev_tokens.get(index, 0)
969
+
970
+ prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
971
+ completion_tokens[index] = content["meta_info"]["completion_tokens"]
894
972
  if request.logprobs:
895
973
  logprobs = to_openai_style_logprobs(
896
974
  output_token_logprobs=content["meta_info"][
@@ -940,7 +1018,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
940
1018
  # First chunk with role
941
1019
  is_first = False
942
1020
  choice_data = ChatCompletionResponseStreamChoice(
943
- index=0,
1021
+ index=index,
944
1022
  delta=DeltaMessage(role="assistant"),
945
1023
  finish_reason=format_finish_reason(
946
1024
  content["meta_info"]["finish_reason"]
@@ -958,7 +1036,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
958
1036
  delta = text[len(stream_buffer) :]
959
1037
  stream_buffer = stream_buffer + delta
960
1038
  choice_data = ChatCompletionResponseStreamChoice(
961
- index=0,
1039
+ index=index,
962
1040
  delta=DeltaMessage(content=delta),
963
1041
  finish_reason=format_finish_reason(
964
1042
  content["meta_info"]["finish_reason"]
@@ -970,12 +1048,25 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
970
1048
  choices=[choice_data],
971
1049
  model=request.model,
972
1050
  )
1051
+
1052
+ is_firsts[index] = is_first
1053
+ stream_buffers[index] = stream_buffer
1054
+ n_prev_tokens[index] = n_prev_token
1055
+
973
1056
  yield f"data: {chunk.model_dump_json()}\n\n"
974
1057
  if request.stream_options and request.stream_options.include_usage:
1058
+ total_prompt_tokens = sum(
1059
+ tokens
1060
+ for i, tokens in prompt_tokens.items()
1061
+ if i % request.n == 0
1062
+ )
1063
+ total_completion_tokens = sum(
1064
+ tokens for tokens in completion_tokens.values()
1065
+ )
975
1066
  usage = UsageInfo(
976
- prompt_tokens=prompt_tokens,
977
- completion_tokens=completion_tokens,
978
- total_tokens=prompt_tokens + completion_tokens,
1067
+ prompt_tokens=total_prompt_tokens,
1068
+ completion_tokens=total_completion_tokens,
1069
+ total_tokens=total_prompt_tokens + total_completion_tokens,
979
1070
  )
980
1071
 
981
1072
  final_usage_chunk = ChatCompletionStreamResponse(