sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +48 -33
- sglang/bench_server_latency.py +0 -6
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +187 -68
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -247
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +25 -25
- sglang/srt/model_executor/forward_batch_info.py +94 -97
- sglang/srt/model_executor/model_runner.py +76 -78
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +22 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +76 -33
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +53 -9
- sglang/version.py +1 -1
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -482
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.1.post3.dist-info/RECORD +0 -134
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1021 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""A scheduler that manages a tensor parallel GPU worker."""
|
17
|
+
|
18
|
+
import json
|
19
|
+
import logging
|
20
|
+
import multiprocessing
|
21
|
+
import os
|
22
|
+
import time
|
23
|
+
import warnings
|
24
|
+
from typing import List, Optional, Union
|
25
|
+
|
26
|
+
import torch
|
27
|
+
import zmq
|
28
|
+
|
29
|
+
from sglang.global_config import global_config
|
30
|
+
from sglang.srt.configs.model_config import ModelConfig
|
31
|
+
from sglang.srt.constrained.fsm_cache import FSMCache
|
32
|
+
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
33
|
+
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
34
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
35
|
+
from sglang.srt.managers.io_struct import (
|
36
|
+
AbortReq,
|
37
|
+
BatchEmbeddingOut,
|
38
|
+
BatchTokenIDOut,
|
39
|
+
FlushCacheReq,
|
40
|
+
TokenizedEmbeddingReqInput,
|
41
|
+
TokenizedGenerateReqInput,
|
42
|
+
TokenizedRewardReqInput,
|
43
|
+
UpdateWeightReqInput,
|
44
|
+
UpdateWeightReqOutput,
|
45
|
+
)
|
46
|
+
from sglang.srt.managers.schedule_batch import (
|
47
|
+
FINISH_ABORT,
|
48
|
+
BaseFinishReason,
|
49
|
+
ImageInputs,
|
50
|
+
Req,
|
51
|
+
ScheduleBatch,
|
52
|
+
)
|
53
|
+
from sglang.srt.managers.schedule_policy import (
|
54
|
+
AddReqResult,
|
55
|
+
PrefillAdder,
|
56
|
+
SchedulePolicy,
|
57
|
+
)
|
58
|
+
from sglang.srt.managers.tp_worker import TpModelWorker
|
59
|
+
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
60
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache
|
61
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
62
|
+
from sglang.srt.utils import (
|
63
|
+
broadcast_pyobj,
|
64
|
+
configure_logger,
|
65
|
+
is_generation_model,
|
66
|
+
is_multimodal_model,
|
67
|
+
kill_parent_process,
|
68
|
+
pytorch_profile,
|
69
|
+
set_random_seed,
|
70
|
+
suppress_other_loggers,
|
71
|
+
)
|
72
|
+
from sglang.utils import get_exception_traceback
|
73
|
+
|
74
|
+
logger = logging.getLogger(__name__)
|
75
|
+
|
76
|
+
# Crash on warning if we are running CI tests
|
77
|
+
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
78
|
+
|
79
|
+
|
80
|
+
class Scheduler:
|
81
|
+
"""A scheduler that manages a tensor parallel GPU worker."""
|
82
|
+
|
83
|
+
def __init__(
|
84
|
+
self,
|
85
|
+
server_args: ServerArgs,
|
86
|
+
port_args: PortArgs,
|
87
|
+
gpu_id: int,
|
88
|
+
tp_rank: int,
|
89
|
+
):
|
90
|
+
# Parse args
|
91
|
+
self.server_args = server_args
|
92
|
+
self.tp_rank = tp_rank
|
93
|
+
self.tp_size = server_args.tp_size
|
94
|
+
self.schedule_policy = server_args.schedule_policy
|
95
|
+
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
96
|
+
self.lora_paths = server_args.lora_paths
|
97
|
+
self.max_loras_per_batch = server_args.max_loras_per_batch
|
98
|
+
|
99
|
+
# Init inter-process communication
|
100
|
+
context = zmq.Context(2)
|
101
|
+
|
102
|
+
if self.tp_rank == 0:
|
103
|
+
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
104
|
+
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
|
105
|
+
|
106
|
+
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
107
|
+
self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
|
108
|
+
else:
|
109
|
+
self.recv_from_tokenizer = self.send_to_detokenizer = None
|
110
|
+
|
111
|
+
# Init tokenizer
|
112
|
+
self.model_config = ModelConfig(
|
113
|
+
server_args.model_path,
|
114
|
+
server_args.trust_remote_code,
|
115
|
+
context_length=server_args.context_length,
|
116
|
+
model_override_args=json.loads(server_args.json_model_override_args),
|
117
|
+
)
|
118
|
+
|
119
|
+
if server_args.skip_tokenizer_init:
|
120
|
+
self.tokenizer = self.processor = None
|
121
|
+
else:
|
122
|
+
if is_multimodal_model(self.model_config.hf_config.architectures):
|
123
|
+
self.processor = get_processor(
|
124
|
+
server_args.tokenizer_path,
|
125
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
126
|
+
trust_remote_code=server_args.trust_remote_code,
|
127
|
+
)
|
128
|
+
self.tokenizer = self.processor.tokenizer
|
129
|
+
else:
|
130
|
+
self.tokenizer = get_tokenizer(
|
131
|
+
server_args.tokenizer_path,
|
132
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
133
|
+
trust_remote_code=server_args.trust_remote_code,
|
134
|
+
)
|
135
|
+
self.is_generation = is_generation_model(
|
136
|
+
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
137
|
+
)
|
138
|
+
|
139
|
+
# Launch a tensor parallel worker
|
140
|
+
self.tp_worker = TpModelWorker(
|
141
|
+
gpu_id=gpu_id,
|
142
|
+
tp_rank=tp_rank,
|
143
|
+
server_args=server_args,
|
144
|
+
nccl_port=port_args.nccl_ports[0],
|
145
|
+
)
|
146
|
+
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
147
|
+
|
148
|
+
# Get token and memory info from the model worker
|
149
|
+
(
|
150
|
+
self.max_total_num_tokens,
|
151
|
+
self.max_prefill_tokens,
|
152
|
+
self.max_running_requests,
|
153
|
+
self.max_req_input_len,
|
154
|
+
self.random_seed,
|
155
|
+
) = self.tp_worker.get_token_and_memory_info()
|
156
|
+
set_random_seed(self.random_seed)
|
157
|
+
self.pad_input_ids_func = getattr(
|
158
|
+
self.tp_worker.model_runner.model, "pad_input_ids", None
|
159
|
+
)
|
160
|
+
|
161
|
+
# Print debug info
|
162
|
+
logger.info(
|
163
|
+
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
164
|
+
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
165
|
+
f"max_running_requests={self.max_running_requests}, "
|
166
|
+
f"context_len={self.model_config.context_len}"
|
167
|
+
)
|
168
|
+
|
169
|
+
# Init cache
|
170
|
+
self.req_to_token_pool = self.tp_worker.model_runner.req_to_token_pool
|
171
|
+
self.token_to_kv_pool = self.tp_worker.model_runner.token_to_kv_pool
|
172
|
+
|
173
|
+
if (
|
174
|
+
server_args.chunked_prefill_size is not None
|
175
|
+
and server_args.disable_radix_cache
|
176
|
+
):
|
177
|
+
self.tree_cache = ChunkCache(
|
178
|
+
req_to_token_pool=self.req_to_token_pool,
|
179
|
+
token_to_kv_pool=self.token_to_kv_pool,
|
180
|
+
)
|
181
|
+
else:
|
182
|
+
self.tree_cache = RadixCache(
|
183
|
+
req_to_token_pool=self.req_to_token_pool,
|
184
|
+
token_to_kv_pool=self.token_to_kv_pool,
|
185
|
+
disable=server_args.disable_radix_cache,
|
186
|
+
)
|
187
|
+
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
188
|
+
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
189
|
+
|
190
|
+
# Init running status
|
191
|
+
self.waiting_queue: List[Req] = []
|
192
|
+
self.running_batch: ScheduleBatch = None
|
193
|
+
self.out_pyobjs = []
|
194
|
+
self.decode_forward_ct = 0
|
195
|
+
self.stream_interval = server_args.stream_interval
|
196
|
+
self.num_generated_tokens = 0
|
197
|
+
self.last_stats_tic = time.time()
|
198
|
+
|
199
|
+
# Init chunked prefill
|
200
|
+
self.chunked_prefill_size = server_args.chunked_prefill_size
|
201
|
+
self.current_inflight_req = None
|
202
|
+
self.is_mixed_chunk = (
|
203
|
+
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
204
|
+
)
|
205
|
+
|
206
|
+
# Init the FSM cache for constrained generation
|
207
|
+
if not server_args.skip_tokenizer_init:
|
208
|
+
self.regex_fsm_cache = FSMCache(
|
209
|
+
server_args.tokenizer_path,
|
210
|
+
{
|
211
|
+
"tokenizer_mode": server_args.tokenizer_mode,
|
212
|
+
"trust_remote_code": server_args.trust_remote_code,
|
213
|
+
},
|
214
|
+
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
215
|
+
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
216
|
+
)
|
217
|
+
self.jump_forward_cache = JumpForwardCache()
|
218
|
+
|
219
|
+
# Init new token estimation
|
220
|
+
assert (
|
221
|
+
server_args.schedule_conservativeness >= 0
|
222
|
+
), "Invalid schedule_conservativeness"
|
223
|
+
self.min_new_token_ratio = min(
|
224
|
+
global_config.base_min_new_token_ratio
|
225
|
+
* server_args.schedule_conservativeness,
|
226
|
+
1.0,
|
227
|
+
)
|
228
|
+
self.new_token_ratio = self.min_new_token_ratio
|
229
|
+
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
230
|
+
self.batch_is_full = False
|
231
|
+
|
232
|
+
@torch.inference_mode()
|
233
|
+
def event_loop(self):
|
234
|
+
while True:
|
235
|
+
recv_reqs = self.recv_requests()
|
236
|
+
self.process_input_requests(recv_reqs)
|
237
|
+
|
238
|
+
self.run_step()
|
239
|
+
|
240
|
+
self.send_results()
|
241
|
+
|
242
|
+
def recv_requests(self):
|
243
|
+
if self.tp_rank == 0:
|
244
|
+
recv_reqs = []
|
245
|
+
|
246
|
+
while True:
|
247
|
+
try:
|
248
|
+
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
249
|
+
except zmq.ZMQError:
|
250
|
+
break
|
251
|
+
recv_reqs.append(recv_req)
|
252
|
+
else:
|
253
|
+
recv_reqs = None
|
254
|
+
|
255
|
+
if self.tp_size != 1:
|
256
|
+
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
257
|
+
return recv_reqs
|
258
|
+
|
259
|
+
def process_input_requests(self, recv_reqs: List):
|
260
|
+
for recv_req in recv_reqs:
|
261
|
+
if isinstance(recv_req, TokenizedGenerateReqInput):
|
262
|
+
self.handle_generate_request(recv_req)
|
263
|
+
elif isinstance(
|
264
|
+
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
|
265
|
+
):
|
266
|
+
self.handle_embedding_request(recv_req)
|
267
|
+
elif isinstance(recv_req, FlushCacheReq):
|
268
|
+
self.flush_cache()
|
269
|
+
elif isinstance(recv_req, AbortReq):
|
270
|
+
self.abort_request(recv_req)
|
271
|
+
elif isinstance(recv_req, UpdateWeightReqInput):
|
272
|
+
success, message = self.update_weights(recv_req)
|
273
|
+
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
|
274
|
+
else:
|
275
|
+
raise ValueError(f"Invalid request: {recv_req}")
|
276
|
+
|
277
|
+
def handle_generate_request(
|
278
|
+
self,
|
279
|
+
recv_req: TokenizedGenerateReqInput,
|
280
|
+
):
|
281
|
+
req = Req(
|
282
|
+
recv_req.rid,
|
283
|
+
recv_req.input_text,
|
284
|
+
recv_req.input_ids,
|
285
|
+
recv_req.sampling_params,
|
286
|
+
lora_path=recv_req.lora_path,
|
287
|
+
)
|
288
|
+
req.tokenizer = self.tokenizer
|
289
|
+
|
290
|
+
# Image inputs
|
291
|
+
if recv_req.image_inputs is not None:
|
292
|
+
req.image_inputs = ImageInputs.from_dict(
|
293
|
+
recv_req.image_inputs, self.model_config.vocab_size
|
294
|
+
)
|
295
|
+
req.origin_input_ids = self.pad_input_ids_func(
|
296
|
+
req.origin_input_ids_unpadded, req.image_inputs
|
297
|
+
)
|
298
|
+
|
299
|
+
req.return_logprob = recv_req.return_logprob
|
300
|
+
req.top_logprobs_num = recv_req.top_logprobs_num
|
301
|
+
req.stream = recv_req.stream
|
302
|
+
req.logprob_start_len = recv_req.logprob_start_len
|
303
|
+
|
304
|
+
if req.logprob_start_len == -1:
|
305
|
+
# By default, only return the logprobs for output tokens
|
306
|
+
req.logprob_start_len = len(recv_req.input_ids) - 1
|
307
|
+
|
308
|
+
# Init regex FSM
|
309
|
+
if (
|
310
|
+
req.sampling_params.json_schema is not None
|
311
|
+
or req.sampling_params.regex is not None
|
312
|
+
):
|
313
|
+
if req.sampling_params.json_schema is not None:
|
314
|
+
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
315
|
+
("json", req.sampling_params.json_schema)
|
316
|
+
)
|
317
|
+
elif req.sampling_params.regex is not None:
|
318
|
+
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
319
|
+
("regex", req.sampling_params.regex)
|
320
|
+
)
|
321
|
+
if not self.disable_regex_jump_forward:
|
322
|
+
req.jump_forward_map = self.jump_forward_cache.query(
|
323
|
+
computed_regex_string
|
324
|
+
)
|
325
|
+
|
326
|
+
# Truncate prompts that are too long
|
327
|
+
if len(req.origin_input_ids) >= self.max_req_input_len:
|
328
|
+
logger.warning(
|
329
|
+
"Request length is longer than the KV cache pool size or "
|
330
|
+
"the max context length. Truncated!!!"
|
331
|
+
)
|
332
|
+
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
333
|
+
req.sampling_params.max_new_tokens = min(
|
334
|
+
(
|
335
|
+
req.sampling_params.max_new_tokens
|
336
|
+
if req.sampling_params.max_new_tokens is not None
|
337
|
+
else 1 << 30
|
338
|
+
),
|
339
|
+
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
340
|
+
)
|
341
|
+
|
342
|
+
self.waiting_queue.append(req)
|
343
|
+
|
344
|
+
def handle_embedding_request(
|
345
|
+
self,
|
346
|
+
recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput],
|
347
|
+
):
|
348
|
+
req = Req(
|
349
|
+
recv_req.rid,
|
350
|
+
recv_req.input_text,
|
351
|
+
recv_req.input_ids,
|
352
|
+
recv_req.sampling_params,
|
353
|
+
)
|
354
|
+
req.tokenizer = self.tokenizer
|
355
|
+
|
356
|
+
# Truncate prompts that are too long
|
357
|
+
if len(req.origin_input_ids) >= self.max_req_input_len:
|
358
|
+
logger.warning(
|
359
|
+
"Request length is longer than the KV cache pool size or "
|
360
|
+
"the max context length. Truncated!!!"
|
361
|
+
)
|
362
|
+
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
363
|
+
|
364
|
+
self.waiting_queue.append(req)
|
365
|
+
|
366
|
+
def send_results(self):
|
367
|
+
if self.tp_rank == 0:
|
368
|
+
for obj in self.out_pyobjs:
|
369
|
+
self.send_to_detokenizer.send_pyobj(obj)
|
370
|
+
self.out_pyobjs = []
|
371
|
+
|
372
|
+
def print_decode_stats(self):
|
373
|
+
num_used = self.max_total_num_tokens - (
|
374
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
375
|
+
)
|
376
|
+
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
377
|
+
self.num_generated_tokens = 0
|
378
|
+
self.last_stats_tic = time.time()
|
379
|
+
logger.info(
|
380
|
+
f"Decode batch. "
|
381
|
+
f"#running-req: {len(self.running_batch.reqs)}, "
|
382
|
+
f"#token: {num_used}, "
|
383
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
384
|
+
f"gen throughput (token/s): {throughput:.2f}, "
|
385
|
+
f"#queue-req: {len(self.waiting_queue)}"
|
386
|
+
)
|
387
|
+
|
388
|
+
def check_memory(self):
|
389
|
+
available_size = (
|
390
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
391
|
+
)
|
392
|
+
if available_size != self.max_total_num_tokens:
|
393
|
+
warnings.warn(
|
394
|
+
"Warning: "
|
395
|
+
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
396
|
+
"KV cache pool leak detected!"
|
397
|
+
)
|
398
|
+
exit(1) if crash_on_warning else None
|
399
|
+
|
400
|
+
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
401
|
+
warnings.warn(
|
402
|
+
"Warning: "
|
403
|
+
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
|
404
|
+
f"total slots={self.req_to_token_pool.size}\n"
|
405
|
+
"Memory pool leak detected!"
|
406
|
+
)
|
407
|
+
exit(1) if crash_on_warning else None
|
408
|
+
|
409
|
+
def run_step(self):
|
410
|
+
new_batch = self.get_new_batch_prefill()
|
411
|
+
if new_batch is not None:
|
412
|
+
# Run a new prefill batch
|
413
|
+
# replace run_batch with the uncommented line to use pytorch profiler
|
414
|
+
# result = pytorch_profile(
|
415
|
+
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
|
416
|
+
# )
|
417
|
+
result = self.run_batch(new_batch)
|
418
|
+
self.process_batch_result(new_batch, result)
|
419
|
+
else:
|
420
|
+
if self.running_batch is not None:
|
421
|
+
# Run a few decode batches continuously for reducing overhead
|
422
|
+
for _ in range(global_config.num_continue_decode_steps):
|
423
|
+
batch = self.get_new_batch_decode()
|
424
|
+
|
425
|
+
if batch:
|
426
|
+
# replace run_batch with the uncommented line to use pytorch profiler
|
427
|
+
# result = pytorch_profile(
|
428
|
+
# "profile_decode_step",
|
429
|
+
# self.run_batch,
|
430
|
+
# batch,
|
431
|
+
# data_size=len(batch.reqs),
|
432
|
+
# )
|
433
|
+
result = self.run_batch(batch)
|
434
|
+
self.process_batch_result(batch, result)
|
435
|
+
|
436
|
+
if self.running_batch is None:
|
437
|
+
break
|
438
|
+
|
439
|
+
if self.out_pyobjs and self.running_batch.has_stream:
|
440
|
+
break
|
441
|
+
else:
|
442
|
+
self.check_memory()
|
443
|
+
self.new_token_ratio = global_config.init_new_token_ratio
|
444
|
+
|
445
|
+
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
446
|
+
# Handle the cases where prefill is not allowed
|
447
|
+
if (
|
448
|
+
self.batch_is_full or len(self.waiting_queue) == 0
|
449
|
+
) and self.current_inflight_req is None:
|
450
|
+
return None
|
451
|
+
|
452
|
+
running_bs = (
|
453
|
+
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
454
|
+
)
|
455
|
+
if running_bs >= self.max_running_requests:
|
456
|
+
self.batch_is_full = True
|
457
|
+
return None
|
458
|
+
|
459
|
+
# Get priority queue
|
460
|
+
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
461
|
+
|
462
|
+
# Prefill policy
|
463
|
+
num_mixed_running = running_bs if self.is_mixed_chunk else 0
|
464
|
+
adder = PrefillAdder(
|
465
|
+
self.tree_cache,
|
466
|
+
self.running_batch,
|
467
|
+
self.new_token_ratio,
|
468
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
469
|
+
self.max_prefill_tokens,
|
470
|
+
self.chunked_prefill_size,
|
471
|
+
num_mixed_running,
|
472
|
+
)
|
473
|
+
|
474
|
+
has_inflight = self.current_inflight_req is not None
|
475
|
+
if self.current_inflight_req is not None:
|
476
|
+
self.current_inflight_req.init_next_round_input(
|
477
|
+
None if prefix_computed else self.tree_cache
|
478
|
+
)
|
479
|
+
self.current_inflight_req = adder.add_inflight_req(
|
480
|
+
self.current_inflight_req
|
481
|
+
)
|
482
|
+
|
483
|
+
if self.lora_paths is not None:
|
484
|
+
lora_set = (
|
485
|
+
set([req.lora_path for req in self.running_batch.reqs])
|
486
|
+
if self.running_batch is not None
|
487
|
+
else set([])
|
488
|
+
)
|
489
|
+
|
490
|
+
for req in self.waiting_queue:
|
491
|
+
if (
|
492
|
+
self.lora_paths is not None
|
493
|
+
and len(
|
494
|
+
lora_set
|
495
|
+
| set([req.lora_path for req in adder.can_run_list])
|
496
|
+
| set([req.lora_path])
|
497
|
+
)
|
498
|
+
> self.max_loras_per_batch
|
499
|
+
):
|
500
|
+
self.batch_is_full = True
|
501
|
+
break
|
502
|
+
|
503
|
+
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
|
504
|
+
self.batch_is_full = True
|
505
|
+
break
|
506
|
+
|
507
|
+
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
508
|
+
res = adder.add_one_req(req)
|
509
|
+
if res != AddReqResult.CONTINUE:
|
510
|
+
if res == AddReqResult.NO_TOKEN:
|
511
|
+
self.batch_is_full = True
|
512
|
+
break
|
513
|
+
|
514
|
+
can_run_list = adder.can_run_list
|
515
|
+
|
516
|
+
if adder.new_inflight_req is not None:
|
517
|
+
assert self.current_inflight_req is None
|
518
|
+
self.current_inflight_req = adder.new_inflight_req
|
519
|
+
|
520
|
+
if len(can_run_list) == 0:
|
521
|
+
return None
|
522
|
+
|
523
|
+
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
524
|
+
|
525
|
+
# Print stats
|
526
|
+
if self.tp_rank == 0:
|
527
|
+
if isinstance(self.tree_cache, RadixCache):
|
528
|
+
self.tree_cache_metrics["total"] += (
|
529
|
+
adder.log_input_tokens + adder.log_hit_tokens
|
530
|
+
) / 10**9
|
531
|
+
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
532
|
+
tree_cache_hit_rate = (
|
533
|
+
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
534
|
+
)
|
535
|
+
else:
|
536
|
+
tree_cache_hit_rate = 0.0
|
537
|
+
|
538
|
+
num_used = self.max_total_num_tokens - (
|
539
|
+
self.token_to_kv_pool.available_size()
|
540
|
+
+ self.tree_cache.evictable_size()
|
541
|
+
)
|
542
|
+
|
543
|
+
if num_mixed_running > 0:
|
544
|
+
logger.info(
|
545
|
+
f"Prefill batch"
|
546
|
+
f"(mixed #running-req: {num_mixed_running}). "
|
547
|
+
f"#new-seq: {len(can_run_list)}, "
|
548
|
+
f"#new-token: {adder.log_input_tokens}, "
|
549
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
550
|
+
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
551
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
552
|
+
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
553
|
+
)
|
554
|
+
else:
|
555
|
+
logger.info(
|
556
|
+
f"Prefill batch. "
|
557
|
+
f"#new-seq: {len(can_run_list)}, "
|
558
|
+
f"#new-token: {adder.log_input_tokens}, "
|
559
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
560
|
+
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
561
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
562
|
+
f"#running-req: {running_bs}, "
|
563
|
+
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
564
|
+
)
|
565
|
+
|
566
|
+
# Create a new batch
|
567
|
+
new_batch = ScheduleBatch.init_new(
|
568
|
+
can_run_list,
|
569
|
+
self.req_to_token_pool,
|
570
|
+
self.token_to_kv_pool,
|
571
|
+
self.tree_cache,
|
572
|
+
)
|
573
|
+
new_batch.prepare_for_extend(self.model_config.vocab_size)
|
574
|
+
|
575
|
+
# Mixed-style chunked prefill
|
576
|
+
decoding_reqs = []
|
577
|
+
if self.is_mixed_chunk and self.running_batch is not None:
|
578
|
+
self.running_batch.prepare_for_decode()
|
579
|
+
new_batch.mix_with_running(self.running_batch)
|
580
|
+
decoding_reqs = self.running_batch.reqs
|
581
|
+
self.running_batch = None
|
582
|
+
new_batch.decoding_reqs = decoding_reqs
|
583
|
+
|
584
|
+
return new_batch
|
585
|
+
|
586
|
+
def get_new_batch_decode(self) -> Optional[ScheduleBatch]:
|
587
|
+
batch = self.running_batch
|
588
|
+
|
589
|
+
# Check if decode out of memory
|
590
|
+
if not batch.check_decode_mem():
|
591
|
+
old_ratio = self.new_token_ratio
|
592
|
+
|
593
|
+
retracted_reqs, new_token_ratio = batch.retract_decode()
|
594
|
+
self.new_token_ratio = new_token_ratio
|
595
|
+
|
596
|
+
logger.info(
|
597
|
+
"Decode out of memory happened. "
|
598
|
+
f"#retracted_reqs: {len(retracted_reqs)}, "
|
599
|
+
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
600
|
+
)
|
601
|
+
self.waiting_queue.extend(retracted_reqs)
|
602
|
+
else:
|
603
|
+
self.new_token_ratio = max(
|
604
|
+
self.new_token_ratio - self.new_token_ratio_decay,
|
605
|
+
self.min_new_token_ratio,
|
606
|
+
)
|
607
|
+
|
608
|
+
# Check for jump-forward
|
609
|
+
if not self.disable_regex_jump_forward:
|
610
|
+
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
611
|
+
self.waiting_queue.extend(jump_forward_reqs)
|
612
|
+
if batch.is_empty():
|
613
|
+
return None
|
614
|
+
|
615
|
+
# Update batch tensors
|
616
|
+
batch.prepare_for_decode()
|
617
|
+
return batch
|
618
|
+
|
619
|
+
def run_batch(self, batch: ScheduleBatch):
|
620
|
+
if self.is_generation:
|
621
|
+
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
622
|
+
model_worker_batch = batch.get_model_worker_batch()
|
623
|
+
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
624
|
+
model_worker_batch
|
625
|
+
)
|
626
|
+
else:
|
627
|
+
logits_output = None
|
628
|
+
if self.tokenizer is not None:
|
629
|
+
next_token_ids = torch.full(
|
630
|
+
(batch.batch_size(),), self.tokenizer.eos_token_id
|
631
|
+
)
|
632
|
+
else:
|
633
|
+
next_token_ids = torch.full((batch.batch_size(),), 0)
|
634
|
+
return logits_output, next_token_ids
|
635
|
+
else: # embedding or reward model
|
636
|
+
assert batch.extend_num_tokens != 0
|
637
|
+
model_worker_batch = batch.get_model_worker_batch()
|
638
|
+
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
639
|
+
return embeddings
|
640
|
+
|
641
|
+
def process_batch_result(self, batch: ScheduleBatch, result):
|
642
|
+
if batch.forward_mode.is_decode():
|
643
|
+
self.process_batch_result_decode(batch, result)
|
644
|
+
else:
|
645
|
+
self.process_batch_result_prefill(batch, result)
|
646
|
+
|
647
|
+
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
648
|
+
if self.is_generation:
|
649
|
+
logits_output, next_token_ids = result
|
650
|
+
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
651
|
+
next_token_ids
|
652
|
+
)
|
653
|
+
|
654
|
+
if logits_output:
|
655
|
+
# Move logprobs to cpu
|
656
|
+
if logits_output.next_token_logprobs is not None:
|
657
|
+
logits_output.next_token_logprobs = (
|
658
|
+
logits_output.next_token_logprobs[
|
659
|
+
torch.arange(
|
660
|
+
len(next_token_ids), device=next_token_ids.device
|
661
|
+
),
|
662
|
+
next_token_ids,
|
663
|
+
].tolist()
|
664
|
+
)
|
665
|
+
logits_output.input_token_logprobs = (
|
666
|
+
logits_output.input_token_logprobs.tolist()
|
667
|
+
)
|
668
|
+
logits_output.normalized_prompt_logprobs = (
|
669
|
+
logits_output.normalized_prompt_logprobs.tolist()
|
670
|
+
)
|
671
|
+
|
672
|
+
next_token_ids = next_token_ids.tolist()
|
673
|
+
|
674
|
+
# Check finish conditions
|
675
|
+
logprob_pt = 0
|
676
|
+
for i, req in enumerate(batch.reqs):
|
677
|
+
if req is not self.current_inflight_req:
|
678
|
+
# Inflight reqs' prefill is not finished
|
679
|
+
req.completion_tokens_wo_jump_forward += 1
|
680
|
+
req.output_ids.append(next_token_ids[i])
|
681
|
+
req.check_finished()
|
682
|
+
|
683
|
+
if req.regex_fsm is not None:
|
684
|
+
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
685
|
+
req.regex_fsm_state, next_token_ids[i]
|
686
|
+
)
|
687
|
+
|
688
|
+
if req.finished():
|
689
|
+
self.tree_cache.cache_finished_req(req)
|
690
|
+
elif req not in batch.decoding_reqs:
|
691
|
+
# To reduce overhead, only cache prefill reqs
|
692
|
+
self.tree_cache.cache_unfinished_req(req)
|
693
|
+
|
694
|
+
if req is self.current_inflight_req:
|
695
|
+
# Inflight request would get a new req idx
|
696
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
697
|
+
|
698
|
+
if req.return_logprob:
|
699
|
+
logprob_pt += self.add_logprob_return_values(
|
700
|
+
i, req, logprob_pt, next_token_ids, logits_output
|
701
|
+
)
|
702
|
+
else: # embedding or reward model
|
703
|
+
assert batch.extend_num_tokens != 0
|
704
|
+
embeddings = result
|
705
|
+
|
706
|
+
# Check finish conditions
|
707
|
+
for i, req in enumerate(batch.reqs):
|
708
|
+
req.embedding = embeddings[i]
|
709
|
+
if req is not self.current_inflight_req:
|
710
|
+
# Inflight reqs' prefill is not finished
|
711
|
+
# dummy output token for embedding models
|
712
|
+
req.output_ids.append(0)
|
713
|
+
req.check_finished()
|
714
|
+
|
715
|
+
if req.finished():
|
716
|
+
self.tree_cache.cache_finished_req(req)
|
717
|
+
else:
|
718
|
+
self.tree_cache.cache_unfinished_req(req)
|
719
|
+
|
720
|
+
if req is self.current_inflight_req:
|
721
|
+
# Inflight request would get a new req idx
|
722
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
723
|
+
|
724
|
+
self.handle_finished_requests(batch)
|
725
|
+
|
726
|
+
if not batch.is_empty():
|
727
|
+
if self.running_batch is None:
|
728
|
+
self.running_batch = batch
|
729
|
+
else:
|
730
|
+
self.running_batch.merge_batch(batch)
|
731
|
+
|
732
|
+
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
733
|
+
logits_output, next_token_ids = result
|
734
|
+
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
735
|
+
next_token_ids
|
736
|
+
)
|
737
|
+
self.num_generated_tokens += len(batch.reqs)
|
738
|
+
|
739
|
+
# Move logprobs to cpu
|
740
|
+
if logits_output.next_token_logprobs is not None:
|
741
|
+
next_token_logprobs = logits_output.next_token_logprobs[
|
742
|
+
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
743
|
+
next_token_ids,
|
744
|
+
].tolist()
|
745
|
+
|
746
|
+
next_token_ids = next_token_ids.tolist()
|
747
|
+
|
748
|
+
# Check finish condition
|
749
|
+
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
750
|
+
req.completion_tokens_wo_jump_forward += 1
|
751
|
+
req.output_ids.append(next_token_id)
|
752
|
+
req.check_finished()
|
753
|
+
|
754
|
+
if req.regex_fsm is not None:
|
755
|
+
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
756
|
+
req.regex_fsm_state, next_token_id
|
757
|
+
)
|
758
|
+
|
759
|
+
if req.finished():
|
760
|
+
self.tree_cache.cache_finished_req(req)
|
761
|
+
|
762
|
+
if req.return_logprob:
|
763
|
+
req.output_token_logprobs.append(
|
764
|
+
(next_token_logprobs[i], next_token_id)
|
765
|
+
)
|
766
|
+
if req.top_logprobs_num > 0:
|
767
|
+
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
768
|
+
|
769
|
+
self.handle_finished_requests(batch)
|
770
|
+
|
771
|
+
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
772
|
+
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
773
|
+
self.print_decode_stats()
|
774
|
+
|
775
|
+
if self.running_batch.is_empty():
|
776
|
+
self.running_batch = None
|
777
|
+
|
778
|
+
def add_logprob_return_values(
|
779
|
+
self,
|
780
|
+
i: int,
|
781
|
+
req: Req,
|
782
|
+
pt: int,
|
783
|
+
next_token_ids: List[int],
|
784
|
+
output: LogitsProcessorOutput,
|
785
|
+
):
|
786
|
+
"""Attach logprobs to the return values."""
|
787
|
+
req.output_token_logprobs.append(
|
788
|
+
(output.next_token_logprobs[i], next_token_ids[i])
|
789
|
+
)
|
790
|
+
|
791
|
+
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
792
|
+
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
793
|
+
|
794
|
+
if req.normalized_prompt_logprob is None:
|
795
|
+
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
796
|
+
|
797
|
+
if req.input_token_logprobs is None:
|
798
|
+
input_token_logprobs = output.input_token_logprobs[
|
799
|
+
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
800
|
+
]
|
801
|
+
input_token_ids = req.fill_ids[
|
802
|
+
len(req.fill_ids)
|
803
|
+
- num_input_logprobs
|
804
|
+
+ 1 : len(req.fill_ids)
|
805
|
+
- req.last_update_decode_tokens
|
806
|
+
]
|
807
|
+
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
|
808
|
+
|
809
|
+
if (
|
810
|
+
req.logprob_start_len == 0
|
811
|
+
): # The first token does not have logprob, pad it.
|
812
|
+
req.input_token_logprobs = [
|
813
|
+
(None, req.fill_ids[0])
|
814
|
+
] + req.input_token_logprobs
|
815
|
+
|
816
|
+
if req.last_update_decode_tokens != 0:
|
817
|
+
# Some decode tokens are re-computed in an extend batch
|
818
|
+
req.output_token_logprobs.extend(
|
819
|
+
list(
|
820
|
+
zip(
|
821
|
+
output.input_token_logprobs[
|
822
|
+
pt
|
823
|
+
+ num_input_logprobs
|
824
|
+
- 1
|
825
|
+
- req.last_update_decode_tokens : pt
|
826
|
+
+ num_input_logprobs
|
827
|
+
- 1
|
828
|
+
],
|
829
|
+
req.fill_ids[
|
830
|
+
len(req.fill_ids)
|
831
|
+
- req.last_update_decode_tokens : len(req.fill_ids)
|
832
|
+
],
|
833
|
+
)
|
834
|
+
)
|
835
|
+
)
|
836
|
+
|
837
|
+
if req.top_logprobs_num > 0:
|
838
|
+
if req.input_top_logprobs is None:
|
839
|
+
req.input_top_logprobs = output.input_top_logprobs[i]
|
840
|
+
if req.logprob_start_len == 0:
|
841
|
+
req.input_top_logprobs = [None] + req.input_top_logprobs
|
842
|
+
|
843
|
+
if req.last_update_decode_tokens != 0:
|
844
|
+
req.output_top_logprobs.extend(
|
845
|
+
output.input_top_logprobs[i][-req.last_update_decode_tokens :]
|
846
|
+
)
|
847
|
+
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
848
|
+
|
849
|
+
return num_input_logprobs
|
850
|
+
|
851
|
+
def handle_finished_requests(self, batch: ScheduleBatch):
|
852
|
+
output_rids = []
|
853
|
+
output_meta_info = []
|
854
|
+
output_finished_reason: List[BaseFinishReason] = []
|
855
|
+
if self.is_generation:
|
856
|
+
output_vids = []
|
857
|
+
decoded_texts = []
|
858
|
+
output_read_ids = []
|
859
|
+
output_read_offsets = []
|
860
|
+
output_skip_special_tokens = []
|
861
|
+
output_spaces_between_special_tokens = []
|
862
|
+
else: # embedding or reward model
|
863
|
+
output_embeddings = []
|
864
|
+
unfinished_indices = []
|
865
|
+
|
866
|
+
for i, req in enumerate(batch.reqs):
|
867
|
+
if not req.finished() and req is not self.current_inflight_req:
|
868
|
+
unfinished_indices.append(i)
|
869
|
+
else:
|
870
|
+
self.batch_is_full = False
|
871
|
+
|
872
|
+
if req.finished() or (
|
873
|
+
req.stream
|
874
|
+
and (
|
875
|
+
self.decode_forward_ct % self.stream_interval == 0
|
876
|
+
or len(req.output_ids) == 1
|
877
|
+
)
|
878
|
+
):
|
879
|
+
output_rids.append(req.rid)
|
880
|
+
output_finished_reason.append(req.finished_reason)
|
881
|
+
if self.is_generation:
|
882
|
+
output_vids.append(req.vid)
|
883
|
+
decoded_texts.append(req.decoded_text)
|
884
|
+
read_ids, read_offset = req.init_incremental_detokenize()
|
885
|
+
output_read_ids.append(read_ids)
|
886
|
+
output_read_offsets.append(read_offset)
|
887
|
+
output_skip_special_tokens.append(
|
888
|
+
req.sampling_params.skip_special_tokens
|
889
|
+
)
|
890
|
+
output_spaces_between_special_tokens.append(
|
891
|
+
req.sampling_params.spaces_between_special_tokens
|
892
|
+
)
|
893
|
+
|
894
|
+
meta_info = {
|
895
|
+
"prompt_tokens": len(req.origin_input_ids),
|
896
|
+
"completion_tokens": len(req.output_ids),
|
897
|
+
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
898
|
+
"finish_reason": (
|
899
|
+
req.finished_reason.to_json()
|
900
|
+
if req.finished_reason is not None
|
901
|
+
else None
|
902
|
+
),
|
903
|
+
}
|
904
|
+
if req.return_logprob:
|
905
|
+
(
|
906
|
+
meta_info["input_token_logprobs"],
|
907
|
+
meta_info["output_token_logprobs"],
|
908
|
+
meta_info["input_top_logprobs"],
|
909
|
+
meta_info["output_top_logprobs"],
|
910
|
+
meta_info["normalized_prompt_logprob"],
|
911
|
+
) = (
|
912
|
+
req.input_token_logprobs,
|
913
|
+
req.output_token_logprobs,
|
914
|
+
req.input_top_logprobs,
|
915
|
+
req.output_top_logprobs,
|
916
|
+
req.normalized_prompt_logprob,
|
917
|
+
)
|
918
|
+
output_meta_info.append(meta_info)
|
919
|
+
else: # embedding or reward model
|
920
|
+
output_embeddings.append(req.embedding)
|
921
|
+
meta_info = {
|
922
|
+
"prompt_tokens": len(req.origin_input_ids),
|
923
|
+
}
|
924
|
+
output_meta_info.append(meta_info)
|
925
|
+
|
926
|
+
# Send to detokenizer
|
927
|
+
if output_rids:
|
928
|
+
if self.is_generation:
|
929
|
+
self.out_pyobjs.append(
|
930
|
+
BatchTokenIDOut(
|
931
|
+
output_rids,
|
932
|
+
output_vids,
|
933
|
+
decoded_texts,
|
934
|
+
output_read_ids,
|
935
|
+
output_read_offsets,
|
936
|
+
output_skip_special_tokens,
|
937
|
+
output_spaces_between_special_tokens,
|
938
|
+
output_meta_info,
|
939
|
+
output_finished_reason,
|
940
|
+
)
|
941
|
+
)
|
942
|
+
else: # embedding or reward model
|
943
|
+
self.out_pyobjs.append(
|
944
|
+
BatchEmbeddingOut(
|
945
|
+
output_rids,
|
946
|
+
output_embeddings,
|
947
|
+
output_meta_info,
|
948
|
+
output_finished_reason,
|
949
|
+
)
|
950
|
+
)
|
951
|
+
|
952
|
+
# Remove finished reqs: update batch tensors
|
953
|
+
batch.filter_batch(unfinished_indices)
|
954
|
+
|
955
|
+
def flush_cache(self):
|
956
|
+
if len(self.waiting_queue) == 0 and (
|
957
|
+
self.running_batch is None or len(self.running_batch.reqs) == 0
|
958
|
+
):
|
959
|
+
self.tree_cache.reset()
|
960
|
+
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
961
|
+
self.regex_fsm_cache.reset()
|
962
|
+
self.req_to_token_pool.clear()
|
963
|
+
self.token_to_kv_pool.clear()
|
964
|
+
torch.cuda.empty_cache()
|
965
|
+
logger.info("Cache flushed successfully!")
|
966
|
+
if_success = True
|
967
|
+
else:
|
968
|
+
logging.warning(
|
969
|
+
f"Cache not flushed because there are pending requests. "
|
970
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
971
|
+
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
972
|
+
)
|
973
|
+
if_success = False
|
974
|
+
return if_success
|
975
|
+
|
976
|
+
def abort_request(self, recv_req: AbortReq):
|
977
|
+
# Delete requests in the waiting queue
|
978
|
+
to_del = None
|
979
|
+
for i, req in enumerate(self.waiting_queue):
|
980
|
+
if req.rid == recv_req.rid:
|
981
|
+
to_del = i
|
982
|
+
break
|
983
|
+
|
984
|
+
if to_del is not None:
|
985
|
+
del self.waiting_queue[to_del]
|
986
|
+
|
987
|
+
# Delete requests in the running batch
|
988
|
+
if self.running_batch:
|
989
|
+
for req in self.running_batch.reqs:
|
990
|
+
if req.rid == recv_req.rid:
|
991
|
+
req.finished_reason = FINISH_ABORT()
|
992
|
+
break
|
993
|
+
|
994
|
+
def update_weights(self, recv_req: UpdateWeightReqInput):
|
995
|
+
success, message = self.tp_worker.update_weights(recv_req)
|
996
|
+
if success:
|
997
|
+
flash_cache_success = self.flush_cache()
|
998
|
+
assert flash_cache_success, "Cache flush failed after updating weights"
|
999
|
+
else:
|
1000
|
+
logger.error(message)
|
1001
|
+
return success, message
|
1002
|
+
|
1003
|
+
|
1004
|
+
def run_scheduler_process(
|
1005
|
+
server_args: ServerArgs,
|
1006
|
+
port_args: PortArgs,
|
1007
|
+
gpu_id: int,
|
1008
|
+
tp_rank: int,
|
1009
|
+
pipe_writer,
|
1010
|
+
):
|
1011
|
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
1012
|
+
suppress_other_loggers()
|
1013
|
+
|
1014
|
+
try:
|
1015
|
+
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
|
1016
|
+
pipe_writer.send("ready")
|
1017
|
+
scheduler.event_loop()
|
1018
|
+
except Exception:
|
1019
|
+
msg = get_exception_traceback()
|
1020
|
+
logger.error(msg)
|
1021
|
+
kill_parent_process()
|