sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +151 -40
- sglang/bench_serving.py +46 -22
- sglang/check_env.py +24 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -29
- sglang/lang/choices.py +164 -0
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +14 -5
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +6 -1
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +4 -7
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +174 -380
- sglang/srt/managers/tokenizer_manager.py +197 -112
- sglang/srt/managers/tp_worker.py +299 -364
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +10 -15
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +27 -12
- sglang/srt/model_executor/forward_batch_info.py +319 -0
- sglang/srt/model_executor/model_runner.py +30 -47
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +1 -1
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -2
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +3 -8
- sglang/srt/models/llama2.py +5 -5
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -12
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +189 -39
- sglang/srt/openai_api/protocol.py +43 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +93 -21
- sglang/srt/server_args.py +30 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +21 -3
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.10.dist-info/RECORD +0 -100
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
"""
|
17
|
+
Memory-efficient attention for prefill.
|
18
|
+
It supporst page size = 1.
|
19
|
+
"""
|
20
|
+
|
16
21
|
# Adapted from
|
17
22
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
18
23
|
import torch
|
@@ -20,13 +20,10 @@ from flashinfer.cascade import merge_state
|
|
20
20
|
from torch import nn
|
21
21
|
|
22
22
|
from sglang.global_config import global_config
|
23
|
+
from sglang.srt.layers.decode_attention import decode_attention_fwd
|
23
24
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
24
|
-
from sglang.srt.
|
25
|
-
from sglang.srt.model_executor.model_runner import
|
26
|
-
ForwardMode,
|
27
|
-
InputMetadata,
|
28
|
-
global_server_args_dict,
|
29
|
-
)
|
25
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
26
|
+
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
30
27
|
|
31
28
|
|
32
29
|
class RadixAttention(nn.Module):
|
@@ -98,7 +95,7 @@ class RadixAttention(nn.Module):
|
|
98
95
|
o = torch.empty_like(q)
|
99
96
|
self.store_kv_cache(k, v, input_metadata)
|
100
97
|
|
101
|
-
|
98
|
+
decode_attention_fwd(
|
102
99
|
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
103
100
|
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
104
101
|
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
@@ -25,10 +25,14 @@ import zmq
|
|
25
25
|
import zmq.asyncio
|
26
26
|
|
27
27
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
28
|
-
from sglang.srt.managers.io_struct import
|
28
|
+
from sglang.srt.managers.io_struct import (
|
29
|
+
BatchEmbeddingOut,
|
30
|
+
BatchStrOut,
|
31
|
+
BatchTokenIDOut,
|
32
|
+
)
|
29
33
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
30
34
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
31
|
-
from sglang.utils import find_printable_text, get_exception_traceback
|
35
|
+
from sglang.utils import find_printable_text, get_exception_traceback
|
32
36
|
|
33
37
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
34
38
|
|
@@ -55,20 +59,40 @@ class DetokenizerManager:
|
|
55
59
|
self.send_to_tokenizer = context.socket(zmq.PUSH)
|
56
60
|
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
57
61
|
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
62
|
+
if server_args.skip_tokenizer_init:
|
63
|
+
self.tokenizer = None
|
64
|
+
else:
|
65
|
+
self.tokenizer = get_tokenizer(
|
66
|
+
server_args.tokenizer_path,
|
67
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
68
|
+
trust_remote_code=server_args.trust_remote_code,
|
69
|
+
)
|
63
70
|
|
64
71
|
self.decode_status = {}
|
65
72
|
|
66
73
|
async def handle_loop(self):
|
67
74
|
while True:
|
68
75
|
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
76
|
+
|
77
|
+
if isinstance(recv_obj, BatchEmbeddingOut):
|
78
|
+
self.send_to_tokenizer.send_pyobj(
|
79
|
+
BatchEmbeddingOut(
|
80
|
+
rids=recv_obj.rids,
|
81
|
+
embeddings=recv_obj.embeddings,
|
82
|
+
meta_info=recv_obj.meta_info,
|
83
|
+
finished_reason=recv_obj.finished_reason,
|
84
|
+
)
|
85
|
+
)
|
86
|
+
continue
|
87
|
+
|
69
88
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
70
89
|
bs = len(recv_obj.rids)
|
71
90
|
|
91
|
+
if self.tokenizer is None:
|
92
|
+
# Send BatchTokenIDOut if no tokenizer init'ed.
|
93
|
+
self.send_to_tokenizer.send_pyobj(recv_obj)
|
94
|
+
continue
|
95
|
+
|
72
96
|
# Initialize decode status
|
73
97
|
read_ids, surr_ids = [], []
|
74
98
|
for i in range(bs):
|
@@ -140,8 +164,6 @@ def start_detokenizer_process(
|
|
140
164
|
port_args: PortArgs,
|
141
165
|
pipe_writer,
|
142
166
|
):
|
143
|
-
graceful_registry(inspect.currentframe().f_code.co_name)
|
144
|
-
|
145
167
|
try:
|
146
168
|
manager = DetokenizerManager(server_args, port_args)
|
147
169
|
except Exception:
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -22,6 +22,8 @@ import uuid
|
|
22
22
|
from dataclasses import dataclass
|
23
23
|
from typing import Dict, List, Optional, Union
|
24
24
|
|
25
|
+
import torch
|
26
|
+
|
25
27
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
26
28
|
from sglang.srt.sampling_params import SamplingParams
|
27
29
|
|
@@ -166,6 +168,59 @@ class TokenizedGenerateReqInput:
|
|
166
168
|
stream: bool
|
167
169
|
|
168
170
|
|
171
|
+
@dataclass
|
172
|
+
class EmbeddingReqInput:
|
173
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
174
|
+
text: Optional[Union[List[str], str]] = None
|
175
|
+
# The token ids for text; one can either specify text or input_ids.
|
176
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
177
|
+
# The request id.
|
178
|
+
rid: Optional[Union[List[str], str]] = None
|
179
|
+
# Dummy sampling params for compatibility
|
180
|
+
sampling_params: Union[List[Dict], Dict] = None
|
181
|
+
|
182
|
+
def post_init(self):
|
183
|
+
if (self.text is None and self.input_ids is None) or (
|
184
|
+
self.text is not None and self.input_ids is not None
|
185
|
+
):
|
186
|
+
raise ValueError("Either text or input_ids should be provided.")
|
187
|
+
|
188
|
+
if self.text is not None:
|
189
|
+
is_single = isinstance(self.text, str)
|
190
|
+
else:
|
191
|
+
is_single = isinstance(self.input_ids[0], int)
|
192
|
+
self.is_single = is_single
|
193
|
+
|
194
|
+
if is_single:
|
195
|
+
if self.rid is None:
|
196
|
+
self.rid = uuid.uuid4().hex
|
197
|
+
if self.sampling_params is None:
|
198
|
+
self.sampling_params = {}
|
199
|
+
self.sampling_params["max_new_tokens"] = 1
|
200
|
+
else:
|
201
|
+
# support select operation
|
202
|
+
self.batch_size = (
|
203
|
+
len(self.text) if self.text is not None else len(self.input_ids)
|
204
|
+
)
|
205
|
+
if self.rid is None:
|
206
|
+
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
207
|
+
else:
|
208
|
+
if not isinstance(self.rid, list):
|
209
|
+
raise ValueError("The rid should be a list.")
|
210
|
+
if self.sampling_params is None:
|
211
|
+
self.sampling_params = [{}] * self.batch_size
|
212
|
+
for i in range(self.batch_size):
|
213
|
+
self.sampling_params[i]["max_new_tokens"] = 1
|
214
|
+
|
215
|
+
|
216
|
+
@dataclass
|
217
|
+
class TokenizedEmbeddingReqInput:
|
218
|
+
rid: str
|
219
|
+
input_text: str
|
220
|
+
input_ids: List[int]
|
221
|
+
sampling_params: SamplingParams
|
222
|
+
|
223
|
+
|
169
224
|
@dataclass
|
170
225
|
class BatchTokenIDOut:
|
171
226
|
rids: List[str]
|
@@ -187,6 +242,14 @@ class BatchStrOut:
|
|
187
242
|
finished_reason: List[BaseFinishReason]
|
188
243
|
|
189
244
|
|
245
|
+
@dataclass
|
246
|
+
class BatchEmbeddingOut:
|
247
|
+
rids: List[str]
|
248
|
+
embeddings: List[List[float]]
|
249
|
+
meta_info: List[Dict]
|
250
|
+
finished_reason: List[BaseFinishReason]
|
251
|
+
|
252
|
+
|
190
253
|
@dataclass
|
191
254
|
class FlushCacheReq:
|
192
255
|
pass
|
@@ -15,44 +15,54 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""Request policy scheduler"""
|
17
17
|
|
18
|
+
import os
|
18
19
|
import random
|
19
20
|
from collections import defaultdict
|
21
|
+
from contextlib import contextmanager
|
22
|
+
from typing import Dict, List, Optional
|
23
|
+
|
24
|
+
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
25
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
26
|
+
from sglang.srt.mem_cache.radix_cache import TreeNode
|
27
|
+
|
28
|
+
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
29
|
+
# This can prevent the server from being too conservative.
|
30
|
+
# Note that this only clips the estimation in the scheduler but does not change the stop
|
31
|
+
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
|
32
|
+
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
|
20
33
|
|
21
34
|
|
22
35
|
class PolicyScheduler:
|
23
|
-
def __init__(
|
24
|
-
|
25
|
-
|
26
|
-
max_running_seqs,
|
27
|
-
max_prefill_num_tokens,
|
28
|
-
max_total_num_tokens,
|
29
|
-
tree_cache,
|
30
|
-
):
|
31
|
-
if tree_cache.disable and policy == "lpm":
|
32
|
-
# LMP is meaningless when the tree cache is disabled.
|
36
|
+
def __init__(self, policy: str, tree_cache: BasePrefixCache):
|
37
|
+
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
|
38
|
+
# LPM and DFS-weight is meaningless when the tree cache is disabled.
|
33
39
|
policy = "fcfs"
|
34
40
|
|
35
41
|
self.policy = policy
|
36
|
-
self.max_running_seqs = max_running_seqs
|
37
|
-
self.max_prefill_num_tokens = max_prefill_num_tokens
|
38
|
-
self.max_total_num_tokens = max_total_num_tokens
|
39
42
|
self.tree_cache = tree_cache
|
40
43
|
|
41
|
-
def
|
44
|
+
def calc_priority(self, waiting_queue: List[Req]):
|
45
|
+
# Compute matched prefix length
|
46
|
+
prefix_computed = False
|
47
|
+
if self.policy in ["lpm", "dfs-weight"]:
|
48
|
+
for r in waiting_queue:
|
49
|
+
# NOTE: the prefix_indices must always be aligned with last_node
|
50
|
+
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
51
|
+
rid=r.rid, key=r.adjust_max_prefix_ids()
|
52
|
+
)
|
53
|
+
prefix_computed = True
|
54
|
+
|
42
55
|
if self.policy == "lpm":
|
43
|
-
#
|
56
|
+
# Longest Prefix Match
|
44
57
|
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
|
45
|
-
return waiting_queue
|
46
58
|
elif self.policy == "fcfs":
|
47
59
|
# first come first serve
|
48
|
-
|
60
|
+
pass
|
49
61
|
elif self.policy == "lof":
|
50
62
|
# longest output first
|
51
63
|
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
52
|
-
return waiting_queue
|
53
64
|
elif self.policy == "random":
|
54
65
|
random.shuffle(waiting_queue)
|
55
|
-
return waiting_queue
|
56
66
|
elif self.policy == "dfs-weight":
|
57
67
|
last_node_to_reqs = defaultdict(list)
|
58
68
|
for req in waiting_queue:
|
@@ -63,23 +73,161 @@ class PolicyScheduler:
|
|
63
73
|
node_to_weight[node] = len(last_node_to_reqs[node])
|
64
74
|
self.calc_weight(self.tree_cache.root_node, node_to_weight)
|
65
75
|
|
66
|
-
|
76
|
+
waiting_queue.clear()
|
67
77
|
self.get_dfs_priority(
|
68
|
-
self.tree_cache.root_node,
|
78
|
+
self.tree_cache.root_node,
|
79
|
+
node_to_weight,
|
80
|
+
last_node_to_reqs,
|
81
|
+
waiting_queue,
|
69
82
|
)
|
70
|
-
assert len(q) == len(waiting_queue)
|
71
|
-
return q
|
72
83
|
else:
|
73
84
|
raise ValueError(f"Unknown schedule_policy: {self.policy}")
|
74
85
|
|
75
|
-
|
86
|
+
return prefix_computed
|
87
|
+
|
88
|
+
def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
|
76
89
|
for child in cur_node.children.values():
|
77
90
|
self.calc_weight(child, node_to_weight)
|
78
91
|
node_to_weight[cur_node] += node_to_weight[child]
|
79
92
|
|
80
|
-
def get_dfs_priority(
|
93
|
+
def get_dfs_priority(
|
94
|
+
self,
|
95
|
+
cur_node: TreeNode,
|
96
|
+
node_to_priority: Dict,
|
97
|
+
last_node_to_reqs: Dict,
|
98
|
+
q: List,
|
99
|
+
):
|
81
100
|
childs = [child for child in cur_node.children.values()]
|
82
101
|
childs.sort(key=lambda x: -node_to_priority[x])
|
83
102
|
for child in childs:
|
84
103
|
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
|
85
104
|
q.extend(last_node_to_reqs[cur_node])
|
105
|
+
|
106
|
+
|
107
|
+
class PrefillAdder:
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
tree_cache: BasePrefixCache,
|
111
|
+
rem_total_tokens: int,
|
112
|
+
rem_input_tokens: int,
|
113
|
+
rem_chunk_tokens: Optional[int],
|
114
|
+
):
|
115
|
+
self.tree_cache = tree_cache
|
116
|
+
self.rem_total_tokens = rem_total_tokens
|
117
|
+
self.rem_input_tokens = rem_input_tokens
|
118
|
+
self.rem_chunk_tokens = rem_chunk_tokens
|
119
|
+
|
120
|
+
self.can_run_list = []
|
121
|
+
self.new_inflight_req = None
|
122
|
+
self.log_hit_tokens = 0
|
123
|
+
self.log_input_tokens = 0
|
124
|
+
|
125
|
+
def no_remaining_tokens(self):
|
126
|
+
return (
|
127
|
+
self.rem_total_tokens <= 0
|
128
|
+
or self.rem_input_tokens <= 0
|
129
|
+
or (
|
130
|
+
self.rem_chunk_tokens <= 0
|
131
|
+
if self.rem_chunk_tokens is not None
|
132
|
+
else False
|
133
|
+
)
|
134
|
+
)
|
135
|
+
|
136
|
+
def remove_running_tokens(
|
137
|
+
self, running_batch: ScheduleBatch, new_token_ratio: float
|
138
|
+
):
|
139
|
+
self.rem_total_tokens -= sum(
|
140
|
+
[
|
141
|
+
min(
|
142
|
+
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
143
|
+
CLIP_MAX_NEW_TOKENS,
|
144
|
+
)
|
145
|
+
* new_token_ratio
|
146
|
+
for r in running_batch.reqs
|
147
|
+
]
|
148
|
+
)
|
149
|
+
|
150
|
+
def _prefill_one_req(
|
151
|
+
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
152
|
+
):
|
153
|
+
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
154
|
+
self.rem_input_tokens -= extend_input_len
|
155
|
+
if self.rem_chunk_tokens is not None:
|
156
|
+
self.rem_chunk_tokens -= extend_input_len
|
157
|
+
|
158
|
+
self.log_hit_tokens += prefix_len
|
159
|
+
self.log_input_tokens += extend_input_len
|
160
|
+
|
161
|
+
def add_inflight_req(self, req: Req):
|
162
|
+
truncated = req.extend_input_len > self.rem_chunk_tokens
|
163
|
+
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
164
|
+
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
165
|
+
self.can_run_list.append(req)
|
166
|
+
|
167
|
+
self._prefill_one_req(
|
168
|
+
len(req.prefix_indices),
|
169
|
+
req.extend_input_len,
|
170
|
+
(
|
171
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
172
|
+
if not truncated
|
173
|
+
else 0
|
174
|
+
),
|
175
|
+
)
|
176
|
+
|
177
|
+
# Return if chunked prefill not finished
|
178
|
+
return req if truncated else None
|
179
|
+
|
180
|
+
@contextmanager
|
181
|
+
def _lock_node(self, last_node: TreeNode):
|
182
|
+
try:
|
183
|
+
delta = self.tree_cache.inc_lock_ref(last_node)
|
184
|
+
self.rem_total_tokens += delta
|
185
|
+
yield None
|
186
|
+
finally:
|
187
|
+
delta = self.tree_cache.dec_lock_ref(last_node)
|
188
|
+
self.rem_total_tokens += delta
|
189
|
+
|
190
|
+
def add_one_req(self, req: Req):
|
191
|
+
total_tokens = req.extend_input_len + min(
|
192
|
+
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
|
193
|
+
)
|
194
|
+
input_tokens = req.extend_input_len
|
195
|
+
prefix_len = len(req.prefix_indices)
|
196
|
+
|
197
|
+
if total_tokens >= self.rem_total_tokens:
|
198
|
+
return False
|
199
|
+
|
200
|
+
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
|
201
|
+
return False
|
202
|
+
|
203
|
+
with self._lock_node(req.last_node):
|
204
|
+
if total_tokens > self.rem_total_tokens:
|
205
|
+
return False
|
206
|
+
|
207
|
+
if (
|
208
|
+
self.rem_chunk_tokens is None
|
209
|
+
or input_tokens <= self.rem_chunk_tokens
|
210
|
+
or (req.return_logprob and req.normalized_prompt_logprob is None)
|
211
|
+
):
|
212
|
+
# Non-chunked prefill
|
213
|
+
self.can_run_list.append(req)
|
214
|
+
self.tree_cache.inc_lock_ref(req.last_node)
|
215
|
+
self._prefill_one_req(
|
216
|
+
prefix_len,
|
217
|
+
input_tokens,
|
218
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
219
|
+
)
|
220
|
+
else:
|
221
|
+
# Chunked prefill
|
222
|
+
trunc_len = self.rem_chunk_tokens
|
223
|
+
if trunc_len == 0:
|
224
|
+
return False
|
225
|
+
|
226
|
+
req.extend_input_len = trunc_len
|
227
|
+
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
228
|
+
self.can_run_list.append(req)
|
229
|
+
self.new_inflight_req = req
|
230
|
+
self.tree_cache.inc_lock_ref(req.last_node)
|
231
|
+
self._prefill_one_req(prefix_len, trunc_len, 0)
|
232
|
+
|
233
|
+
return True
|