sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/conversation.py +50 -1
  11. sglang/srt/hf_transformers_utils.py +22 -23
  12. sglang/srt/layers/activation.py +24 -1
  13. sglang/srt/layers/decode_attention.py +338 -50
  14. sglang/srt/layers/fused_moe/layer.py +2 -2
  15. sglang/srt/layers/layernorm.py +3 -0
  16. sglang/srt/layers/logits_processor.py +60 -23
  17. sglang/srt/layers/radix_attention.py +3 -4
  18. sglang/srt/layers/sampler.py +154 -0
  19. sglang/srt/managers/controller_multi.py +2 -8
  20. sglang/srt/managers/controller_single.py +7 -10
  21. sglang/srt/managers/detokenizer_manager.py +20 -9
  22. sglang/srt/managers/io_struct.py +44 -11
  23. sglang/srt/managers/policy_scheduler.py +5 -2
  24. sglang/srt/managers/schedule_batch.py +52 -167
  25. sglang/srt/managers/tokenizer_manager.py +192 -83
  26. sglang/srt/managers/tp_worker.py +130 -43
  27. sglang/srt/mem_cache/memory_pool.py +82 -8
  28. sglang/srt/mm_utils.py +79 -7
  29. sglang/srt/model_executor/cuda_graph_runner.py +49 -11
  30. sglang/srt/model_executor/forward_batch_info.py +59 -27
  31. sglang/srt/model_executor/model_runner.py +210 -61
  32. sglang/srt/models/chatglm.py +4 -12
  33. sglang/srt/models/commandr.py +5 -1
  34. sglang/srt/models/dbrx.py +5 -1
  35. sglang/srt/models/deepseek.py +5 -1
  36. sglang/srt/models/deepseek_v2.py +5 -1
  37. sglang/srt/models/gemma.py +5 -1
  38. sglang/srt/models/gemma2.py +15 -7
  39. sglang/srt/models/gpt_bigcode.py +5 -1
  40. sglang/srt/models/grok.py +16 -2
  41. sglang/srt/models/internlm2.py +5 -1
  42. sglang/srt/models/llama2.py +7 -3
  43. sglang/srt/models/llama_classification.py +2 -2
  44. sglang/srt/models/llama_embedding.py +4 -0
  45. sglang/srt/models/llava.py +176 -59
  46. sglang/srt/models/minicpm.py +5 -1
  47. sglang/srt/models/mixtral.py +5 -1
  48. sglang/srt/models/mixtral_quant.py +5 -1
  49. sglang/srt/models/qwen.py +5 -2
  50. sglang/srt/models/qwen2.py +13 -3
  51. sglang/srt/models/qwen2_moe.py +5 -14
  52. sglang/srt/models/stablelm.py +5 -1
  53. sglang/srt/openai_api/adapter.py +117 -37
  54. sglang/srt/sampling/sampling_batch_info.py +209 -0
  55. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
  56. sglang/srt/server.py +84 -56
  57. sglang/srt/server_args.py +43 -15
  58. sglang/srt/utils.py +26 -16
  59. sglang/test/runners.py +23 -31
  60. sglang/test/simple_eval_common.py +9 -10
  61. sglang/test/simple_eval_gpqa.py +2 -1
  62. sglang/test/simple_eval_humaneval.py +2 -2
  63. sglang/test/simple_eval_math.py +2 -1
  64. sglang/test/simple_eval_mmlu.py +2 -1
  65. sglang/test/test_activation.py +55 -0
  66. sglang/test/test_utils.py +36 -53
  67. sglang/version.py +1 -1
  68. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
  69. sglang-0.2.14.dist-info/RECORD +114 -0
  70. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  71. sglang/launch_server_llavavid.py +0 -29
  72. sglang-0.2.13.dist-info/RECORD +0 -112
  73. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  74. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
45
  from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
+ from sglang.srt.layers.sampler import Sampler
48
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
50
 
50
51
 
@@ -333,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module):
333
334
  self.model = MixtralModel(config, quant_config=quant_config)
334
335
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
335
336
  self.logits_processor = LogitsProcessor(config)
337
+ self.sampler = Sampler()
336
338
 
337
339
  @torch.no_grad()
338
340
  def forward(
@@ -343,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module):
343
345
  input_embeds: torch.Tensor = None,
344
346
  ) -> torch.Tensor:
345
347
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
346
- return self.logits_processor(
348
+ logits_output = self.logits_processor(
347
349
  input_ids, hidden_states, self.lm_head.weight, input_metadata
348
350
  )
351
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
352
+ return sample_output, logits_output
349
353
 
350
354
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
351
355
  stacked_params_mapping = [
sglang/srt/models/qwen.py CHANGED
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.sampler import Sampler
42
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
44
 
44
45
 
@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module):
251
252
  vocab_size = ((config.vocab_size + 63) // 64) * 64
252
253
  self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
253
254
  self.logits_processor = LogitsProcessor(config)
255
+ self.sampler = Sampler()
254
256
 
255
257
  @torch.no_grad()
256
258
  def forward(
@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module):
260
262
  input_metadata: InputMetadata,
261
263
  ):
262
264
  hidden_states = self.transformer(input_ids, positions, input_metadata)
263
- next_tokens = self.logits_processor(
265
+ logits_output = self.logits_processor(
264
266
  input_ids, hidden_states, self.lm_head.weight, input_metadata
265
267
  )
266
- return next_tokens
268
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
269
+ return sample_output, logits_output
267
270
 
268
271
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
269
272
  stacked_params_mapping = [
@@ -38,7 +38,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
38
  from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
+ from sglang.srt.layers.pooler import Pooler, PoolingType
41
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
42
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
45
 
44
46
  Qwen2Config = None
@@ -275,6 +277,8 @@ class Qwen2ForCausalLM(nn.Module):
275
277
  self.model = Qwen2Model(config, quant_config=quant_config)
276
278
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
277
279
  self.logits_processor = LogitsProcessor(config)
280
+ self.sampler = Sampler()
281
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
278
282
 
279
283
  @torch.no_grad()
280
284
  def forward(
@@ -283,11 +287,17 @@ class Qwen2ForCausalLM(nn.Module):
283
287
  positions: torch.Tensor,
284
288
  input_metadata: InputMetadata,
285
289
  input_embeds: torch.Tensor = None,
290
+ get_embedding: bool = False,
286
291
  ) -> torch.Tensor:
287
292
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
288
- return self.logits_processor(
289
- input_ids, hidden_states, self.lm_head.weight, input_metadata
290
- )
293
+ if not get_embedding:
294
+ logits_output = self.logits_processor(
295
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
296
+ )
297
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
298
+ return sample_output, logits_output
299
+ else:
300
+ return self.pooler(hidden_states, input_metadata)
291
301
 
292
302
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
293
303
  stacked_params_mapping = [
@@ -35,10 +35,8 @@ from vllm.model_executor.layers.linear import (
35
35
  ReplicatedLinear,
36
36
  RowParallelLinear,
37
37
  )
38
- from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
38
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
40
39
  from vllm.model_executor.layers.rotary_embedding import get_rope
41
- from vllm.model_executor.layers.sampler import Sampler
42
40
  from vllm.model_executor.layers.vocab_parallel_embedding import (
43
41
  ParallelLMHead,
44
42
  VocabParallelEmbedding,
@@ -49,6 +47,7 @@ from sglang.srt.layers.activation import SiluAndMul
49
47
  from sglang.srt.layers.layernorm import RMSNorm
50
48
  from sglang.srt.layers.logits_processor import LogitsProcessor
51
49
  from sglang.srt.layers.radix_attention import RadixAttention
50
+ from sglang.srt.layers.sampler import Sampler
52
51
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
53
52
 
54
53
 
@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module):
366
365
  config.vocab_size, config.hidden_size, quant_config=quant_config
367
366
  )
368
367
  self.logits_processor = LogitsProcessor(config)
368
+ self.sampler = Sampler()
369
369
 
370
370
  @torch.no_grad()
371
371
  def forward(
@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module):
376
376
  input_embeds: torch.Tensor = None,
377
377
  ) -> torch.Tensor:
378
378
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
379
- return self.logits_processor(
379
+ logits_output = self.logits_processor(
380
380
  input_ids, hidden_states, self.lm_head.weight, input_metadata
381
381
  )
382
-
383
- def compute_logits(
384
- self,
385
- input_ids: torch.Tensor,
386
- hidden_states: torch.Tensor,
387
- input_metadata: InputMetadata,
388
- ) -> torch.Tensor:
389
- logits = self.logits_processor(
390
- input_ids, hidden_states, self.lm_head.weight, input_metadata
391
- )
392
- return logits
382
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
383
+ return sample_output, logits_output
393
384
 
394
385
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
395
386
  stacked_params_mapping = [
@@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
40
  from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
43
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
45
 
45
46
 
@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module):
249
250
  self.model = StableLMEpochModel(config, quant_config=quant_config)
250
251
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
251
252
  self.logits_processor = LogitsProcessor(config)
253
+ self.sampler = Sampler()
252
254
 
253
255
  @torch.no_grad()
254
256
  def forward(
@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module):
259
261
  input_embeds: torch.Tensor = None,
260
262
  ) -> torch.Tensor:
261
263
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
262
- return self.logits_processor(
264
+ logits_output = self.logits_processor(
263
265
  input_ids, hidden_states, self.lm_head.weight, input_metadata
264
266
  )
267
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
268
+ return sample_output, logits_output
265
269
 
266
270
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
267
271
  stacked_params_mapping = [
@@ -17,6 +17,7 @@ limitations under the License.
17
17
 
18
18
  import asyncio
19
19
  import json
20
+ import logging
20
21
  import os
21
22
  import time
22
23
  import uuid
@@ -64,6 +65,8 @@ from sglang.srt.openai_api.protocol import (
64
65
  UsageInfo,
65
66
  )
66
67
 
68
+ logger = logging.getLogger(__name__)
69
+
67
70
  chat_template_name = None
68
71
 
69
72
 
@@ -120,7 +123,7 @@ def create_streaming_error_response(
120
123
  def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
121
124
  global chat_template_name
122
125
 
123
- print(f"Use chat template: {chat_template_arg}")
126
+ logger.info(f"Use chat template: {chat_template_arg}")
124
127
  if not chat_template_exists(chat_template_arg):
125
128
  if not os.path.exists(chat_template_arg):
126
129
  raise RuntimeError(
@@ -276,6 +279,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
276
279
  request_data = json.loads(line)
277
280
  file_request_list.append(request_data)
278
281
  body = request_data["body"]
282
+
283
+ # Although streaming is supported for standalone completions, it is not supported in
284
+ # batch mode (multiple completions in single request).
285
+ if body.get("stream", False):
286
+ raise ValueError("Streaming requests are not supported in batch mode")
287
+
279
288
  if end_point == "/v1/chat/completions":
280
289
  all_requests.append(ChatCompletionRequest(**body))
281
290
  elif end_point == "/v1/completions":
@@ -346,7 +355,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
346
355
  }
347
356
 
348
357
  except Exception as e:
349
- print("error in SGLang:", e)
358
+ logger.error("error in SGLang:", e)
350
359
  # Update batch status to "failed"
351
360
  retrieve_batch = batch_storage[batch_id]
352
361
  retrieve_batch.status = "failed"
@@ -383,20 +392,33 @@ async def v1_retrieve_file_content(file_id: str):
383
392
  return StreamingResponse(iter_file(), media_type="application/octet-stream")
384
393
 
385
394
 
386
- def v1_generate_request(all_requests):
395
+ def v1_generate_request(all_requests: List[CompletionRequest]):
387
396
  prompts = []
388
397
  sampling_params_list = []
389
398
  return_logprobs = []
399
+ logprob_start_lens = []
390
400
  top_logprobs_nums = []
391
- first_prompt_type = type(all_requests[0].prompt)
392
401
 
402
+ # NOTE: with openai API, the prompt's logprobs are always not computed
403
+ first_prompt_type = type(all_requests[0].prompt)
393
404
  for request in all_requests:
394
- prompt = request.prompt
395
405
  assert (
396
- type(prompt) == first_prompt_type
406
+ type(request.prompt) == first_prompt_type
397
407
  ), "All prompts must be of the same type in file input settings"
398
- prompts.append(prompt)
408
+ if len(all_requests) > 1 and request.n > 1:
409
+ raise ValueError(
410
+ "Parallel sampling is not supported for completions from files"
411
+ )
412
+ if request.echo and request.logprobs:
413
+ logger.warning(
414
+ "Echo is not compatible with logprobs. "
415
+ "To compute logprobs of input prompt, please use SGLang /request API."
416
+ )
417
+
418
+ for request in all_requests:
419
+ prompts.append(request.prompt)
399
420
  return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
421
+ logprob_start_lens.append(-1)
400
422
  top_logprobs_nums.append(
401
423
  request.logprobs if request.logprobs is not None else 0
402
424
  )
@@ -416,14 +438,11 @@ def v1_generate_request(all_requests):
416
438
  "ignore_eos": request.ignore_eos,
417
439
  }
418
440
  )
419
- if len(all_requests) > 1 and request.n > 1:
420
- raise ValueError(
421
- "Parallel sampling is not supported for completions from files"
422
- )
423
441
 
424
442
  if len(all_requests) == 1:
425
443
  prompt = prompts[0]
426
444
  sampling_params_list = sampling_params_list[0]
445
+ logprob_start_lens = logprob_start_lens[0]
427
446
  return_logprobs = return_logprobs[0]
428
447
  top_logprobs_nums = top_logprobs_nums[0]
429
448
  if isinstance(prompt, str) or isinstance(prompt[0], str):
@@ -441,6 +460,7 @@ def v1_generate_request(all_requests):
441
460
  sampling_params=sampling_params_list,
442
461
  return_logprob=return_logprobs,
443
462
  top_logprobs_num=top_logprobs_nums,
463
+ logprob_start_len=logprob_start_lens,
444
464
  return_text_in_logprobs=True,
445
465
  stream=all_requests[0].stream,
446
466
  )
@@ -580,27 +600,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
580
600
  if adapted_request.stream:
581
601
 
582
602
  async def generate_stream_resp():
583
- stream_buffer = ""
584
- n_prev_token = 0
603
+ stream_buffers = {}
604
+ n_prev_tokens = {}
605
+ prompt_tokens = {}
606
+ completion_tokens = {}
585
607
  try:
586
608
  async for content in tokenizer_manager.generate_request(
587
609
  adapted_request, raw_request
588
610
  ):
611
+ index = content["index"]
612
+
613
+ stream_buffer = stream_buffers.get(index, "")
614
+ n_prev_token = n_prev_tokens.get(index, 0)
615
+
589
616
  text = content["text"]
590
- prompt_tokens = content["meta_info"]["prompt_tokens"]
591
- completion_tokens = content["meta_info"]["completion_tokens"]
617
+ prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
618
+ completion_tokens[index] = content["meta_info"]["completion_tokens"]
592
619
 
593
620
  if not stream_buffer: # The first chunk
594
621
  if request.echo:
595
622
  if isinstance(request.prompt, str):
596
623
  # for the case of single str prompts
597
624
  prompts = request.prompt
598
- elif isinstance(request.prompt, list) and isinstance(
599
- request.prompt[0], int
600
- ):
601
- prompts = tokenizer_manager.tokenizer.decode(
602
- request.prompt, skip_special_tokens=True
603
- )
625
+ elif isinstance(request.prompt, list):
626
+ if isinstance(request.prompt[0], str):
627
+ # for the case of multiple str prompts
628
+ prompts = request.prompt[index // request.n]
629
+ elif isinstance(request.prompt[0], int):
630
+ # for the case of single token ids prompt
631
+ prompts = tokenizer_manager.tokenizer.decode(
632
+ request.prompt, skip_special_tokens=True
633
+ )
634
+ elif isinstance(request.prompt[0], list) and isinstance(
635
+ request.prompt[0][0], int
636
+ ):
637
+ # for the case of multiple token ids prompts
638
+ prompts = tokenizer_manager.tokenizer.decode(
639
+ request.prompt[index // request.n],
640
+ skip_special_tokens=True,
641
+ )
604
642
 
605
643
  # Prepend prompt in response text.
606
644
  text = prompts + text
@@ -637,7 +675,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
637
675
  delta = text[len(stream_buffer) :]
638
676
  stream_buffer = stream_buffer + delta
639
677
  choice_data = CompletionResponseStreamChoice(
640
- index=0,
678
+ index=index,
641
679
  text=delta,
642
680
  logprobs=logprobs,
643
681
  finish_reason=format_finish_reason(
@@ -650,12 +688,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
650
688
  choices=[choice_data],
651
689
  model=request.model,
652
690
  )
691
+
692
+ stream_buffers[index] = stream_buffer
693
+ n_prev_tokens[index] = n_prev_token
694
+
653
695
  yield f"data: {chunk.model_dump_json()}\n\n"
654
696
  if request.stream_options and request.stream_options.include_usage:
697
+ total_prompt_tokens = sum(
698
+ tokens
699
+ for i, tokens in prompt_tokens.items()
700
+ if i % request.n == 0
701
+ )
702
+ total_completion_tokens = sum(
703
+ tokens for tokens in completion_tokens.values()
704
+ )
655
705
  usage = UsageInfo(
656
- prompt_tokens=prompt_tokens,
657
- completion_tokens=completion_tokens,
658
- total_tokens=prompt_tokens + completion_tokens,
706
+ prompt_tokens=total_prompt_tokens,
707
+ completion_tokens=total_completion_tokens,
708
+ total_tokens=total_prompt_tokens + total_completion_tokens,
659
709
  )
660
710
 
661
711
  final_usage_chunk = CompletionStreamResponse(
@@ -694,12 +744,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
694
744
  return response
695
745
 
696
746
 
697
- def v1_chat_generate_request(all_requests, tokenizer_manager):
747
+ def v1_chat_generate_request(
748
+ all_requests: List[ChatCompletionRequest], tokenizer_manager
749
+ ):
698
750
  input_ids = []
699
751
  sampling_params_list = []
700
752
  image_data_list = []
701
753
  return_logprobs = []
754
+ logprob_start_lens = []
702
755
  top_logprobs_nums = []
756
+
757
+ # NOTE: with openai API, the prompt's logprobs are always not computed
758
+
703
759
  for request in all_requests:
704
760
  # Prep the data needed for the underlying GenerateReqInput:
705
761
  # - prompt: The full prompt string.
@@ -732,6 +788,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
732
788
  image_data = None
733
789
  input_ids.append(prompt_ids)
734
790
  return_logprobs.append(request.logprobs)
791
+ logprob_start_lens.append(-1)
735
792
  top_logprobs_nums.append(request.top_logprobs)
736
793
  sampling_params_list.append(
737
794
  {
@@ -758,17 +815,20 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
758
815
  sampling_params_list = sampling_params_list[0]
759
816
  image_data = image_data_list[0]
760
817
  return_logprobs = return_logprobs[0]
818
+ logprob_start_lens = logprob_start_lens[0]
761
819
  top_logprobs_nums = top_logprobs_nums[0]
762
820
  else:
763
821
  if isinstance(input_ids[0], str):
764
822
  prompt_kwargs = {"text": input_ids}
765
823
  else:
766
824
  prompt_kwargs = {"input_ids": input_ids}
825
+
767
826
  adapted_request = GenerateReqInput(
768
827
  **prompt_kwargs,
769
828
  image_data=image_data,
770
829
  sampling_params=sampling_params_list,
771
830
  return_logprob=return_logprobs,
831
+ logprob_start_len=logprob_start_lens,
772
832
  top_logprobs_num=top_logprobs_nums,
773
833
  stream=all_requests[0].stream,
774
834
  return_text_in_logprobs=True,
@@ -892,16 +952,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
892
952
  if adapted_request.stream:
893
953
 
894
954
  async def generate_stream_resp():
895
- is_first = True
896
-
897
- stream_buffer = ""
898
- n_prev_token = 0
955
+ is_firsts = {}
956
+ stream_buffers = {}
957
+ n_prev_tokens = {}
958
+ prompt_tokens = {}
959
+ completion_tokens = {}
899
960
  try:
900
961
  async for content in tokenizer_manager.generate_request(
901
962
  adapted_request, raw_request
902
963
  ):
903
- prompt_tokens = content["meta_info"]["prompt_tokens"]
904
- completion_tokens = content["meta_info"]["completion_tokens"]
964
+ index = content["index"]
965
+
966
+ is_first = is_firsts.get(index, True)
967
+ stream_buffer = stream_buffers.get(index, "")
968
+ n_prev_token = n_prev_tokens.get(index, 0)
969
+
970
+ prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
971
+ completion_tokens[index] = content["meta_info"]["completion_tokens"]
905
972
  if request.logprobs:
906
973
  logprobs = to_openai_style_logprobs(
907
974
  output_token_logprobs=content["meta_info"][
@@ -951,7 +1018,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
951
1018
  # First chunk with role
952
1019
  is_first = False
953
1020
  choice_data = ChatCompletionResponseStreamChoice(
954
- index=0,
1021
+ index=index,
955
1022
  delta=DeltaMessage(role="assistant"),
956
1023
  finish_reason=format_finish_reason(
957
1024
  content["meta_info"]["finish_reason"]
@@ -969,7 +1036,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
969
1036
  delta = text[len(stream_buffer) :]
970
1037
  stream_buffer = stream_buffer + delta
971
1038
  choice_data = ChatCompletionResponseStreamChoice(
972
- index=0,
1039
+ index=index,
973
1040
  delta=DeltaMessage(content=delta),
974
1041
  finish_reason=format_finish_reason(
975
1042
  content["meta_info"]["finish_reason"]
@@ -981,12 +1048,25 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
981
1048
  choices=[choice_data],
982
1049
  model=request.model,
983
1050
  )
1051
+
1052
+ is_firsts[index] = is_first
1053
+ stream_buffers[index] = stream_buffer
1054
+ n_prev_tokens[index] = n_prev_token
1055
+
984
1056
  yield f"data: {chunk.model_dump_json()}\n\n"
985
1057
  if request.stream_options and request.stream_options.include_usage:
1058
+ total_prompt_tokens = sum(
1059
+ tokens
1060
+ for i, tokens in prompt_tokens.items()
1061
+ if i % request.n == 0
1062
+ )
1063
+ total_completion_tokens = sum(
1064
+ tokens for tokens in completion_tokens.values()
1065
+ )
986
1066
  usage = UsageInfo(
987
- prompt_tokens=prompt_tokens,
988
- completion_tokens=completion_tokens,
989
- total_tokens=prompt_tokens + completion_tokens,
1067
+ prompt_tokens=total_prompt_tokens,
1068
+ completion_tokens=total_completion_tokens,
1069
+ total_tokens=total_prompt_tokens + total_completion_tokens,
990
1070
  )
991
1071
 
992
1072
  final_usage_chunk = ChatCompletionStreamResponse(