sglang 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +33 -26
- sglang/api.py +9 -1
- sglang/bench_latency.py +2 -2
- sglang/bench_serving.py +10 -1
- sglang/check_env.py +1 -1
- sglang/lang/backend/litellm.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/interpreter.py +21 -5
- sglang/lang/ir.py +1 -2
- sglang/srt/constrained/__init__.py +15 -0
- sglang/srt/constrained/{base_cache.py → base_tool_cache.py} +17 -2
- sglang/srt/constrained/fsm_cache.py +17 -2
- sglang/srt/constrained/jump_forward.py +17 -2
- sglang/srt/conversation.py +26 -0
- sglang/srt/hf_transformers_utils.py +15 -0
- sglang/srt/layers/context_flashattention_nopad.py +15 -0
- sglang/srt/layers/extend_attention.py +15 -0
- sglang/srt/layers/fused_moe.py +15 -0
- sglang/srt/layers/linear.py +15 -0
- sglang/srt/layers/logits_processor.py +41 -13
- sglang/srt/layers/quantization/__init__.py +15 -0
- sglang/srt/layers/quantization/fp8.py +15 -0
- sglang/srt/layers/radix_attention.py +17 -2
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
- sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
- sglang/srt/managers/detokenizer_manager.py +16 -1
- sglang/srt/managers/io_struct.py +36 -3
- sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
- sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +60 -21
- sglang/srt/managers/tokenizer_manager.py +39 -16
- sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +159 -46
- sglang/srt/mem_cache/base_cache.py +43 -0
- sglang/srt/mem_cache/chunk_cache.py +60 -0
- sglang/srt/mem_cache/flush_cache.py +33 -0
- sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
- sglang/srt/{managers/controller → mem_cache}/radix_cache.py +20 -2
- sglang/srt/mm_utils.py +15 -0
- sglang/srt/model_config.py +15 -0
- sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +16 -1
- sglang/srt/{managers/controller → model_executor}/model_runner.py +49 -14
- sglang/srt/model_loader/model_loader.py +15 -0
- sglang/srt/model_loader/utils.py +16 -1
- sglang/srt/models/chatglm.py +16 -1
- sglang/srt/models/commandr.py +16 -1
- sglang/srt/models/dbrx.py +16 -1
- sglang/srt/models/deepseek.py +16 -1
- sglang/srt/models/deepseek_v2.py +16 -1
- sglang/srt/models/gemma.py +16 -1
- sglang/srt/models/gemma2.py +16 -1
- sglang/srt/models/gpt_bigcode.py +16 -1
- sglang/srt/models/grok.py +16 -1
- sglang/srt/models/internlm2.py +16 -1
- sglang/srt/models/llama2.py +21 -22
- sglang/srt/models/llama_classification.py +16 -1
- sglang/srt/models/llava.py +17 -2
- sglang/srt/models/llavavid.py +17 -2
- sglang/srt/models/minicpm.py +16 -1
- sglang/srt/models/mistral.py +15 -0
- sglang/srt/models/mixtral.py +16 -1
- sglang/srt/models/mixtral_quant.py +16 -1
- sglang/srt/models/qwen.py +16 -1
- sglang/srt/models/qwen2.py +16 -1
- sglang/srt/models/qwen2_moe.py +16 -1
- sglang/srt/models/stablelm.py +16 -1
- sglang/srt/models/yivl.py +15 -0
- sglang/srt/openai_api/adapter.py +569 -131
- sglang/srt/openai_api/protocol.py +84 -2
- sglang/srt/sampling_params.py +15 -0
- sglang/srt/server.py +92 -23
- sglang/srt/server_args.py +52 -11
- sglang/srt/utils.py +15 -0
- sglang/test/test_programs.py +9 -6
- sglang/utils.py +22 -0
- sglang/version.py +1 -1
- {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/METADATA +33 -7
- sglang-0.2.8.dist-info/RECORD +95 -0
- {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/WHEEL +1 -1
- sglang/srt/flush_cache.py +0 -18
- sglang-0.2.6.dist-info/RECORD +0 -93
- {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/LICENSE +0 -0
- {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
# Adapted from
|
2
17
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
3
18
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
@@ -5,7 +20,7 @@ import torch
|
|
5
20
|
import triton
|
6
21
|
import triton.language as tl
|
7
22
|
|
8
|
-
from sglang.srt.managers.
|
23
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
24
|
|
10
25
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
11
26
|
REDUCE_TRITON_TYPE = tl.float32
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""
|
2
17
|
A controller that manages multiple data parallel workers.
|
3
18
|
Each data parallel worker can manage multiple tensor parallel workers.
|
@@ -12,7 +27,7 @@ from enum import Enum, auto
|
|
12
27
|
import numpy as np
|
13
28
|
import zmq
|
14
29
|
|
15
|
-
from sglang.srt.managers.
|
30
|
+
from sglang.srt.managers.controller_single import (
|
16
31
|
start_controller_process as start_controller_process_single,
|
17
32
|
)
|
18
33
|
from sglang.srt.managers.io_struct import (
|
@@ -24,7 +39,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|
24
39
|
from sglang.srt.utils import kill_parent_process
|
25
40
|
from sglang.utils import get_exception_traceback
|
26
41
|
|
27
|
-
logger = logging.getLogger(
|
42
|
+
logger = logging.getLogger(__name__)
|
28
43
|
|
29
44
|
|
30
45
|
class LoadBalanceMethod(Enum):
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""A controller that manages a group of tensor parallel workers."""
|
2
17
|
|
3
18
|
import logging
|
@@ -7,7 +22,7 @@ from typing import List
|
|
7
22
|
|
8
23
|
import zmq
|
9
24
|
|
10
|
-
from sglang.srt.managers.
|
25
|
+
from sglang.srt.managers.tp_worker import (
|
11
26
|
ModelTpServer,
|
12
27
|
broadcast_recv_input,
|
13
28
|
launch_tp_servers,
|
@@ -16,7 +31,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|
16
31
|
from sglang.srt.utils import kill_parent_process
|
17
32
|
from sglang.utils import get_exception_traceback
|
18
33
|
|
19
|
-
logger = logging.getLogger(
|
34
|
+
logger = logging.getLogger(__name__)
|
20
35
|
|
21
36
|
|
22
37
|
class ControllerSingle:
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
2
17
|
|
3
18
|
import asyncio
|
@@ -10,8 +25,8 @@ import zmq
|
|
10
25
|
import zmq.asyncio
|
11
26
|
|
12
27
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
13
|
-
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
14
28
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
29
|
+
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
15
30
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
16
31
|
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
|
17
32
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""
|
2
17
|
The definition of objects transfered between different
|
3
18
|
processes (TokenizerManager, DetokenizerManager, Controller).
|
@@ -7,7 +22,7 @@ import uuid
|
|
7
22
|
from dataclasses import dataclass
|
8
23
|
from typing import Dict, List, Optional, Union
|
9
24
|
|
10
|
-
from sglang.srt.managers.
|
25
|
+
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
11
26
|
from sglang.srt.sampling_params import SamplingParams
|
12
27
|
|
13
28
|
|
@@ -64,8 +79,26 @@ class GenerateReqInput:
|
|
64
79
|
if self.top_logprobs_num is None:
|
65
80
|
self.top_logprobs_num = 0
|
66
81
|
else:
|
67
|
-
|
68
|
-
|
82
|
+
parallel_sample_num_list = []
|
83
|
+
if isinstance(self.sampling_params, dict):
|
84
|
+
parallel_sample_num = self.sampling_params.get("n", 1)
|
85
|
+
elif isinstance(self.sampling_params, list):
|
86
|
+
for sp in self.sampling_params:
|
87
|
+
parallel_sample_num = sp.get("n", 1)
|
88
|
+
parallel_sample_num_list.append(parallel_sample_num)
|
89
|
+
parallel_sample_num = max(parallel_sample_num_list)
|
90
|
+
all_equal = all(
|
91
|
+
element == parallel_sample_num
|
92
|
+
for element in parallel_sample_num_list
|
93
|
+
)
|
94
|
+
if parallel_sample_num > 1 and (not all_equal):
|
95
|
+
## TODO cope with the case that the parallel_sample_num is different for different samples
|
96
|
+
raise ValueError(
|
97
|
+
"The parallel_sample_num should be the same for all samples in sample params."
|
98
|
+
)
|
99
|
+
else:
|
100
|
+
parallel_sample_num = 1
|
101
|
+
self.parallel_sample_num = parallel_sample_num
|
69
102
|
|
70
103
|
if parallel_sample_num != 1:
|
71
104
|
# parallel sampling +1 represents the original prefill stage
|
@@ -1,46 +1,61 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""Request policy scheduler"""
|
2
17
|
|
3
18
|
import random
|
4
19
|
from collections import defaultdict
|
5
20
|
|
6
21
|
|
7
|
-
class
|
22
|
+
class PolicyScheduler:
|
8
23
|
def __init__(
|
9
24
|
self,
|
10
|
-
|
25
|
+
policy,
|
11
26
|
max_running_seqs,
|
12
27
|
max_prefill_num_tokens,
|
13
28
|
max_total_num_tokens,
|
14
29
|
tree_cache,
|
15
30
|
):
|
16
|
-
if tree_cache.disable and
|
31
|
+
if tree_cache.disable and policy == "lpm":
|
17
32
|
# LMP is meaningless when the tree cache is disabled.
|
18
|
-
|
33
|
+
policy = "fcfs"
|
19
34
|
|
20
|
-
self.
|
35
|
+
self.policy = policy
|
21
36
|
self.max_running_seqs = max_running_seqs
|
22
37
|
self.max_prefill_num_tokens = max_prefill_num_tokens
|
23
38
|
self.max_total_num_tokens = max_total_num_tokens
|
24
39
|
self.tree_cache = tree_cache
|
25
40
|
|
26
|
-
def get_priority_queue(self,
|
27
|
-
if self.
|
41
|
+
def get_priority_queue(self, waiting_queue):
|
42
|
+
if self.policy == "lpm":
|
28
43
|
# longest prefix match
|
29
|
-
|
30
|
-
return
|
31
|
-
elif self.
|
44
|
+
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
|
45
|
+
return waiting_queue
|
46
|
+
elif self.policy == "fcfs":
|
32
47
|
# first come first serve
|
33
|
-
return
|
34
|
-
elif self.
|
48
|
+
return waiting_queue
|
49
|
+
elif self.policy == "lof":
|
35
50
|
# longest output first
|
36
|
-
|
37
|
-
return
|
38
|
-
elif self.
|
39
|
-
random.shuffle(
|
40
|
-
return
|
41
|
-
elif self.
|
51
|
+
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
52
|
+
return waiting_queue
|
53
|
+
elif self.policy == "random":
|
54
|
+
random.shuffle(waiting_queue)
|
55
|
+
return waiting_queue
|
56
|
+
elif self.policy == "dfs-weight":
|
42
57
|
last_node_to_reqs = defaultdict(list)
|
43
|
-
for req in
|
58
|
+
for req in waiting_queue:
|
44
59
|
last_node_to_reqs[req.last_node].append(req)
|
45
60
|
|
46
61
|
node_to_weight = defaultdict(int)
|
@@ -52,10 +67,10 @@ class ScheduleHeuristic:
|
|
52
67
|
self.get_dfs_priority(
|
53
68
|
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
|
54
69
|
)
|
55
|
-
assert len(q) == len(
|
70
|
+
assert len(q) == len(waiting_queue)
|
56
71
|
return q
|
57
72
|
else:
|
58
|
-
raise ValueError(f"Unknown
|
73
|
+
raise ValueError(f"Unknown schedule_policy: {self.policy}")
|
59
74
|
|
60
75
|
def calc_weight(self, cur_node, node_to_weight):
|
61
76
|
for child in cur_node.children.values():
|
@@ -1,5 +1,21 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""Meta data for requests and batches"""
|
2
17
|
|
18
|
+
import logging
|
3
19
|
import warnings
|
4
20
|
from dataclasses import dataclass
|
5
21
|
from enum import IntEnum, auto
|
@@ -12,8 +28,9 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
|
12
28
|
from sglang.global_config import global_config
|
13
29
|
from sglang.srt.constrained import RegexGuide
|
14
30
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
15
|
-
from sglang.srt.
|
16
|
-
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
31
|
+
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
32
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
33
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache
|
17
34
|
|
18
35
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
19
36
|
|
@@ -25,6 +42,9 @@ global_server_args_dict = {
|
|
25
42
|
}
|
26
43
|
|
27
44
|
|
45
|
+
logger = logging.getLogger(__name__)
|
46
|
+
|
47
|
+
|
28
48
|
class ForwardMode(IntEnum):
|
29
49
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
30
50
|
PREFILL = auto()
|
@@ -364,7 +384,7 @@ class Batch:
|
|
364
384
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
365
385
|
|
366
386
|
if out_cache_loc is None:
|
367
|
-
|
387
|
+
logger.error("Prefill out of memory. This should never happen.")
|
368
388
|
self.tree_cache.pretty_print()
|
369
389
|
exit()
|
370
390
|
|
@@ -467,15 +487,33 @@ class Batch:
|
|
467
487
|
req = self.reqs[idx]
|
468
488
|
retracted_reqs.append(req)
|
469
489
|
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
490
|
+
if isinstance(self.tree_cache, ChunkCache):
|
491
|
+
# ChunkCache does not have eviction
|
492
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
493
|
+
req_pool_indices_cpu[idx]
|
494
|
+
][: seq_lens_cpu[idx]]
|
495
|
+
self.token_to_kv_pool.free(token_indices)
|
496
|
+
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
497
|
+
del self.tree_cache.entries[req.rid]
|
498
|
+
else:
|
499
|
+
# TODO: apply more fine-grained retraction
|
500
|
+
last_uncached_pos = len(req.prefix_indices)
|
501
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
502
|
+
req_pool_indices_cpu[idx]
|
503
|
+
][last_uncached_pos : seq_lens_cpu[idx]]
|
504
|
+
self.token_to_kv_pool.free(token_indices)
|
505
|
+
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
506
|
+
|
507
|
+
# release the last node
|
508
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
509
|
+
|
510
|
+
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
511
|
+
residual_size = (
|
512
|
+
len(sorted_indices) * global_config.retract_decode_steps
|
513
|
+
- self.token_to_kv_pool.available_size()
|
514
|
+
)
|
515
|
+
residual_size = max(0, residual_size)
|
516
|
+
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
479
517
|
|
480
518
|
req.prefix_indices = None
|
481
519
|
req.last_node = None
|
@@ -556,6 +594,7 @@ class Batch:
|
|
556
594
|
if req_pool_indices_cpu is None:
|
557
595
|
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
558
596
|
self.tree_cache.cache_req(
|
597
|
+
rid=req.rid,
|
559
598
|
token_ids=cur_all_ids,
|
560
599
|
last_uncached_pos=len(req.prefix_indices),
|
561
600
|
req_pool_idx=req_pool_indices_cpu[i],
|
@@ -598,7 +637,7 @@ class Batch:
|
|
598
637
|
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
599
638
|
|
600
639
|
if self.out_cache_loc is None:
|
601
|
-
|
640
|
+
logger.error("Decode out of memory. This should never happen.")
|
602
641
|
self.tree_cache.pretty_print()
|
603
642
|
exit()
|
604
643
|
|
@@ -762,7 +801,7 @@ class InputMetadata:
|
|
762
801
|
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
763
802
|
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
764
803
|
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
765
|
-
|
804
|
+
flashinfer_use_ragged: bool = False
|
766
805
|
|
767
806
|
@classmethod
|
768
807
|
def create(
|
@@ -778,10 +817,10 @@ class InputMetadata:
|
|
778
817
|
return_logprob=False,
|
779
818
|
skip_flashinfer_init=False,
|
780
819
|
):
|
781
|
-
|
820
|
+
flashinfer_use_ragged = False
|
782
821
|
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
783
822
|
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
784
|
-
|
823
|
+
flashinfer_use_ragged = True
|
785
824
|
init_flashinfer_args(
|
786
825
|
forward_mode,
|
787
826
|
model_runner,
|
@@ -789,7 +828,7 @@ class InputMetadata:
|
|
789
828
|
seq_lens,
|
790
829
|
prefix_lens,
|
791
830
|
model_runner.flashinfer_decode_wrapper,
|
792
|
-
|
831
|
+
flashinfer_use_ragged,
|
793
832
|
)
|
794
833
|
|
795
834
|
batch_size = len(req_pool_indices)
|
@@ -844,7 +883,7 @@ class InputMetadata:
|
|
844
883
|
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
845
884
|
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
846
885
|
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
847
|
-
|
886
|
+
flashinfer_use_ragged=flashinfer_use_ragged,
|
848
887
|
)
|
849
888
|
|
850
889
|
if model_runner.server_args.disable_flashinfer:
|
@@ -865,7 +904,7 @@ def init_flashinfer_args(
|
|
865
904
|
seq_lens,
|
866
905
|
prefix_lens,
|
867
906
|
flashinfer_decode_wrapper,
|
868
|
-
|
907
|
+
flashinfer_use_ragged=False,
|
869
908
|
):
|
870
909
|
"""Init auxiliary variables for FlashInfer attention backend."""
|
871
910
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
@@ -874,7 +913,7 @@ def init_flashinfer_args(
|
|
874
913
|
batch_size = len(req_pool_indices)
|
875
914
|
total_num_tokens = int(torch.sum(seq_lens))
|
876
915
|
|
877
|
-
if
|
916
|
+
if flashinfer_use_ragged:
|
878
917
|
paged_kernel_lens = prefix_lens
|
879
918
|
else:
|
880
919
|
paged_kernel_lens = seq_lens
|
@@ -910,7 +949,7 @@ def init_flashinfer_args(
|
|
910
949
|
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
911
950
|
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
912
951
|
|
913
|
-
if
|
952
|
+
if flashinfer_use_ragged:
|
914
953
|
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
915
954
|
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
916
955
|
qo_indptr,
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""TokenizerManager is a process that tokenizes the text."""
|
2
17
|
|
3
18
|
import asyncio
|
@@ -6,7 +21,7 @@ import dataclasses
|
|
6
21
|
import logging
|
7
22
|
import multiprocessing as mp
|
8
23
|
import os
|
9
|
-
from typing import Dict, List
|
24
|
+
from typing import Dict, List, Tuple
|
10
25
|
|
11
26
|
import numpy as np
|
12
27
|
import transformers
|
@@ -69,6 +84,7 @@ class TokenizerManager:
|
|
69
84
|
trust_remote_code=server_args.trust_remote_code,
|
70
85
|
model_overide_args=model_overide_args,
|
71
86
|
)
|
87
|
+
|
72
88
|
if server_args.context_length is not None:
|
73
89
|
self.context_len = server_args.context_length
|
74
90
|
else:
|
@@ -137,31 +153,33 @@ class TokenizerManager:
|
|
137
153
|
self, obj, request, index=None, is_cache_for_prefill=False
|
138
154
|
):
|
139
155
|
if not is_cache_for_prefill:
|
140
|
-
|
141
|
-
|
156
|
+
not_use_index = not (index is not None)
|
157
|
+
rid = obj.rid if not_use_index else obj.rid[index]
|
158
|
+
input_text = obj.text if not_use_index else obj.text[index]
|
142
159
|
input_ids = (
|
143
160
|
self.tokenizer.encode(input_text)
|
144
161
|
if obj.input_ids is None
|
145
162
|
else obj.input_ids
|
146
163
|
)
|
147
|
-
if
|
164
|
+
if not not_use_index and obj.input_ids:
|
148
165
|
input_ids = obj.input_ids[index]
|
149
166
|
|
150
167
|
self._validate_input_length(input_ids)
|
168
|
+
|
151
169
|
sampling_params = self._get_sampling_params(
|
152
|
-
obj.sampling_params if
|
170
|
+
obj.sampling_params if not_use_index else obj.sampling_params[index]
|
153
171
|
)
|
154
172
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
155
|
-
obj.image_data if
|
173
|
+
obj.image_data if not_use_index else obj.image_data[index]
|
156
174
|
)
|
157
175
|
return_logprob = (
|
158
|
-
obj.return_logprob if
|
176
|
+
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
159
177
|
)
|
160
178
|
logprob_start_len = (
|
161
|
-
obj.logprob_start_len if
|
179
|
+
obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
|
162
180
|
)
|
163
181
|
top_logprobs_num = (
|
164
|
-
obj.top_logprobs_num if
|
182
|
+
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
|
165
183
|
)
|
166
184
|
else:
|
167
185
|
if isinstance(obj.text, list):
|
@@ -209,7 +227,7 @@ class TokenizerManager:
|
|
209
227
|
|
210
228
|
async def _handle_batch_request(self, obj: GenerateReqInput, request):
|
211
229
|
batch_size = obj.batch_size
|
212
|
-
parallel_sample_num = obj.
|
230
|
+
parallel_sample_num = obj.parallel_sample_num
|
213
231
|
|
214
232
|
if parallel_sample_num != 1:
|
215
233
|
# Send prefill requests to cache the common input
|
@@ -226,7 +244,6 @@ class TokenizerManager:
|
|
226
244
|
obj.input_ids = input_id_result
|
227
245
|
elif input_id_result is not None:
|
228
246
|
obj.input_ids = input_id_result[0]
|
229
|
-
|
230
247
|
# First send out all requests
|
231
248
|
for i in range(batch_size):
|
232
249
|
for j in range(parallel_sample_num):
|
@@ -234,7 +251,7 @@ class TokenizerManager:
|
|
234
251
|
continue
|
235
252
|
index = i * parallel_sample_num + j
|
236
253
|
if parallel_sample_num != 1:
|
237
|
-
# Here when using parallel sampling we
|
254
|
+
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
238
255
|
index += batch_size - 1 - i
|
239
256
|
rid = obj.rid[index]
|
240
257
|
if parallel_sample_num == 1:
|
@@ -469,7 +486,9 @@ class TokenizerManager:
|
|
469
486
|
)
|
470
487
|
return ret
|
471
488
|
|
472
|
-
def detokenize_logprob_tokens(
|
489
|
+
def detokenize_logprob_tokens(
|
490
|
+
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
|
491
|
+
):
|
473
492
|
if not decode_to_text:
|
474
493
|
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
475
494
|
|
@@ -481,9 +500,13 @@ class TokenizerManager:
|
|
481
500
|
]
|
482
501
|
|
483
502
|
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
484
|
-
for
|
485
|
-
|
486
|
-
|
503
|
+
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
|
504
|
+
# We should batch all top-k tokens in all positions.
|
505
|
+
for i, token_top_logprobs in enumerate(top_logprobs):
|
506
|
+
if token_top_logprobs:
|
507
|
+
top_logprobs[i] = self.detokenize_logprob_tokens(
|
508
|
+
token_top_logprobs, decode_to_text
|
509
|
+
)
|
487
510
|
return top_logprobs
|
488
511
|
|
489
512
|
|