sglang 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +33 -26
  2. sglang/api.py +9 -1
  3. sglang/bench_latency.py +2 -2
  4. sglang/bench_serving.py +10 -1
  5. sglang/check_env.py +1 -1
  6. sglang/lang/backend/litellm.py +1 -1
  7. sglang/lang/backend/openai.py +1 -1
  8. sglang/lang/backend/runtime_endpoint.py +4 -4
  9. sglang/lang/interpreter.py +24 -9
  10. sglang/lang/ir.py +1 -1
  11. sglang/srt/constrained/__init__.py +15 -0
  12. sglang/srt/constrained/base_cache.py +15 -0
  13. sglang/srt/constrained/fsm_cache.py +36 -1
  14. sglang/srt/constrained/jump_forward.py +15 -0
  15. sglang/srt/conversation.py +26 -0
  16. sglang/srt/hf_transformers_utils.py +18 -1
  17. sglang/srt/layers/context_flashattention_nopad.py +15 -0
  18. sglang/srt/layers/extend_attention.py +15 -0
  19. sglang/srt/layers/fused_moe.py +15 -0
  20. sglang/srt/layers/linear.py +15 -0
  21. sglang/srt/layers/logits_processor.py +109 -72
  22. sglang/srt/layers/quantization/__init__.py +15 -0
  23. sglang/srt/layers/quantization/fp8.py +15 -0
  24. sglang/srt/layers/radix_attention.py +21 -3
  25. sglang/srt/layers/token_attention.py +16 -1
  26. sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
  27. sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
  28. sglang/srt/managers/detokenizer_manager.py +16 -1
  29. sglang/srt/managers/io_struct.py +38 -5
  30. sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
  31. sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +85 -25
  32. sglang/srt/managers/tokenizer_manager.py +99 -57
  33. sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +177 -81
  34. sglang/srt/mem_cache/flush_cache.py +33 -0
  35. sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
  36. sglang/srt/{managers/controller → mem_cache}/radix_cache.py +15 -0
  37. sglang/srt/mm_utils.py +15 -0
  38. sglang/srt/model_config.py +20 -0
  39. sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +42 -18
  40. sglang/srt/{managers/controller → model_executor}/model_runner.py +51 -16
  41. sglang/srt/model_loader/model_loader.py +15 -0
  42. sglang/srt/model_loader/utils.py +16 -1
  43. sglang/srt/models/chatglm.py +16 -1
  44. sglang/srt/models/commandr.py +16 -1
  45. sglang/srt/models/dbrx.py +16 -1
  46. sglang/srt/models/deepseek.py +16 -1
  47. sglang/srt/models/deepseek_v2.py +532 -0
  48. sglang/srt/models/gemma.py +16 -1
  49. sglang/srt/models/gemma2.py +16 -1
  50. sglang/srt/models/gpt_bigcode.py +16 -1
  51. sglang/srt/models/grok.py +16 -1
  52. sglang/srt/models/internlm2.py +16 -1
  53. sglang/srt/models/llama2.py +16 -1
  54. sglang/srt/models/llama_classification.py +19 -4
  55. sglang/srt/models/llava.py +17 -2
  56. sglang/srt/models/llavavid.py +17 -2
  57. sglang/srt/models/minicpm.py +16 -1
  58. sglang/srt/models/mistral.py +15 -0
  59. sglang/srt/models/mixtral.py +16 -1
  60. sglang/srt/models/mixtral_quant.py +16 -1
  61. sglang/srt/models/qwen.py +16 -1
  62. sglang/srt/models/qwen2.py +16 -1
  63. sglang/srt/models/qwen2_moe.py +16 -1
  64. sglang/srt/models/stablelm.py +16 -1
  65. sglang/srt/models/yivl.py +15 -0
  66. sglang/srt/openai_api/adapter.py +545 -160
  67. sglang/srt/openai_api/protocol.py +65 -1
  68. sglang/srt/sampling_params.py +20 -4
  69. sglang/srt/server.py +90 -37
  70. sglang/srt/server_args.py +76 -17
  71. sglang/srt/utils.py +15 -0
  72. sglang/test/test_programs.py +5 -1
  73. sglang/utils.py +22 -0
  74. sglang/version.py +1 -1
  75. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/METADATA +40 -12
  76. sglang-0.2.7.dist-info/RECORD +93 -0
  77. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/WHEEL +1 -1
  78. sglang/srt/flush_cache.py +0 -18
  79. sglang-0.2.5.dist-info/RECORD +0 -92
  80. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,21 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """Meta data for requests and batches"""
2
17
 
18
+ import logging
3
19
  import warnings
4
20
  from dataclasses import dataclass
5
21
  from enum import IntEnum, auto
@@ -12,11 +28,21 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
12
28
  from sglang.global_config import global_config
13
29
  from sglang.srt.constrained import RegexGuide
14
30
  from sglang.srt.constrained.jump_forward import JumpForwardMap
15
- from sglang.srt.managers.controller.radix_cache import RadixCache
16
- from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
31
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
32
+ from sglang.srt.mem_cache.radix_cache import RadixCache
17
33
 
18
34
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
19
35
 
36
+ # Put some global args for easy access
37
+ global_server_args_dict = {
38
+ "disable_flashinfer": False,
39
+ "disable_flashinfer_sampling": False,
40
+ "attention_reduce_in_fp32": False,
41
+ }
42
+
43
+
44
+ logger = logging.getLogger(__name__)
45
+
20
46
 
21
47
  class ForwardMode(IntEnum):
22
48
  # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
@@ -124,10 +150,10 @@ class Req:
124
150
  self.logprob_start_len = 0
125
151
  self.top_logprobs_num = 0
126
152
  self.normalized_prompt_logprob = None
127
- self.prefill_token_logprobs = None
128
- self.prefill_top_logprobs = None
129
- self.decode_token_logprobs = []
130
- self.decode_top_logprobs = []
153
+ self.input_token_logprobs = None
154
+ self.input_top_logprobs = None
155
+ self.output_token_logprobs = []
156
+ self.output_top_logprobs = []
131
157
  # The tokens is prefilled but need to be considered as decode tokens
132
158
  # and should be updated for the decode logprobs
133
159
  self.last_update_decode_tokens = 0
@@ -244,8 +270,8 @@ class Req:
244
270
  k = k + 1
245
271
  else:
246
272
  break
247
- self.decode_token_logprobs = self.decode_token_logprobs[:k]
248
- self.decode_top_logprobs = self.decode_top_logprobs[:k]
273
+ self.output_token_logprobs = self.output_token_logprobs[:k]
274
+ self.output_top_logprobs = self.output_top_logprobs[:k]
249
275
  self.logprob_start_len = prompt_tokens + k
250
276
  self.last_update_decode_tokens = len(self.output_ids) - k
251
277
 
@@ -357,7 +383,7 @@ class Batch:
357
383
  out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
358
384
 
359
385
  if out_cache_loc is None:
360
- print("Prefill out of memory. This should never happen.")
386
+ logger.error("Prefill out of memory. This should never happen.")
361
387
  self.tree_cache.pretty_print()
362
388
  exit()
363
389
 
@@ -376,7 +402,7 @@ class Batch:
376
402
  logit_bias = torch.zeros(
377
403
  (bs, vocab_size), dtype=torch.float32, device=device
378
404
  )
379
- logit_bias[i] = int_token_logit_bias
405
+ logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
380
406
 
381
407
  # Set fields
382
408
  self.input_ids = torch.tensor(
@@ -591,7 +617,7 @@ class Batch:
591
617
  self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
592
618
 
593
619
  if self.out_cache_loc is None:
594
- print("Decode out of memory. This should never happen.")
620
+ logger.error("Decode out of memory. This should never happen.")
595
621
  self.tree_cache.pretty_print()
596
622
  exit()
597
623
 
@@ -687,13 +713,21 @@ class Batch:
687
713
  # TODO(lmzheng): apply penalty
688
714
  probs = torch.softmax(logits, dim=-1)
689
715
 
690
- max_top_k_round, batch_size = 32, probs.shape[0]
691
- uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
692
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
693
- probs, uniform_samples, self.top_ks, self.top_ps
694
- )
716
+ if not global_server_args_dict["disable_flashinfer_sampling"]:
717
+ max_top_k_round, batch_size = 32, probs.shape[0]
718
+ uniform_samples = torch.rand(
719
+ (max_top_k_round, batch_size), device=probs.device
720
+ )
721
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
722
+ probs, uniform_samples, self.top_ks, self.top_ps
723
+ )
724
+ else:
725
+ # Here we provide a slower fallback implementation.
726
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
727
+ probs, self.top_ks, self.top_ps
728
+ )
695
729
 
696
- if torch.any(~success):
730
+ if not torch.all(success):
697
731
  warnings.warn("Sampling failed, fallback to top_k=1 strategy")
698
732
  probs = probs.masked_fill(torch.isnan(probs), 0.0)
699
733
  argmax_ids = torch.argmax(probs, dim=-1)
@@ -747,7 +781,7 @@ class InputMetadata:
747
781
  flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
748
782
  flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
749
783
  flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
750
- use_ragged: bool = False
784
+ flashinfer_use_ragged: bool = False
751
785
 
752
786
  @classmethod
753
787
  def create(
@@ -763,10 +797,10 @@ class InputMetadata:
763
797
  return_logprob=False,
764
798
  skip_flashinfer_init=False,
765
799
  ):
766
- use_ragged = False
800
+ flashinfer_use_ragged = False
767
801
  if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
768
802
  if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
769
- use_ragged = True
803
+ flashinfer_use_ragged = True
770
804
  init_flashinfer_args(
771
805
  forward_mode,
772
806
  model_runner,
@@ -774,7 +808,7 @@ class InputMetadata:
774
808
  seq_lens,
775
809
  prefix_lens,
776
810
  model_runner.flashinfer_decode_wrapper,
777
- use_ragged,
811
+ flashinfer_use_ragged,
778
812
  )
779
813
 
780
814
  batch_size = len(req_pool_indices)
@@ -829,7 +863,7 @@ class InputMetadata:
829
863
  flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
830
864
  flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
831
865
  flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
832
- use_ragged=use_ragged,
866
+ flashinfer_use_ragged=flashinfer_use_ragged,
833
867
  )
834
868
 
835
869
  if model_runner.server_args.disable_flashinfer:
@@ -850,7 +884,7 @@ def init_flashinfer_args(
850
884
  seq_lens,
851
885
  prefix_lens,
852
886
  flashinfer_decode_wrapper,
853
- use_ragged=False,
887
+ flashinfer_use_ragged=False,
854
888
  ):
855
889
  """Init auxiliary variables for FlashInfer attention backend."""
856
890
  num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
@@ -859,7 +893,7 @@ def init_flashinfer_args(
859
893
  batch_size = len(req_pool_indices)
860
894
  total_num_tokens = int(torch.sum(seq_lens))
861
895
 
862
- if use_ragged:
896
+ if flashinfer_use_ragged:
863
897
  paged_kernel_lens = prefix_lens
864
898
  else:
865
899
  paged_kernel_lens = seq_lens
@@ -895,7 +929,7 @@ def init_flashinfer_args(
895
929
  qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
896
930
  qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
897
931
 
898
- if use_ragged:
932
+ if flashinfer_use_ragged:
899
933
  model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
900
934
  model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
901
935
  qo_indptr,
@@ -933,3 +967,29 @@ def init_triton_args(forward_mode, seq_lens, prefix_lens):
933
967
  max_extend_len = int(torch.max(extend_seq_lens))
934
968
 
935
969
  return max_seq_len, max_extend_len, start_loc, prefix_lens
970
+
971
+
972
+ def top_k_top_p_sampling_from_probs_torch(
973
+ probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
974
+ ):
975
+ """A top-k and top-k sampling implementation with native pytorch operations."""
976
+ probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
977
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
978
+ probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
979
+ probs_sort[
980
+ torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
981
+ >= top_ks.view(-1, 1)
982
+ ] = 0.0
983
+ probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
984
+ try:
985
+ sampled_index = torch.multinomial(probs_sort, num_samples=1)
986
+ except RuntimeError:
987
+ batch_next_token_ids = torch.zeros(
988
+ (probs_sort.shape[0],), dtype=torch.int64, device=probs.device
989
+ )
990
+ success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
991
+ return batch_next_token_ids, success
992
+
993
+ batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
994
+ success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
995
+ return batch_next_token_ids, success
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """TokenizerManager is a process that tokenizes the text."""
2
17
 
3
18
  import asyncio
@@ -6,7 +21,7 @@ import dataclasses
6
21
  import logging
7
22
  import multiprocessing as mp
8
23
  import os
9
- from typing import Dict, List
24
+ from typing import Dict, List, Tuple
10
25
 
11
26
  import numpy as np
12
27
  import transformers
@@ -69,6 +84,7 @@ class TokenizerManager:
69
84
  trust_remote_code=server_args.trust_remote_code,
70
85
  model_overide_args=model_overide_args,
71
86
  )
87
+
72
88
  if server_args.context_length is not None:
73
89
  self.context_len = server_args.context_length
74
90
  else:
@@ -133,50 +149,54 @@ class TokenizerManager:
133
149
  async for response in self._handle_batch_request(obj, request):
134
150
  yield response
135
151
 
136
- async def _handle_single_request(self, obj, request, index=None, is_prefill=False):
137
- if is_prefill:
138
- if isinstance(obj.text, list):
139
- input_text = obj.text[index]
140
- rid = obj.rid[index]
141
- else:
142
- input_text = obj.text
143
- rid = obj.rid[0]
144
- input_ids = self.tokenizer.encode(input_text)
145
- sampling_params = SamplingParams(**obj.sampling_params[0])
146
- sampling_params.max_new_tokens = 0
147
- pixel_values, image_hash, image_size = await self._get_pixel_values(
148
- obj.image_data[0]
149
- )
150
- return_logprob = obj.return_logprob[0]
151
- logprob_start_len = obj.logprob_start_len[0]
152
- top_logprobs_num = obj.top_logprobs_num[0]
153
- else:
154
- rid = obj.rid if index is None else obj.rid[index]
155
- input_text = obj.text if index is None else obj.text[index]
152
+ async def _handle_single_request(
153
+ self, obj, request, index=None, is_cache_for_prefill=False
154
+ ):
155
+ if not is_cache_for_prefill:
156
+ not_use_index = not (index is not None)
157
+ rid = obj.rid if not_use_index else obj.rid[index]
158
+ input_text = obj.text if not_use_index else obj.text[index]
156
159
  input_ids = (
157
160
  self.tokenizer.encode(input_text)
158
161
  if obj.input_ids is None
159
162
  else obj.input_ids
160
163
  )
161
- if index is not None and obj.input_ids:
164
+ if not not_use_index and obj.input_ids:
162
165
  input_ids = obj.input_ids[index]
163
166
 
164
167
  self._validate_input_length(input_ids)
168
+
165
169
  sampling_params = self._get_sampling_params(
166
- obj.sampling_params if index is None else obj.sampling_params[index]
170
+ obj.sampling_params if not_use_index else obj.sampling_params[index]
167
171
  )
168
172
  pixel_values, image_hash, image_size = await self._get_pixel_values(
169
- obj.image_data if index is None else obj.image_data[index]
173
+ obj.image_data if not_use_index else obj.image_data[index]
170
174
  )
171
175
  return_logprob = (
172
- obj.return_logprob if index is None else obj.return_logprob[index]
176
+ obj.return_logprob if not_use_index else obj.return_logprob[index]
173
177
  )
174
178
  logprob_start_len = (
175
- obj.logprob_start_len if index is None else obj.logprob_start_len[index]
179
+ obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
176
180
  )
177
181
  top_logprobs_num = (
178
- obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
182
+ obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
179
183
  )
184
+ else:
185
+ if isinstance(obj.text, list):
186
+ input_text = obj.text[index]
187
+ rid = obj.rid[index]
188
+ else:
189
+ input_text = obj.text
190
+ rid = obj.rid[0]
191
+ input_ids = self.tokenizer.encode(input_text)
192
+ sampling_params = SamplingParams(**obj.sampling_params[0])
193
+ sampling_params.max_new_tokens = 0
194
+ pixel_values, image_hash, image_size = await self._get_pixel_values(
195
+ obj.image_data[0]
196
+ )
197
+ return_logprob = obj.return_logprob[0]
198
+ logprob_start_len = obj.logprob_start_len[0]
199
+ top_logprobs_num = obj.top_logprobs_num[0]
180
200
 
181
201
  tokenized_obj = TokenizedGenerateReqInput(
182
202
  rid,
@@ -196,26 +216,26 @@ class TokenizerManager:
196
216
  event = asyncio.Event()
197
217
  state = ReqState([], False, event)
198
218
  self.rid_to_state[rid] = state
199
- if is_prefill:
200
- await self._wait_for_prefill_response(event, state, obj, request, rid)
201
- yield input_ids
202
- else:
219
+ if not is_cache_for_prefill:
203
220
  async for response in self._wait_for_response(
204
221
  event, state, obj, rid, request
205
222
  ):
206
223
  yield response
224
+ else:
225
+ await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
226
+ yield input_ids
207
227
 
208
- async def _handle_batch_request(self, obj, request):
228
+ async def _handle_batch_request(self, obj: GenerateReqInput, request):
209
229
  batch_size = obj.batch_size
210
- parallel_sample_num = obj.sampling_params[0].get("n", 1)
230
+ parallel_sample_num = obj.parallel_sample_num
211
231
 
212
232
  if parallel_sample_num != 1:
213
- ## send prefill requests
233
+ # Send prefill requests to cache the common input
214
234
  parallel_sample_num += 1
215
235
  input_id_result = [] if obj.input_ids is None else None
216
236
  for i in range(batch_size):
217
237
  async for input_id in self._handle_single_request(
218
- obj, request, index=i, is_prefill=True
238
+ obj, request, index=i, is_cache_for_prefill=True
219
239
  ):
220
240
  if input_id_result is not None:
221
241
  input_id_result.append(input_id)
@@ -231,7 +251,7 @@ class TokenizerManager:
231
251
  continue
232
252
  index = i * parallel_sample_num + j
233
253
  if parallel_sample_num != 1:
234
- # Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
254
+ # Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
235
255
  index += batch_size - 1 - i
236
256
  rid = obj.rid[index]
237
257
  if parallel_sample_num == 1:
@@ -308,17 +328,15 @@ class TokenizerManager:
308
328
 
309
329
  yield output_list
310
330
 
311
- def _validate_input_length(self, input_ids):
331
+ def _validate_input_length(self, input_ids: List[int]):
312
332
  if len(input_ids) >= self.context_len:
313
333
  raise ValueError(
314
334
  f"The input ({len(input_ids)} tokens) is longer than the "
315
335
  f"model's context length ({self.context_len} tokens)."
316
336
  )
317
337
 
318
- def _get_sampling_params(self, sampling_params_data, max_new_tokens=None):
338
+ def _get_sampling_params(self, sampling_params_data: dict):
319
339
  sampling_params = SamplingParams(**sampling_params_data)
320
- if max_new_tokens is not None:
321
- sampling_params.max_new_tokens = max_new_tokens
322
340
  if sampling_params.max_new_tokens != 0:
323
341
  sampling_params.normalize(self.tokenizer)
324
342
  sampling_params.verify()
@@ -332,7 +350,14 @@ class TokenizerManager:
332
350
  else:
333
351
  return None, None, None
334
352
 
335
- async def _wait_for_response(self, event, state, obj, rid, request):
353
+ async def _wait_for_response(
354
+ self,
355
+ event: asyncio.Event,
356
+ state: ReqState,
357
+ obj: GenerateReqInput,
358
+ rid: str,
359
+ request,
360
+ ):
336
361
  while True:
337
362
  try:
338
363
  await asyncio.wait_for(event.wait(), timeout=4)
@@ -361,7 +386,14 @@ class TokenizerManager:
361
386
  event.clear()
362
387
  yield out
363
388
 
364
- async def _wait_for_prefill_response(self, event, state, obj, request, rid):
389
+ async def _wait_for_cache_prefill_response(
390
+ self,
391
+ event: asyncio.Event,
392
+ state: ReqState,
393
+ obj: GenerateReqInput,
394
+ rid: str,
395
+ request,
396
+ ):
365
397
  while True:
366
398
  try:
367
399
  await asyncio.wait_for(state.event.wait(), timeout=4)
@@ -380,7 +412,7 @@ class TokenizerManager:
380
412
  req = FlushCacheReq()
381
413
  self.send_to_router.send_pyobj(req)
382
414
 
383
- def abort_request(self, rid):
415
+ def abort_request(self, rid: str):
384
416
  if rid not in self.rid_to_state:
385
417
  return
386
418
  del self.rid_to_state[rid]
@@ -426,31 +458,37 @@ class TokenizerManager:
426
458
  state.event.set()
427
459
 
428
460
  def convert_logprob_style(
429
- self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
461
+ self,
462
+ ret: dict,
463
+ return_logprob: bool,
464
+ top_logprobs_num: int,
465
+ return_text_in_logprobs: bool,
430
466
  ):
431
467
  if return_logprob:
432
- ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
433
- ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
468
+ ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
469
+ ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
434
470
  )
435
- ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
436
- ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
471
+ ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
472
+ ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
437
473
  )
438
474
 
439
475
  if top_logprobs_num > 0:
440
- ret["meta_info"]["prefill_top_logprobs"] = (
476
+ ret["meta_info"]["input_top_logprobs"] = (
441
477
  self.detokenize_top_logprobs_tokens(
442
- ret["meta_info"]["prefill_top_logprobs"],
478
+ ret["meta_info"]["input_top_logprobs"],
443
479
  return_text_in_logprobs,
444
480
  )
445
481
  )
446
- ret["meta_info"]["decode_top_logprobs"] = (
482
+ ret["meta_info"]["output_top_logprobs"] = (
447
483
  self.detokenize_top_logprobs_tokens(
448
- ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
484
+ ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
449
485
  )
450
486
  )
451
487
  return ret
452
488
 
453
- def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
489
+ def detokenize_logprob_tokens(
490
+ self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
491
+ ):
454
492
  if not decode_to_text:
455
493
  return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
456
494
 
@@ -461,10 +499,14 @@ class TokenizerManager:
461
499
  for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
462
500
  ]
463
501
 
464
- def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
465
- for i, t in enumerate(top_logprobs):
466
- if t:
467
- top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
502
+ def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
503
+ # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
504
+ # We should batch all top-k tokens in all positions.
505
+ for i, token_top_logprobs in enumerate(top_logprobs):
506
+ if token_top_logprobs:
507
+ top_logprobs[i] = self.detokenize_logprob_tokens(
508
+ token_top_logprobs, decode_to_text
509
+ )
468
510
  return top_logprobs
469
511
 
470
512