sglang 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +33 -26
- sglang/api.py +9 -1
- sglang/bench_latency.py +2 -2
- sglang/bench_serving.py +10 -1
- sglang/check_env.py +1 -1
- sglang/lang/backend/litellm.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/interpreter.py +20 -5
- sglang/lang/ir.py +1 -1
- sglang/srt/constrained/__init__.py +15 -0
- sglang/srt/constrained/base_cache.py +15 -0
- sglang/srt/constrained/fsm_cache.py +15 -0
- sglang/srt/constrained/jump_forward.py +15 -0
- sglang/srt/conversation.py +26 -0
- sglang/srt/hf_transformers_utils.py +15 -0
- sglang/srt/layers/context_flashattention_nopad.py +15 -0
- sglang/srt/layers/extend_attention.py +15 -0
- sglang/srt/layers/fused_moe.py +15 -0
- sglang/srt/layers/linear.py +15 -0
- sglang/srt/layers/logits_processor.py +41 -13
- sglang/srt/layers/quantization/__init__.py +15 -0
- sglang/srt/layers/quantization/fp8.py +15 -0
- sglang/srt/layers/radix_attention.py +17 -2
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
- sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
- sglang/srt/managers/detokenizer_manager.py +16 -1
- sglang/srt/managers/io_struct.py +36 -3
- sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
- sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +31 -12
- sglang/srt/managers/tokenizer_manager.py +39 -16
- sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +130 -40
- sglang/srt/mem_cache/flush_cache.py +33 -0
- sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
- sglang/srt/{managers/controller → mem_cache}/radix_cache.py +15 -0
- sglang/srt/mm_utils.py +15 -0
- sglang/srt/model_config.py +15 -0
- sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +16 -1
- sglang/srt/{managers/controller → model_executor}/model_runner.py +32 -12
- sglang/srt/model_loader/model_loader.py +15 -0
- sglang/srt/model_loader/utils.py +16 -1
- sglang/srt/models/chatglm.py +16 -1
- sglang/srt/models/commandr.py +16 -1
- sglang/srt/models/dbrx.py +16 -1
- sglang/srt/models/deepseek.py +16 -1
- sglang/srt/models/deepseek_v2.py +16 -1
- sglang/srt/models/gemma.py +16 -1
- sglang/srt/models/gemma2.py +16 -1
- sglang/srt/models/gpt_bigcode.py +16 -1
- sglang/srt/models/grok.py +16 -1
- sglang/srt/models/internlm2.py +16 -1
- sglang/srt/models/llama2.py +16 -1
- sglang/srt/models/llama_classification.py +16 -1
- sglang/srt/models/llava.py +17 -2
- sglang/srt/models/llavavid.py +17 -2
- sglang/srt/models/minicpm.py +16 -1
- sglang/srt/models/mistral.py +15 -0
- sglang/srt/models/mixtral.py +16 -1
- sglang/srt/models/mixtral_quant.py +16 -1
- sglang/srt/models/qwen.py +16 -1
- sglang/srt/models/qwen2.py +16 -1
- sglang/srt/models/qwen2_moe.py +16 -1
- sglang/srt/models/stablelm.py +16 -1
- sglang/srt/models/yivl.py +15 -0
- sglang/srt/openai_api/adapter.py +520 -135
- sglang/srt/openai_api/protocol.py +64 -0
- sglang/srt/sampling_params.py +15 -0
- sglang/srt/server.py +89 -23
- sglang/srt/server_args.py +49 -11
- sglang/srt/utils.py +15 -0
- sglang/utils.py +22 -0
- sglang/version.py +1 -1
- {sglang-0.2.6.dist-info → sglang-0.2.7.dist-info}/METADATA +32 -6
- sglang-0.2.7.dist-info/RECORD +93 -0
- {sglang-0.2.6.dist-info → sglang-0.2.7.dist-info}/WHEEL +1 -1
- sglang/srt/flush_cache.py +0 -18
- sglang-0.2.6.dist-info/RECORD +0 -93
- {sglang-0.2.6.dist-info → sglang-0.2.7.dist-info}/LICENSE +0 -0
- {sglang-0.2.6.dist-info → sglang-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""
|
2
17
|
A controller that manages multiple data parallel workers.
|
3
18
|
Each data parallel worker can manage multiple tensor parallel workers.
|
@@ -12,7 +27,7 @@ from enum import Enum, auto
|
|
12
27
|
import numpy as np
|
13
28
|
import zmq
|
14
29
|
|
15
|
-
from sglang.srt.managers.
|
30
|
+
from sglang.srt.managers.controller_single import (
|
16
31
|
start_controller_process as start_controller_process_single,
|
17
32
|
)
|
18
33
|
from sglang.srt.managers.io_struct import (
|
@@ -24,7 +39,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|
24
39
|
from sglang.srt.utils import kill_parent_process
|
25
40
|
from sglang.utils import get_exception_traceback
|
26
41
|
|
27
|
-
logger = logging.getLogger(
|
42
|
+
logger = logging.getLogger(__name__)
|
28
43
|
|
29
44
|
|
30
45
|
class LoadBalanceMethod(Enum):
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""A controller that manages a group of tensor parallel workers."""
|
2
17
|
|
3
18
|
import logging
|
@@ -7,7 +22,7 @@ from typing import List
|
|
7
22
|
|
8
23
|
import zmq
|
9
24
|
|
10
|
-
from sglang.srt.managers.
|
25
|
+
from sglang.srt.managers.tp_worker import (
|
11
26
|
ModelTpServer,
|
12
27
|
broadcast_recv_input,
|
13
28
|
launch_tp_servers,
|
@@ -16,7 +31,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|
16
31
|
from sglang.srt.utils import kill_parent_process
|
17
32
|
from sglang.utils import get_exception_traceback
|
18
33
|
|
19
|
-
logger = logging.getLogger(
|
34
|
+
logger = logging.getLogger(__name__)
|
20
35
|
|
21
36
|
|
22
37
|
class ControllerSingle:
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
2
17
|
|
3
18
|
import asyncio
|
@@ -10,8 +25,8 @@ import zmq
|
|
10
25
|
import zmq.asyncio
|
11
26
|
|
12
27
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
13
|
-
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
14
28
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
29
|
+
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
15
30
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
16
31
|
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
|
17
32
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""
|
2
17
|
The definition of objects transfered between different
|
3
18
|
processes (TokenizerManager, DetokenizerManager, Controller).
|
@@ -7,7 +22,7 @@ import uuid
|
|
7
22
|
from dataclasses import dataclass
|
8
23
|
from typing import Dict, List, Optional, Union
|
9
24
|
|
10
|
-
from sglang.srt.managers.
|
25
|
+
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
11
26
|
from sglang.srt.sampling_params import SamplingParams
|
12
27
|
|
13
28
|
|
@@ -64,8 +79,26 @@ class GenerateReqInput:
|
|
64
79
|
if self.top_logprobs_num is None:
|
65
80
|
self.top_logprobs_num = 0
|
66
81
|
else:
|
67
|
-
|
68
|
-
|
82
|
+
parallel_sample_num_list = []
|
83
|
+
if isinstance(self.sampling_params, dict):
|
84
|
+
parallel_sample_num = self.sampling_params.get("n", 1)
|
85
|
+
elif isinstance(self.sampling_params, list):
|
86
|
+
for sp in self.sampling_params:
|
87
|
+
parallel_sample_num = sp.get("n", 1)
|
88
|
+
parallel_sample_num_list.append(parallel_sample_num)
|
89
|
+
parallel_sample_num = max(parallel_sample_num_list)
|
90
|
+
all_equal = all(
|
91
|
+
element == parallel_sample_num
|
92
|
+
for element in parallel_sample_num_list
|
93
|
+
)
|
94
|
+
if parallel_sample_num > 1 and (not all_equal):
|
95
|
+
## TODO cope with the case that the parallel_sample_num is different for different samples
|
96
|
+
raise ValueError(
|
97
|
+
"The parallel_sample_num should be the same for all samples in sample params."
|
98
|
+
)
|
99
|
+
else:
|
100
|
+
parallel_sample_num = 1
|
101
|
+
self.parallel_sample_num = parallel_sample_num
|
69
102
|
|
70
103
|
if parallel_sample_num != 1:
|
71
104
|
# parallel sampling +1 represents the original prefill stage
|
@@ -1,46 +1,61 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""Request policy scheduler"""
|
2
17
|
|
3
18
|
import random
|
4
19
|
from collections import defaultdict
|
5
20
|
|
6
21
|
|
7
|
-
class
|
22
|
+
class PolicyScheduler:
|
8
23
|
def __init__(
|
9
24
|
self,
|
10
|
-
|
25
|
+
policy,
|
11
26
|
max_running_seqs,
|
12
27
|
max_prefill_num_tokens,
|
13
28
|
max_total_num_tokens,
|
14
29
|
tree_cache,
|
15
30
|
):
|
16
|
-
if tree_cache.disable and
|
31
|
+
if tree_cache.disable and policy == "lpm":
|
17
32
|
# LMP is meaningless when the tree cache is disabled.
|
18
|
-
|
33
|
+
policy = "fcfs"
|
19
34
|
|
20
|
-
self.
|
35
|
+
self.policy = policy
|
21
36
|
self.max_running_seqs = max_running_seqs
|
22
37
|
self.max_prefill_num_tokens = max_prefill_num_tokens
|
23
38
|
self.max_total_num_tokens = max_total_num_tokens
|
24
39
|
self.tree_cache = tree_cache
|
25
40
|
|
26
|
-
def get_priority_queue(self,
|
27
|
-
if self.
|
41
|
+
def get_priority_queue(self, waiting_queue):
|
42
|
+
if self.policy == "lpm":
|
28
43
|
# longest prefix match
|
29
|
-
|
30
|
-
return
|
31
|
-
elif self.
|
44
|
+
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
|
45
|
+
return waiting_queue
|
46
|
+
elif self.policy == "fcfs":
|
32
47
|
# first come first serve
|
33
|
-
return
|
34
|
-
elif self.
|
48
|
+
return waiting_queue
|
49
|
+
elif self.policy == "lof":
|
35
50
|
# longest output first
|
36
|
-
|
37
|
-
return
|
38
|
-
elif self.
|
39
|
-
random.shuffle(
|
40
|
-
return
|
41
|
-
elif self.
|
51
|
+
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
52
|
+
return waiting_queue
|
53
|
+
elif self.policy == "random":
|
54
|
+
random.shuffle(waiting_queue)
|
55
|
+
return waiting_queue
|
56
|
+
elif self.policy == "dfs-weight":
|
42
57
|
last_node_to_reqs = defaultdict(list)
|
43
|
-
for req in
|
58
|
+
for req in waiting_queue:
|
44
59
|
last_node_to_reqs[req.last_node].append(req)
|
45
60
|
|
46
61
|
node_to_weight = defaultdict(int)
|
@@ -52,10 +67,10 @@ class ScheduleHeuristic:
|
|
52
67
|
self.get_dfs_priority(
|
53
68
|
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
|
54
69
|
)
|
55
|
-
assert len(q) == len(
|
70
|
+
assert len(q) == len(waiting_queue)
|
56
71
|
return q
|
57
72
|
else:
|
58
|
-
raise ValueError(f"Unknown
|
73
|
+
raise ValueError(f"Unknown schedule_policy: {self.policy}")
|
59
74
|
|
60
75
|
def calc_weight(self, cur_node, node_to_weight):
|
61
76
|
for child in cur_node.children.values():
|
@@ -1,5 +1,21 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""Meta data for requests and batches"""
|
2
17
|
|
18
|
+
import logging
|
3
19
|
import warnings
|
4
20
|
from dataclasses import dataclass
|
5
21
|
from enum import IntEnum, auto
|
@@ -12,8 +28,8 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
|
12
28
|
from sglang.global_config import global_config
|
13
29
|
from sglang.srt.constrained import RegexGuide
|
14
30
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
15
|
-
from sglang.srt.
|
16
|
-
from sglang.srt.
|
31
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
32
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache
|
17
33
|
|
18
34
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
19
35
|
|
@@ -25,6 +41,9 @@ global_server_args_dict = {
|
|
25
41
|
}
|
26
42
|
|
27
43
|
|
44
|
+
logger = logging.getLogger(__name__)
|
45
|
+
|
46
|
+
|
28
47
|
class ForwardMode(IntEnum):
|
29
48
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
30
49
|
PREFILL = auto()
|
@@ -364,7 +383,7 @@ class Batch:
|
|
364
383
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
365
384
|
|
366
385
|
if out_cache_loc is None:
|
367
|
-
|
386
|
+
logger.error("Prefill out of memory. This should never happen.")
|
368
387
|
self.tree_cache.pretty_print()
|
369
388
|
exit()
|
370
389
|
|
@@ -598,7 +617,7 @@ class Batch:
|
|
598
617
|
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
599
618
|
|
600
619
|
if self.out_cache_loc is None:
|
601
|
-
|
620
|
+
logger.error("Decode out of memory. This should never happen.")
|
602
621
|
self.tree_cache.pretty_print()
|
603
622
|
exit()
|
604
623
|
|
@@ -762,7 +781,7 @@ class InputMetadata:
|
|
762
781
|
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
763
782
|
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
764
783
|
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
765
|
-
|
784
|
+
flashinfer_use_ragged: bool = False
|
766
785
|
|
767
786
|
@classmethod
|
768
787
|
def create(
|
@@ -778,10 +797,10 @@ class InputMetadata:
|
|
778
797
|
return_logprob=False,
|
779
798
|
skip_flashinfer_init=False,
|
780
799
|
):
|
781
|
-
|
800
|
+
flashinfer_use_ragged = False
|
782
801
|
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
783
802
|
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
784
|
-
|
803
|
+
flashinfer_use_ragged = True
|
785
804
|
init_flashinfer_args(
|
786
805
|
forward_mode,
|
787
806
|
model_runner,
|
@@ -789,7 +808,7 @@ class InputMetadata:
|
|
789
808
|
seq_lens,
|
790
809
|
prefix_lens,
|
791
810
|
model_runner.flashinfer_decode_wrapper,
|
792
|
-
|
811
|
+
flashinfer_use_ragged,
|
793
812
|
)
|
794
813
|
|
795
814
|
batch_size = len(req_pool_indices)
|
@@ -844,7 +863,7 @@ class InputMetadata:
|
|
844
863
|
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
845
864
|
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
846
865
|
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
847
|
-
|
866
|
+
flashinfer_use_ragged=flashinfer_use_ragged,
|
848
867
|
)
|
849
868
|
|
850
869
|
if model_runner.server_args.disable_flashinfer:
|
@@ -865,7 +884,7 @@ def init_flashinfer_args(
|
|
865
884
|
seq_lens,
|
866
885
|
prefix_lens,
|
867
886
|
flashinfer_decode_wrapper,
|
868
|
-
|
887
|
+
flashinfer_use_ragged=False,
|
869
888
|
):
|
870
889
|
"""Init auxiliary variables for FlashInfer attention backend."""
|
871
890
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
@@ -874,7 +893,7 @@ def init_flashinfer_args(
|
|
874
893
|
batch_size = len(req_pool_indices)
|
875
894
|
total_num_tokens = int(torch.sum(seq_lens))
|
876
895
|
|
877
|
-
if
|
896
|
+
if flashinfer_use_ragged:
|
878
897
|
paged_kernel_lens = prefix_lens
|
879
898
|
else:
|
880
899
|
paged_kernel_lens = seq_lens
|
@@ -910,7 +929,7 @@ def init_flashinfer_args(
|
|
910
929
|
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
911
930
|
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
912
931
|
|
913
|
-
if
|
932
|
+
if flashinfer_use_ragged:
|
914
933
|
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
915
934
|
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
916
935
|
qo_indptr,
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""TokenizerManager is a process that tokenizes the text."""
|
2
17
|
|
3
18
|
import asyncio
|
@@ -6,7 +21,7 @@ import dataclasses
|
|
6
21
|
import logging
|
7
22
|
import multiprocessing as mp
|
8
23
|
import os
|
9
|
-
from typing import Dict, List
|
24
|
+
from typing import Dict, List, Tuple
|
10
25
|
|
11
26
|
import numpy as np
|
12
27
|
import transformers
|
@@ -69,6 +84,7 @@ class TokenizerManager:
|
|
69
84
|
trust_remote_code=server_args.trust_remote_code,
|
70
85
|
model_overide_args=model_overide_args,
|
71
86
|
)
|
87
|
+
|
72
88
|
if server_args.context_length is not None:
|
73
89
|
self.context_len = server_args.context_length
|
74
90
|
else:
|
@@ -137,31 +153,33 @@ class TokenizerManager:
|
|
137
153
|
self, obj, request, index=None, is_cache_for_prefill=False
|
138
154
|
):
|
139
155
|
if not is_cache_for_prefill:
|
140
|
-
|
141
|
-
|
156
|
+
not_use_index = not (index is not None)
|
157
|
+
rid = obj.rid if not_use_index else obj.rid[index]
|
158
|
+
input_text = obj.text if not_use_index else obj.text[index]
|
142
159
|
input_ids = (
|
143
160
|
self.tokenizer.encode(input_text)
|
144
161
|
if obj.input_ids is None
|
145
162
|
else obj.input_ids
|
146
163
|
)
|
147
|
-
if
|
164
|
+
if not not_use_index and obj.input_ids:
|
148
165
|
input_ids = obj.input_ids[index]
|
149
166
|
|
150
167
|
self._validate_input_length(input_ids)
|
168
|
+
|
151
169
|
sampling_params = self._get_sampling_params(
|
152
|
-
obj.sampling_params if
|
170
|
+
obj.sampling_params if not_use_index else obj.sampling_params[index]
|
153
171
|
)
|
154
172
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
155
|
-
obj.image_data if
|
173
|
+
obj.image_data if not_use_index else obj.image_data[index]
|
156
174
|
)
|
157
175
|
return_logprob = (
|
158
|
-
obj.return_logprob if
|
176
|
+
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
159
177
|
)
|
160
178
|
logprob_start_len = (
|
161
|
-
obj.logprob_start_len if
|
179
|
+
obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
|
162
180
|
)
|
163
181
|
top_logprobs_num = (
|
164
|
-
obj.top_logprobs_num if
|
182
|
+
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
|
165
183
|
)
|
166
184
|
else:
|
167
185
|
if isinstance(obj.text, list):
|
@@ -209,7 +227,7 @@ class TokenizerManager:
|
|
209
227
|
|
210
228
|
async def _handle_batch_request(self, obj: GenerateReqInput, request):
|
211
229
|
batch_size = obj.batch_size
|
212
|
-
parallel_sample_num = obj.
|
230
|
+
parallel_sample_num = obj.parallel_sample_num
|
213
231
|
|
214
232
|
if parallel_sample_num != 1:
|
215
233
|
# Send prefill requests to cache the common input
|
@@ -226,7 +244,6 @@ class TokenizerManager:
|
|
226
244
|
obj.input_ids = input_id_result
|
227
245
|
elif input_id_result is not None:
|
228
246
|
obj.input_ids = input_id_result[0]
|
229
|
-
|
230
247
|
# First send out all requests
|
231
248
|
for i in range(batch_size):
|
232
249
|
for j in range(parallel_sample_num):
|
@@ -234,7 +251,7 @@ class TokenizerManager:
|
|
234
251
|
continue
|
235
252
|
index = i * parallel_sample_num + j
|
236
253
|
if parallel_sample_num != 1:
|
237
|
-
# Here when using parallel sampling we
|
254
|
+
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
238
255
|
index += batch_size - 1 - i
|
239
256
|
rid = obj.rid[index]
|
240
257
|
if parallel_sample_num == 1:
|
@@ -469,7 +486,9 @@ class TokenizerManager:
|
|
469
486
|
)
|
470
487
|
return ret
|
471
488
|
|
472
|
-
def detokenize_logprob_tokens(
|
489
|
+
def detokenize_logprob_tokens(
|
490
|
+
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
|
491
|
+
):
|
473
492
|
if not decode_to_text:
|
474
493
|
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
475
494
|
|
@@ -481,9 +500,13 @@ class TokenizerManager:
|
|
481
500
|
]
|
482
501
|
|
483
502
|
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
484
|
-
for
|
485
|
-
|
486
|
-
|
503
|
+
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
|
504
|
+
# We should batch all top-k tokens in all positions.
|
505
|
+
for i, token_top_logprobs in enumerate(top_logprobs):
|
506
|
+
if token_top_logprobs:
|
507
|
+
top_logprobs[i] = self.detokenize_logprob_tokens(
|
508
|
+
token_top_logprobs, decode_to_text
|
509
|
+
)
|
487
510
|
return top_logprobs
|
488
511
|
|
489
512
|
|