sglang 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +33 -26
- sglang/api.py +9 -1
- sglang/bench_latency.py +2 -2
- sglang/bench_serving.py +10 -1
- sglang/check_env.py +1 -1
- sglang/lang/backend/litellm.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +4 -4
- sglang/lang/interpreter.py +24 -9
- sglang/lang/ir.py +1 -1
- sglang/srt/constrained/__init__.py +15 -0
- sglang/srt/constrained/base_cache.py +15 -0
- sglang/srt/constrained/fsm_cache.py +36 -1
- sglang/srt/constrained/jump_forward.py +15 -0
- sglang/srt/conversation.py +26 -0
- sglang/srt/hf_transformers_utils.py +18 -1
- sglang/srt/layers/context_flashattention_nopad.py +15 -0
- sglang/srt/layers/extend_attention.py +15 -0
- sglang/srt/layers/fused_moe.py +15 -0
- sglang/srt/layers/linear.py +15 -0
- sglang/srt/layers/logits_processor.py +109 -72
- sglang/srt/layers/quantization/__init__.py +15 -0
- sglang/srt/layers/quantization/fp8.py +15 -0
- sglang/srt/layers/radix_attention.py +21 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
- sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
- sglang/srt/managers/detokenizer_manager.py +16 -1
- sglang/srt/managers/io_struct.py +38 -5
- sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
- sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +85 -25
- sglang/srt/managers/tokenizer_manager.py +99 -57
- sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +177 -81
- sglang/srt/mem_cache/flush_cache.py +33 -0
- sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
- sglang/srt/{managers/controller → mem_cache}/radix_cache.py +15 -0
- sglang/srt/mm_utils.py +15 -0
- sglang/srt/model_config.py +20 -0
- sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +42 -18
- sglang/srt/{managers/controller → model_executor}/model_runner.py +51 -16
- sglang/srt/model_loader/model_loader.py +15 -0
- sglang/srt/model_loader/utils.py +16 -1
- sglang/srt/models/chatglm.py +16 -1
- sglang/srt/models/commandr.py +16 -1
- sglang/srt/models/dbrx.py +16 -1
- sglang/srt/models/deepseek.py +16 -1
- sglang/srt/models/deepseek_v2.py +532 -0
- sglang/srt/models/gemma.py +16 -1
- sglang/srt/models/gemma2.py +16 -1
- sglang/srt/models/gpt_bigcode.py +16 -1
- sglang/srt/models/grok.py +16 -1
- sglang/srt/models/internlm2.py +16 -1
- sglang/srt/models/llama2.py +16 -1
- sglang/srt/models/llama_classification.py +19 -4
- sglang/srt/models/llava.py +17 -2
- sglang/srt/models/llavavid.py +17 -2
- sglang/srt/models/minicpm.py +16 -1
- sglang/srt/models/mistral.py +15 -0
- sglang/srt/models/mixtral.py +16 -1
- sglang/srt/models/mixtral_quant.py +16 -1
- sglang/srt/models/qwen.py +16 -1
- sglang/srt/models/qwen2.py +16 -1
- sglang/srt/models/qwen2_moe.py +16 -1
- sglang/srt/models/stablelm.py +16 -1
- sglang/srt/models/yivl.py +15 -0
- sglang/srt/openai_api/adapter.py +545 -160
- sglang/srt/openai_api/protocol.py +65 -1
- sglang/srt/sampling_params.py +20 -4
- sglang/srt/server.py +90 -37
- sglang/srt/server_args.py +76 -17
- sglang/srt/utils.py +15 -0
- sglang/test/test_programs.py +5 -1
- sglang/utils.py +22 -0
- sglang/version.py +1 -1
- {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/METADATA +40 -12
- sglang-0.2.7.dist-info/RECORD +93 -0
- {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/WHEEL +1 -1
- sglang/srt/flush_cache.py +0 -18
- sglang-0.2.5.dist-info/RECORD +0 -92
- {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/LICENSE +0 -0
- {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,21 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""Meta data for requests and batches"""
|
2
17
|
|
18
|
+
import logging
|
3
19
|
import warnings
|
4
20
|
from dataclasses import dataclass
|
5
21
|
from enum import IntEnum, auto
|
@@ -12,11 +28,21 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
|
12
28
|
from sglang.global_config import global_config
|
13
29
|
from sglang.srt.constrained import RegexGuide
|
14
30
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
15
|
-
from sglang.srt.
|
16
|
-
from sglang.srt.
|
31
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
32
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache
|
17
33
|
|
18
34
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
19
35
|
|
36
|
+
# Put some global args for easy access
|
37
|
+
global_server_args_dict = {
|
38
|
+
"disable_flashinfer": False,
|
39
|
+
"disable_flashinfer_sampling": False,
|
40
|
+
"attention_reduce_in_fp32": False,
|
41
|
+
}
|
42
|
+
|
43
|
+
|
44
|
+
logger = logging.getLogger(__name__)
|
45
|
+
|
20
46
|
|
21
47
|
class ForwardMode(IntEnum):
|
22
48
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
@@ -124,10 +150,10 @@ class Req:
|
|
124
150
|
self.logprob_start_len = 0
|
125
151
|
self.top_logprobs_num = 0
|
126
152
|
self.normalized_prompt_logprob = None
|
127
|
-
self.
|
128
|
-
self.
|
129
|
-
self.
|
130
|
-
self.
|
153
|
+
self.input_token_logprobs = None
|
154
|
+
self.input_top_logprobs = None
|
155
|
+
self.output_token_logprobs = []
|
156
|
+
self.output_top_logprobs = []
|
131
157
|
# The tokens is prefilled but need to be considered as decode tokens
|
132
158
|
# and should be updated for the decode logprobs
|
133
159
|
self.last_update_decode_tokens = 0
|
@@ -244,8 +270,8 @@ class Req:
|
|
244
270
|
k = k + 1
|
245
271
|
else:
|
246
272
|
break
|
247
|
-
self.
|
248
|
-
self.
|
273
|
+
self.output_token_logprobs = self.output_token_logprobs[:k]
|
274
|
+
self.output_top_logprobs = self.output_top_logprobs[:k]
|
249
275
|
self.logprob_start_len = prompt_tokens + k
|
250
276
|
self.last_update_decode_tokens = len(self.output_ids) - k
|
251
277
|
|
@@ -357,7 +383,7 @@ class Batch:
|
|
357
383
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
358
384
|
|
359
385
|
if out_cache_loc is None:
|
360
|
-
|
386
|
+
logger.error("Prefill out of memory. This should never happen.")
|
361
387
|
self.tree_cache.pretty_print()
|
362
388
|
exit()
|
363
389
|
|
@@ -376,7 +402,7 @@ class Batch:
|
|
376
402
|
logit_bias = torch.zeros(
|
377
403
|
(bs, vocab_size), dtype=torch.float32, device=device
|
378
404
|
)
|
379
|
-
logit_bias[i] = int_token_logit_bias
|
405
|
+
logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
380
406
|
|
381
407
|
# Set fields
|
382
408
|
self.input_ids = torch.tensor(
|
@@ -591,7 +617,7 @@ class Batch:
|
|
591
617
|
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
592
618
|
|
593
619
|
if self.out_cache_loc is None:
|
594
|
-
|
620
|
+
logger.error("Decode out of memory. This should never happen.")
|
595
621
|
self.tree_cache.pretty_print()
|
596
622
|
exit()
|
597
623
|
|
@@ -687,13 +713,21 @@ class Batch:
|
|
687
713
|
# TODO(lmzheng): apply penalty
|
688
714
|
probs = torch.softmax(logits, dim=-1)
|
689
715
|
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
716
|
+
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
717
|
+
max_top_k_round, batch_size = 32, probs.shape[0]
|
718
|
+
uniform_samples = torch.rand(
|
719
|
+
(max_top_k_round, batch_size), device=probs.device
|
720
|
+
)
|
721
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
722
|
+
probs, uniform_samples, self.top_ks, self.top_ps
|
723
|
+
)
|
724
|
+
else:
|
725
|
+
# Here we provide a slower fallback implementation.
|
726
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
|
727
|
+
probs, self.top_ks, self.top_ps
|
728
|
+
)
|
695
729
|
|
696
|
-
if torch.
|
730
|
+
if not torch.all(success):
|
697
731
|
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
698
732
|
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
699
733
|
argmax_ids = torch.argmax(probs, dim=-1)
|
@@ -747,7 +781,7 @@ class InputMetadata:
|
|
747
781
|
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
748
782
|
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
749
783
|
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
750
|
-
|
784
|
+
flashinfer_use_ragged: bool = False
|
751
785
|
|
752
786
|
@classmethod
|
753
787
|
def create(
|
@@ -763,10 +797,10 @@ class InputMetadata:
|
|
763
797
|
return_logprob=False,
|
764
798
|
skip_flashinfer_init=False,
|
765
799
|
):
|
766
|
-
|
800
|
+
flashinfer_use_ragged = False
|
767
801
|
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
768
802
|
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
769
|
-
|
803
|
+
flashinfer_use_ragged = True
|
770
804
|
init_flashinfer_args(
|
771
805
|
forward_mode,
|
772
806
|
model_runner,
|
@@ -774,7 +808,7 @@ class InputMetadata:
|
|
774
808
|
seq_lens,
|
775
809
|
prefix_lens,
|
776
810
|
model_runner.flashinfer_decode_wrapper,
|
777
|
-
|
811
|
+
flashinfer_use_ragged,
|
778
812
|
)
|
779
813
|
|
780
814
|
batch_size = len(req_pool_indices)
|
@@ -829,7 +863,7 @@ class InputMetadata:
|
|
829
863
|
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
830
864
|
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
831
865
|
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
832
|
-
|
866
|
+
flashinfer_use_ragged=flashinfer_use_ragged,
|
833
867
|
)
|
834
868
|
|
835
869
|
if model_runner.server_args.disable_flashinfer:
|
@@ -850,7 +884,7 @@ def init_flashinfer_args(
|
|
850
884
|
seq_lens,
|
851
885
|
prefix_lens,
|
852
886
|
flashinfer_decode_wrapper,
|
853
|
-
|
887
|
+
flashinfer_use_ragged=False,
|
854
888
|
):
|
855
889
|
"""Init auxiliary variables for FlashInfer attention backend."""
|
856
890
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
@@ -859,7 +893,7 @@ def init_flashinfer_args(
|
|
859
893
|
batch_size = len(req_pool_indices)
|
860
894
|
total_num_tokens = int(torch.sum(seq_lens))
|
861
895
|
|
862
|
-
if
|
896
|
+
if flashinfer_use_ragged:
|
863
897
|
paged_kernel_lens = prefix_lens
|
864
898
|
else:
|
865
899
|
paged_kernel_lens = seq_lens
|
@@ -895,7 +929,7 @@ def init_flashinfer_args(
|
|
895
929
|
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
896
930
|
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
897
931
|
|
898
|
-
if
|
932
|
+
if flashinfer_use_ragged:
|
899
933
|
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
900
934
|
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
901
935
|
qo_indptr,
|
@@ -933,3 +967,29 @@ def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
|
933
967
|
max_extend_len = int(torch.max(extend_seq_lens))
|
934
968
|
|
935
969
|
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
970
|
+
|
971
|
+
|
972
|
+
def top_k_top_p_sampling_from_probs_torch(
|
973
|
+
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
974
|
+
):
|
975
|
+
"""A top-k and top-k sampling implementation with native pytorch operations."""
|
976
|
+
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
977
|
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
978
|
+
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
979
|
+
probs_sort[
|
980
|
+
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
981
|
+
>= top_ks.view(-1, 1)
|
982
|
+
] = 0.0
|
983
|
+
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
984
|
+
try:
|
985
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
986
|
+
except RuntimeError:
|
987
|
+
batch_next_token_ids = torch.zeros(
|
988
|
+
(probs_sort.shape[0],), dtype=torch.int64, device=probs.device
|
989
|
+
)
|
990
|
+
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
991
|
+
return batch_next_token_ids, success
|
992
|
+
|
993
|
+
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
994
|
+
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
|
995
|
+
return batch_next_token_ids, success
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""TokenizerManager is a process that tokenizes the text."""
|
2
17
|
|
3
18
|
import asyncio
|
@@ -6,7 +21,7 @@ import dataclasses
|
|
6
21
|
import logging
|
7
22
|
import multiprocessing as mp
|
8
23
|
import os
|
9
|
-
from typing import Dict, List
|
24
|
+
from typing import Dict, List, Tuple
|
10
25
|
|
11
26
|
import numpy as np
|
12
27
|
import transformers
|
@@ -69,6 +84,7 @@ class TokenizerManager:
|
|
69
84
|
trust_remote_code=server_args.trust_remote_code,
|
70
85
|
model_overide_args=model_overide_args,
|
71
86
|
)
|
87
|
+
|
72
88
|
if server_args.context_length is not None:
|
73
89
|
self.context_len = server_args.context_length
|
74
90
|
else:
|
@@ -133,50 +149,54 @@ class TokenizerManager:
|
|
133
149
|
async for response in self._handle_batch_request(obj, request):
|
134
150
|
yield response
|
135
151
|
|
136
|
-
async def _handle_single_request(
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
else
|
142
|
-
|
143
|
-
rid = obj.rid[0]
|
144
|
-
input_ids = self.tokenizer.encode(input_text)
|
145
|
-
sampling_params = SamplingParams(**obj.sampling_params[0])
|
146
|
-
sampling_params.max_new_tokens = 0
|
147
|
-
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
148
|
-
obj.image_data[0]
|
149
|
-
)
|
150
|
-
return_logprob = obj.return_logprob[0]
|
151
|
-
logprob_start_len = obj.logprob_start_len[0]
|
152
|
-
top_logprobs_num = obj.top_logprobs_num[0]
|
153
|
-
else:
|
154
|
-
rid = obj.rid if index is None else obj.rid[index]
|
155
|
-
input_text = obj.text if index is None else obj.text[index]
|
152
|
+
async def _handle_single_request(
|
153
|
+
self, obj, request, index=None, is_cache_for_prefill=False
|
154
|
+
):
|
155
|
+
if not is_cache_for_prefill:
|
156
|
+
not_use_index = not (index is not None)
|
157
|
+
rid = obj.rid if not_use_index else obj.rid[index]
|
158
|
+
input_text = obj.text if not_use_index else obj.text[index]
|
156
159
|
input_ids = (
|
157
160
|
self.tokenizer.encode(input_text)
|
158
161
|
if obj.input_ids is None
|
159
162
|
else obj.input_ids
|
160
163
|
)
|
161
|
-
if
|
164
|
+
if not not_use_index and obj.input_ids:
|
162
165
|
input_ids = obj.input_ids[index]
|
163
166
|
|
164
167
|
self._validate_input_length(input_ids)
|
168
|
+
|
165
169
|
sampling_params = self._get_sampling_params(
|
166
|
-
obj.sampling_params if
|
170
|
+
obj.sampling_params if not_use_index else obj.sampling_params[index]
|
167
171
|
)
|
168
172
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
169
|
-
obj.image_data if
|
173
|
+
obj.image_data if not_use_index else obj.image_data[index]
|
170
174
|
)
|
171
175
|
return_logprob = (
|
172
|
-
obj.return_logprob if
|
176
|
+
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
173
177
|
)
|
174
178
|
logprob_start_len = (
|
175
|
-
obj.logprob_start_len if
|
179
|
+
obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
|
176
180
|
)
|
177
181
|
top_logprobs_num = (
|
178
|
-
obj.top_logprobs_num if
|
182
|
+
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
|
179
183
|
)
|
184
|
+
else:
|
185
|
+
if isinstance(obj.text, list):
|
186
|
+
input_text = obj.text[index]
|
187
|
+
rid = obj.rid[index]
|
188
|
+
else:
|
189
|
+
input_text = obj.text
|
190
|
+
rid = obj.rid[0]
|
191
|
+
input_ids = self.tokenizer.encode(input_text)
|
192
|
+
sampling_params = SamplingParams(**obj.sampling_params[0])
|
193
|
+
sampling_params.max_new_tokens = 0
|
194
|
+
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
195
|
+
obj.image_data[0]
|
196
|
+
)
|
197
|
+
return_logprob = obj.return_logprob[0]
|
198
|
+
logprob_start_len = obj.logprob_start_len[0]
|
199
|
+
top_logprobs_num = obj.top_logprobs_num[0]
|
180
200
|
|
181
201
|
tokenized_obj = TokenizedGenerateReqInput(
|
182
202
|
rid,
|
@@ -196,26 +216,26 @@ class TokenizerManager:
|
|
196
216
|
event = asyncio.Event()
|
197
217
|
state = ReqState([], False, event)
|
198
218
|
self.rid_to_state[rid] = state
|
199
|
-
if
|
200
|
-
await self._wait_for_prefill_response(event, state, obj, request, rid)
|
201
|
-
yield input_ids
|
202
|
-
else:
|
219
|
+
if not is_cache_for_prefill:
|
203
220
|
async for response in self._wait_for_response(
|
204
221
|
event, state, obj, rid, request
|
205
222
|
):
|
206
223
|
yield response
|
224
|
+
else:
|
225
|
+
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
|
226
|
+
yield input_ids
|
207
227
|
|
208
|
-
async def _handle_batch_request(self, obj, request):
|
228
|
+
async def _handle_batch_request(self, obj: GenerateReqInput, request):
|
209
229
|
batch_size = obj.batch_size
|
210
|
-
parallel_sample_num = obj.
|
230
|
+
parallel_sample_num = obj.parallel_sample_num
|
211
231
|
|
212
232
|
if parallel_sample_num != 1:
|
213
|
-
|
233
|
+
# Send prefill requests to cache the common input
|
214
234
|
parallel_sample_num += 1
|
215
235
|
input_id_result = [] if obj.input_ids is None else None
|
216
236
|
for i in range(batch_size):
|
217
237
|
async for input_id in self._handle_single_request(
|
218
|
-
obj, request, index=i,
|
238
|
+
obj, request, index=i, is_cache_for_prefill=True
|
219
239
|
):
|
220
240
|
if input_id_result is not None:
|
221
241
|
input_id_result.append(input_id)
|
@@ -231,7 +251,7 @@ class TokenizerManager:
|
|
231
251
|
continue
|
232
252
|
index = i * parallel_sample_num + j
|
233
253
|
if parallel_sample_num != 1:
|
234
|
-
# Here when using parallel sampling we
|
254
|
+
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
235
255
|
index += batch_size - 1 - i
|
236
256
|
rid = obj.rid[index]
|
237
257
|
if parallel_sample_num == 1:
|
@@ -308,17 +328,15 @@ class TokenizerManager:
|
|
308
328
|
|
309
329
|
yield output_list
|
310
330
|
|
311
|
-
def _validate_input_length(self, input_ids):
|
331
|
+
def _validate_input_length(self, input_ids: List[int]):
|
312
332
|
if len(input_ids) >= self.context_len:
|
313
333
|
raise ValueError(
|
314
334
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
315
335
|
f"model's context length ({self.context_len} tokens)."
|
316
336
|
)
|
317
337
|
|
318
|
-
def _get_sampling_params(self, sampling_params_data
|
338
|
+
def _get_sampling_params(self, sampling_params_data: dict):
|
319
339
|
sampling_params = SamplingParams(**sampling_params_data)
|
320
|
-
if max_new_tokens is not None:
|
321
|
-
sampling_params.max_new_tokens = max_new_tokens
|
322
340
|
if sampling_params.max_new_tokens != 0:
|
323
341
|
sampling_params.normalize(self.tokenizer)
|
324
342
|
sampling_params.verify()
|
@@ -332,7 +350,14 @@ class TokenizerManager:
|
|
332
350
|
else:
|
333
351
|
return None, None, None
|
334
352
|
|
335
|
-
async def _wait_for_response(
|
353
|
+
async def _wait_for_response(
|
354
|
+
self,
|
355
|
+
event: asyncio.Event,
|
356
|
+
state: ReqState,
|
357
|
+
obj: GenerateReqInput,
|
358
|
+
rid: str,
|
359
|
+
request,
|
360
|
+
):
|
336
361
|
while True:
|
337
362
|
try:
|
338
363
|
await asyncio.wait_for(event.wait(), timeout=4)
|
@@ -361,7 +386,14 @@ class TokenizerManager:
|
|
361
386
|
event.clear()
|
362
387
|
yield out
|
363
388
|
|
364
|
-
async def
|
389
|
+
async def _wait_for_cache_prefill_response(
|
390
|
+
self,
|
391
|
+
event: asyncio.Event,
|
392
|
+
state: ReqState,
|
393
|
+
obj: GenerateReqInput,
|
394
|
+
rid: str,
|
395
|
+
request,
|
396
|
+
):
|
365
397
|
while True:
|
366
398
|
try:
|
367
399
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
@@ -380,7 +412,7 @@ class TokenizerManager:
|
|
380
412
|
req = FlushCacheReq()
|
381
413
|
self.send_to_router.send_pyobj(req)
|
382
414
|
|
383
|
-
def abort_request(self, rid):
|
415
|
+
def abort_request(self, rid: str):
|
384
416
|
if rid not in self.rid_to_state:
|
385
417
|
return
|
386
418
|
del self.rid_to_state[rid]
|
@@ -426,31 +458,37 @@ class TokenizerManager:
|
|
426
458
|
state.event.set()
|
427
459
|
|
428
460
|
def convert_logprob_style(
|
429
|
-
self,
|
461
|
+
self,
|
462
|
+
ret: dict,
|
463
|
+
return_logprob: bool,
|
464
|
+
top_logprobs_num: int,
|
465
|
+
return_text_in_logprobs: bool,
|
430
466
|
):
|
431
467
|
if return_logprob:
|
432
|
-
ret["meta_info"]["
|
433
|
-
ret["meta_info"]["
|
468
|
+
ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
469
|
+
ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
|
434
470
|
)
|
435
|
-
ret["meta_info"]["
|
436
|
-
ret["meta_info"]["
|
471
|
+
ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
472
|
+
ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
|
437
473
|
)
|
438
474
|
|
439
475
|
if top_logprobs_num > 0:
|
440
|
-
ret["meta_info"]["
|
476
|
+
ret["meta_info"]["input_top_logprobs"] = (
|
441
477
|
self.detokenize_top_logprobs_tokens(
|
442
|
-
ret["meta_info"]["
|
478
|
+
ret["meta_info"]["input_top_logprobs"],
|
443
479
|
return_text_in_logprobs,
|
444
480
|
)
|
445
481
|
)
|
446
|
-
ret["meta_info"]["
|
482
|
+
ret["meta_info"]["output_top_logprobs"] = (
|
447
483
|
self.detokenize_top_logprobs_tokens(
|
448
|
-
ret["meta_info"]["
|
484
|
+
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
|
449
485
|
)
|
450
486
|
)
|
451
487
|
return ret
|
452
488
|
|
453
|
-
def detokenize_logprob_tokens(
|
489
|
+
def detokenize_logprob_tokens(
|
490
|
+
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
|
491
|
+
):
|
454
492
|
if not decode_to_text:
|
455
493
|
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
456
494
|
|
@@ -461,10 +499,14 @@ class TokenizerManager:
|
|
461
499
|
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
462
500
|
]
|
463
501
|
|
464
|
-
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
|
465
|
-
for
|
466
|
-
|
467
|
-
|
502
|
+
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
503
|
+
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
|
504
|
+
# We should batch all top-k tokens in all positions.
|
505
|
+
for i, token_top_logprobs in enumerate(top_logprobs):
|
506
|
+
if token_top_logprobs:
|
507
|
+
top_logprobs[i] = self.detokenize_logprob_tokens(
|
508
|
+
token_top_logprobs, decode_to_text
|
509
|
+
)
|
468
510
|
return top_logprobs
|
469
511
|
|
470
512
|
|