sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,686 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import logging
|
3
|
-
import multiprocessing
|
4
|
-
import time
|
5
|
-
import warnings
|
6
|
-
from concurrent.futures import ThreadPoolExecutor
|
7
|
-
from typing import List
|
8
|
-
|
9
|
-
import numpy as np
|
10
|
-
import rpyc
|
11
|
-
import torch
|
12
|
-
from rpyc.utils.classic import obtain
|
13
|
-
from rpyc.utils.server import ThreadedServer
|
14
|
-
from sglang.srt.constrained.fsm_cache import FSMCache
|
15
|
-
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
16
|
-
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
17
|
-
from sglang.srt.managers.io_struct import (
|
18
|
-
BatchTokenIDOut,
|
19
|
-
FlushCacheReq,
|
20
|
-
TokenizedGenerateReqInput,
|
21
|
-
)
|
22
|
-
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
23
|
-
from sglang.srt.managers.router.model_runner import ModelRunner
|
24
|
-
from sglang.srt.managers.router.radix_cache import RadixCache
|
25
|
-
from sglang.srt.managers.router.scheduler import Scheduler
|
26
|
-
from sglang.srt.model_config import ModelConfig
|
27
|
-
from sglang.srt.server_args import PortArgs, ServerArgs
|
28
|
-
from sglang.srt.utils import (
|
29
|
-
get_exception_traceback,
|
30
|
-
get_int_token_logit_bias,
|
31
|
-
is_multimodal_model,
|
32
|
-
set_random_seed,
|
33
|
-
)
|
34
|
-
from vllm.logger import _default_handler as vllm_default_handler
|
35
|
-
|
36
|
-
logger = logging.getLogger("model_rpc")
|
37
|
-
|
38
|
-
|
39
|
-
class ModelRpcServer(rpyc.Service):
|
40
|
-
def exposed_init_model(
|
41
|
-
self,
|
42
|
-
tp_rank: int,
|
43
|
-
server_args: ServerArgs,
|
44
|
-
port_args: PortArgs,
|
45
|
-
):
|
46
|
-
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
47
|
-
|
48
|
-
# Copy arguments
|
49
|
-
self.tp_rank = tp_rank
|
50
|
-
self.tp_size = server_args.tp_size
|
51
|
-
self.schedule_heuristic = server_args.schedule_heuristic
|
52
|
-
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
53
|
-
vllm_default_handler.setLevel(
|
54
|
-
level=getattr(logging, server_args.log_level.upper())
|
55
|
-
)
|
56
|
-
|
57
|
-
# Init model and tokenizer
|
58
|
-
self.model_config = ModelConfig(
|
59
|
-
server_args.model_path,
|
60
|
-
server_args.trust_remote_code,
|
61
|
-
context_length=server_args.context_length,
|
62
|
-
)
|
63
|
-
|
64
|
-
# for model end global settings
|
65
|
-
server_args_dict = {
|
66
|
-
"enable_flashinfer": server_args.enable_flashinfer,
|
67
|
-
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
68
|
-
}
|
69
|
-
|
70
|
-
self.model_runner = ModelRunner(
|
71
|
-
model_config=self.model_config,
|
72
|
-
mem_fraction_static=server_args.mem_fraction_static,
|
73
|
-
tp_rank=tp_rank,
|
74
|
-
tp_size=server_args.tp_size,
|
75
|
-
nccl_port=port_args.nccl_port,
|
76
|
-
load_format=server_args.load_format,
|
77
|
-
trust_remote_code=server_args.trust_remote_code,
|
78
|
-
server_args_dict=server_args_dict,
|
79
|
-
)
|
80
|
-
if is_multimodal_model(server_args.model_path):
|
81
|
-
self.processor = get_processor(
|
82
|
-
server_args.tokenizer_path,
|
83
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
84
|
-
trust_remote_code=server_args.trust_remote_code,
|
85
|
-
)
|
86
|
-
self.tokenizer = self.processor.tokenizer
|
87
|
-
else:
|
88
|
-
self.tokenizer = get_tokenizer(
|
89
|
-
server_args.tokenizer_path,
|
90
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
91
|
-
trust_remote_code=server_args.trust_remote_code,
|
92
|
-
)
|
93
|
-
self.eos_token_id = self.tokenizer.eos_token_id
|
94
|
-
self.max_total_num_token = self.model_runner.max_total_num_token
|
95
|
-
self.max_num_running_seq = self.max_total_num_token // 2
|
96
|
-
self.max_prefill_num_token = max(
|
97
|
-
self.model_config.context_len,
|
98
|
-
(
|
99
|
-
self.max_total_num_token // 6
|
100
|
-
if server_args.max_prefill_num_token is None
|
101
|
-
else server_args.max_prefill_num_token
|
102
|
-
),
|
103
|
-
)
|
104
|
-
self.int_token_logit_bias = torch.tensor(
|
105
|
-
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
106
|
-
)
|
107
|
-
set_random_seed(server_args.random_seed)
|
108
|
-
logger.info(
|
109
|
-
f"Rank {self.tp_rank}: "
|
110
|
-
f"max_total_num_token={self.max_total_num_token}, "
|
111
|
-
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
112
|
-
f"context_len={self.model_config.context_len}, "
|
113
|
-
)
|
114
|
-
logger.info(server_args.get_optional_modes_logging())
|
115
|
-
|
116
|
-
# Init cache
|
117
|
-
self.tree_cache = RadixCache(server_args.disable_radix_cache)
|
118
|
-
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
119
|
-
self.scheduler = Scheduler(
|
120
|
-
self.schedule_heuristic,
|
121
|
-
self.max_num_running_seq,
|
122
|
-
self.max_prefill_num_token,
|
123
|
-
self.max_total_num_token,
|
124
|
-
self.tree_cache,
|
125
|
-
)
|
126
|
-
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
127
|
-
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
128
|
-
|
129
|
-
# Init running status
|
130
|
-
self.forward_queue: List[Req] = []
|
131
|
-
self.running_batch: Batch = None
|
132
|
-
self.out_pyobjs = []
|
133
|
-
self.decode_forward_ct = 0
|
134
|
-
self.stream_interval = server_args.stream_interval
|
135
|
-
|
136
|
-
# Init the FSM cache for constrained generation
|
137
|
-
self.regex_fsm_cache = FSMCache(
|
138
|
-
server_args.tokenizer_path,
|
139
|
-
{
|
140
|
-
"tokenizer_mode": server_args.tokenizer_mode,
|
141
|
-
"trust_remote_code": server_args.trust_remote_code,
|
142
|
-
},
|
143
|
-
)
|
144
|
-
self.jump_forward_cache = JumpForwardCache()
|
145
|
-
|
146
|
-
# Init new token estimation
|
147
|
-
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
|
148
|
-
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
|
149
|
-
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
|
150
|
-
|
151
|
-
def flush_cache(self):
|
152
|
-
if len(self.forward_queue) == 0 and (
|
153
|
-
self.running_batch is None or len(self.running_batch.reqs) == 0
|
154
|
-
):
|
155
|
-
self.tree_cache.reset()
|
156
|
-
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
157
|
-
self.regex_fsm_cache.reset()
|
158
|
-
self.req_to_token_pool.clear()
|
159
|
-
self.token_to_kv_pool.clear()
|
160
|
-
torch.cuda.empty_cache()
|
161
|
-
logger.info("Cache flushed successfully!")
|
162
|
-
else:
|
163
|
-
warnings.warn(
|
164
|
-
"Cache not flushed because there are pending requests. "
|
165
|
-
f"#queue-req: {len(self.forward_queue)}, "
|
166
|
-
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
167
|
-
)
|
168
|
-
|
169
|
-
def exposed_step(self, recv_reqs):
|
170
|
-
if self.tp_size != 1:
|
171
|
-
recv_reqs = obtain(recv_reqs)
|
172
|
-
|
173
|
-
try:
|
174
|
-
# Recv requests
|
175
|
-
for recv_req in recv_reqs:
|
176
|
-
if isinstance(recv_req, TokenizedGenerateReqInput):
|
177
|
-
self.handle_generate_request(recv_req)
|
178
|
-
elif isinstance(recv_req, FlushCacheReq):
|
179
|
-
self.flush_cache()
|
180
|
-
else:
|
181
|
-
raise ValueError(f"Invalid request: {recv_req}")
|
182
|
-
|
183
|
-
# Forward
|
184
|
-
self.forward_step()
|
185
|
-
except Exception:
|
186
|
-
logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
|
187
|
-
|
188
|
-
# Return results
|
189
|
-
ret = self.out_pyobjs
|
190
|
-
self.out_pyobjs = []
|
191
|
-
return ret
|
192
|
-
|
193
|
-
@torch.inference_mode()
|
194
|
-
def forward_step(self):
|
195
|
-
new_batch = self.get_new_fill_batch()
|
196
|
-
|
197
|
-
if new_batch is not None:
|
198
|
-
# Run new fill batch
|
199
|
-
self.forward_fill_batch(new_batch)
|
200
|
-
|
201
|
-
if not new_batch.is_empty():
|
202
|
-
if self.running_batch is None:
|
203
|
-
self.running_batch = new_batch
|
204
|
-
else:
|
205
|
-
self.running_batch.merge(new_batch)
|
206
|
-
else:
|
207
|
-
# Run decode batch
|
208
|
-
if self.running_batch is not None:
|
209
|
-
# Run a few decode batches continuously for reducing overhead
|
210
|
-
for _ in range(10):
|
211
|
-
self.forward_decode_batch(self.running_batch)
|
212
|
-
|
213
|
-
if self.running_batch.is_empty():
|
214
|
-
self.running_batch = None
|
215
|
-
break
|
216
|
-
|
217
|
-
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
218
|
-
break
|
219
|
-
|
220
|
-
if self.running_batch is not None and self.tp_rank == 0:
|
221
|
-
if self.decode_forward_ct % 40 == 0:
|
222
|
-
num_used = self.max_total_num_token - (
|
223
|
-
self.token_to_kv_pool.available_size()
|
224
|
-
+ self.tree_cache.evictable_size()
|
225
|
-
)
|
226
|
-
logger.info(
|
227
|
-
f"#running-req: {len(self.running_batch.reqs)}, "
|
228
|
-
f"#token: {num_used}, "
|
229
|
-
f"token usage: {num_used / self.max_total_num_token:.2f}, "
|
230
|
-
f"#queue-req: {len(self.forward_queue)}"
|
231
|
-
)
|
232
|
-
else:
|
233
|
-
# check the available size
|
234
|
-
available_size = (
|
235
|
-
self.token_to_kv_pool.available_size()
|
236
|
-
+ self.tree_cache.evictable_size()
|
237
|
-
)
|
238
|
-
if available_size != self.max_total_num_token:
|
239
|
-
warnings.warn(
|
240
|
-
"Warning: "
|
241
|
-
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
|
242
|
-
"KV cache pool leak detected!"
|
243
|
-
)
|
244
|
-
|
245
|
-
def handle_generate_request(
|
246
|
-
self,
|
247
|
-
recv_req: TokenizedGenerateReqInput,
|
248
|
-
):
|
249
|
-
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
250
|
-
req.pixel_values = recv_req.pixel_values
|
251
|
-
if req.pixel_values is not None:
|
252
|
-
req.pad_value = [
|
253
|
-
(recv_req.image_hash) % self.model_config.vocab_size,
|
254
|
-
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
255
|
-
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
256
|
-
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
257
|
-
]
|
258
|
-
req.image_size = recv_req.image_size
|
259
|
-
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
|
260
|
-
req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
|
261
|
-
)
|
262
|
-
req.sampling_params = recv_req.sampling_params
|
263
|
-
req.return_logprob = recv_req.return_logprob
|
264
|
-
req.logprob_start_len = recv_req.logprob_start_len
|
265
|
-
req.stream = recv_req.stream
|
266
|
-
req.tokenizer = self.tokenizer
|
267
|
-
|
268
|
-
# Init regex fsm
|
269
|
-
if req.sampling_params.regex is not None:
|
270
|
-
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
271
|
-
if not self.disable_regex_jump_forward:
|
272
|
-
req.jump_forward_map = self.jump_forward_cache.query(
|
273
|
-
req.sampling_params.regex
|
274
|
-
)
|
275
|
-
|
276
|
-
# Truncate long prompts
|
277
|
-
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
278
|
-
req.sampling_params.max_new_tokens = min(
|
279
|
-
req.sampling_params.max_new_tokens,
|
280
|
-
self.model_config.context_len - 1 - len(req.input_ids),
|
281
|
-
self.max_total_num_token - 128 - len(req.input_ids),
|
282
|
-
)
|
283
|
-
self.forward_queue.append(req)
|
284
|
-
|
285
|
-
def get_new_fill_batch(self):
|
286
|
-
if (
|
287
|
-
self.running_batch is not None
|
288
|
-
and len(self.running_batch.reqs) > self.max_num_running_seq
|
289
|
-
):
|
290
|
-
return None
|
291
|
-
|
292
|
-
for req in self.forward_queue:
|
293
|
-
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
294
|
-
if req.return_logprob:
|
295
|
-
prefix_indices = prefix_indices[: req.logprob_start_len]
|
296
|
-
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
297
|
-
req.prefix_indices = prefix_indices
|
298
|
-
req.last_node = last_node
|
299
|
-
|
300
|
-
# Get priority queue
|
301
|
-
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
|
302
|
-
|
303
|
-
# Add requests if there is available space
|
304
|
-
can_run_list = []
|
305
|
-
new_batch_total_tokens = 0
|
306
|
-
new_batch_input_tokens = 0
|
307
|
-
|
308
|
-
available_size = (
|
309
|
-
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
310
|
-
)
|
311
|
-
if self.running_batch:
|
312
|
-
available_size -= sum(
|
313
|
-
[
|
314
|
-
(r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
|
315
|
-
for r in self.running_batch.reqs
|
316
|
-
]
|
317
|
-
)
|
318
|
-
|
319
|
-
for req in self.forward_queue:
|
320
|
-
if req.return_logprob:
|
321
|
-
# Need at least two tokens to compute normalized logprob
|
322
|
-
if req.extend_input_len < 2:
|
323
|
-
delta = 2 - req.extend_input_len
|
324
|
-
req.extend_input_len += delta
|
325
|
-
req.prefix_indices = req.prefix_indices[:-delta]
|
326
|
-
if req.image_offset is not None:
|
327
|
-
req.image_offset += delta
|
328
|
-
if req.extend_input_len == 0 and req.max_new_tokens() > 0:
|
329
|
-
# Need at least one token to compute logits
|
330
|
-
req.extend_input_len = 1
|
331
|
-
req.prefix_indices = req.prefix_indices[:-1]
|
332
|
-
if req.image_offset is not None:
|
333
|
-
req.image_offset += 1
|
334
|
-
|
335
|
-
if (
|
336
|
-
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
337
|
-
< available_size
|
338
|
-
and req.extend_input_len + new_batch_input_tokens
|
339
|
-
< self.max_prefill_num_token
|
340
|
-
):
|
341
|
-
delta = self.tree_cache.inc_ref_counter(req.last_node)
|
342
|
-
available_size += delta
|
343
|
-
|
344
|
-
if not (
|
345
|
-
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
346
|
-
< available_size
|
347
|
-
):
|
348
|
-
# Undo the insertion
|
349
|
-
delta = self.tree_cache.dec_ref_counter(req.last_node)
|
350
|
-
available_size += delta
|
351
|
-
else:
|
352
|
-
# Add this request to the running batch
|
353
|
-
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
354
|
-
can_run_list.append(req)
|
355
|
-
new_batch_total_tokens += (
|
356
|
-
req.extend_input_len + req.max_new_tokens()
|
357
|
-
)
|
358
|
-
new_batch_input_tokens += req.extend_input_len
|
359
|
-
|
360
|
-
if len(can_run_list) == 0:
|
361
|
-
return None
|
362
|
-
|
363
|
-
if self.tp_rank == 0:
|
364
|
-
running_req = (
|
365
|
-
0 if self.running_batch is None else len(self.running_batch.reqs)
|
366
|
-
)
|
367
|
-
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
|
368
|
-
self.tree_cache_metrics["total"] += (
|
369
|
-
hit_tokens + new_batch_input_tokens
|
370
|
-
) / 10**9
|
371
|
-
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
|
372
|
-
tree_cache_hit_rate = (
|
373
|
-
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
374
|
-
)
|
375
|
-
logger.info(
|
376
|
-
f"new fill batch. #seq: {len(can_run_list)}. "
|
377
|
-
f"#cached_token: {hit_tokens}. "
|
378
|
-
f"#new_token: {new_batch_input_tokens}. "
|
379
|
-
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
|
380
|
-
f"#running_req: {running_req}. "
|
381
|
-
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
|
382
|
-
)
|
383
|
-
logger.debug(
|
384
|
-
f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
385
|
-
f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
386
|
-
f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
387
|
-
f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
388
|
-
)
|
389
|
-
|
390
|
-
new_batch = Batch.init_new(
|
391
|
-
can_run_list,
|
392
|
-
self.req_to_token_pool,
|
393
|
-
self.token_to_kv_pool,
|
394
|
-
self.tree_cache,
|
395
|
-
)
|
396
|
-
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
397
|
-
return new_batch
|
398
|
-
|
399
|
-
def forward_fill_batch(self, batch: Batch):
|
400
|
-
# Build batch tensors
|
401
|
-
batch.prepare_for_extend(
|
402
|
-
self.model_config.vocab_size, self.int_token_logit_bias
|
403
|
-
)
|
404
|
-
|
405
|
-
logprobs = None
|
406
|
-
if batch.extend_num_tokens != 0:
|
407
|
-
# Forward
|
408
|
-
logits, (
|
409
|
-
prefill_logprobs,
|
410
|
-
normalized_logprobs,
|
411
|
-
last_logprobs,
|
412
|
-
) = self.model_runner.forward(
|
413
|
-
batch, ForwardMode.EXTEND, batch.return_logprob
|
414
|
-
)
|
415
|
-
if prefill_logprobs is not None:
|
416
|
-
logprobs = prefill_logprobs.cpu().tolist()
|
417
|
-
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
418
|
-
|
419
|
-
next_token_ids, _ = batch.sample(logits)
|
420
|
-
next_token_ids = next_token_ids.cpu().tolist()
|
421
|
-
else:
|
422
|
-
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
423
|
-
logits = logprobs = normalized_logprobs = last_logprobs = None
|
424
|
-
|
425
|
-
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
426
|
-
reqs = batch.reqs
|
427
|
-
if last_logprobs is not None:
|
428
|
-
last_logprobs = (
|
429
|
-
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
430
|
-
)
|
431
|
-
|
432
|
-
# Check finish condition
|
433
|
-
pt = 0
|
434
|
-
for i, req in enumerate(reqs):
|
435
|
-
req.completion_tokens_wo_jump_forward += 1
|
436
|
-
req.output_ids = [next_token_ids[i]]
|
437
|
-
req.check_finished()
|
438
|
-
|
439
|
-
if logprobs is not None:
|
440
|
-
req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
|
441
|
-
req.normalized_logprob = normalized_logprobs[i]
|
442
|
-
|
443
|
-
# If logprob_start_len > 0, then first logprob_start_len prompt tokens
|
444
|
-
# will be ignored.
|
445
|
-
prompt_token_len = len(req.logprob)
|
446
|
-
token_ids = req.input_ids[-prompt_token_len:] + [next_token_ids[i]]
|
447
|
-
token_logprobs = req.logprob + [last_logprobs[i]]
|
448
|
-
req.token_logprob = list(zip(token_ids, token_logprobs))
|
449
|
-
if req.logprob_start_len == 0:
|
450
|
-
req.token_logprob = [(req.input_ids[0], None)] + req.token_logprob
|
451
|
-
pt += req.extend_input_len
|
452
|
-
|
453
|
-
self.handle_finished_requests(batch)
|
454
|
-
|
455
|
-
def forward_decode_batch(self, batch: Batch):
|
456
|
-
# check if decode out of memory
|
457
|
-
if not batch.check_decode_mem():
|
458
|
-
old_ratio = self.new_token_ratio
|
459
|
-
self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)
|
460
|
-
|
461
|
-
retracted_reqs = batch.retract_decode()
|
462
|
-
logger.info(
|
463
|
-
"decode out of memory happened, "
|
464
|
-
f"#retracted_reqs: {len(retracted_reqs)}, "
|
465
|
-
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
466
|
-
)
|
467
|
-
self.forward_queue.extend(retracted_reqs)
|
468
|
-
else:
|
469
|
-
self.new_token_ratio = max(
|
470
|
-
self.new_token_ratio - self.new_token_ratio_step[0],
|
471
|
-
self.min_new_token_ratio,
|
472
|
-
)
|
473
|
-
|
474
|
-
if not self.disable_regex_jump_forward:
|
475
|
-
# check for jump-forward
|
476
|
-
jump_forward_reqs = batch.check_for_jump_forward()
|
477
|
-
|
478
|
-
# check for image jump-forward
|
479
|
-
for req in jump_forward_reqs:
|
480
|
-
if req.pixel_values is not None:
|
481
|
-
(
|
482
|
-
req.input_ids,
|
483
|
-
req.image_offset,
|
484
|
-
) = self.model_runner.model.pad_input_ids(
|
485
|
-
req.input_ids,
|
486
|
-
req.pad_value,
|
487
|
-
req.pixel_values.shape,
|
488
|
-
req.image_size,
|
489
|
-
)
|
490
|
-
|
491
|
-
self.forward_queue.extend(jump_forward_reqs)
|
492
|
-
if batch.is_empty():
|
493
|
-
return
|
494
|
-
|
495
|
-
# Update batch tensors
|
496
|
-
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
497
|
-
batch.prepare_for_decode()
|
498
|
-
|
499
|
-
# Forward
|
500
|
-
logits, (_, _, last_logprobs) = self.model_runner.forward(
|
501
|
-
batch,
|
502
|
-
ForwardMode.DECODE,
|
503
|
-
batch.return_logprob,
|
504
|
-
)
|
505
|
-
next_token_ids, _ = batch.sample(logits)
|
506
|
-
next_token_ids = next_token_ids.cpu().tolist()
|
507
|
-
|
508
|
-
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
509
|
-
reqs = batch.reqs
|
510
|
-
if last_logprobs is not None:
|
511
|
-
last_logprobs = last_logprobs[
|
512
|
-
torch.arange(len(reqs)), next_token_ids
|
513
|
-
].tolist()
|
514
|
-
|
515
|
-
# Check finish condition
|
516
|
-
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
|
517
|
-
req.completion_tokens_wo_jump_forward += 1
|
518
|
-
req.output_ids.append(next_tok_id)
|
519
|
-
req.check_finished()
|
520
|
-
|
521
|
-
if last_logprobs is not None:
|
522
|
-
req.token_logprob.append((next_tok_id, last_logprobs[i]))
|
523
|
-
|
524
|
-
self.handle_finished_requests(batch)
|
525
|
-
|
526
|
-
def handle_finished_requests(self, batch: Batch):
|
527
|
-
output_rids = []
|
528
|
-
output_tokens = []
|
529
|
-
output_and_jump_forward_strs = []
|
530
|
-
output_hit_stop_str = []
|
531
|
-
output_skip_special_tokens = []
|
532
|
-
output_meta_info = []
|
533
|
-
output_finished = []
|
534
|
-
finished_indices = []
|
535
|
-
unfinished_indices = []
|
536
|
-
for i, req in enumerate(batch.reqs):
|
537
|
-
if req.finished:
|
538
|
-
finished_indices.append(i)
|
539
|
-
else:
|
540
|
-
unfinished_indices.append(i)
|
541
|
-
|
542
|
-
if req.finished or (
|
543
|
-
(
|
544
|
-
req.stream
|
545
|
-
and (
|
546
|
-
self.decode_forward_ct % self.stream_interval == 0
|
547
|
-
or len(req.output_ids) == 1
|
548
|
-
)
|
549
|
-
)
|
550
|
-
):
|
551
|
-
output_rids.append(req.rid)
|
552
|
-
output_tokens.append(req.output_ids)
|
553
|
-
output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
|
554
|
-
output_hit_stop_str.append(req.hit_stop_str)
|
555
|
-
output_skip_special_tokens.append(
|
556
|
-
req.sampling_params.skip_special_tokens
|
557
|
-
)
|
558
|
-
|
559
|
-
meta_info = {
|
560
|
-
"prompt_tokens": req.prompt_tokens,
|
561
|
-
"completion_tokens": len(req.input_ids)
|
562
|
-
+ len(req.output_ids)
|
563
|
-
- req.prompt_tokens,
|
564
|
-
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
565
|
-
}
|
566
|
-
if req.return_logprob:
|
567
|
-
meta_info["prompt_logprob"] = req.logprob
|
568
|
-
meta_info["token_logprob"] = req.token_logprob
|
569
|
-
meta_info["normalized_prompt_logprob"] = req.normalized_logprob
|
570
|
-
output_meta_info.append(meta_info)
|
571
|
-
output_finished.append(req.finished)
|
572
|
-
|
573
|
-
# Send to detokenizer
|
574
|
-
if output_rids:
|
575
|
-
self.out_pyobjs.append(
|
576
|
-
BatchTokenIDOut(
|
577
|
-
output_rids,
|
578
|
-
output_tokens,
|
579
|
-
output_and_jump_forward_strs,
|
580
|
-
output_hit_stop_str,
|
581
|
-
output_skip_special_tokens,
|
582
|
-
output_meta_info,
|
583
|
-
output_finished,
|
584
|
-
)
|
585
|
-
)
|
586
|
-
|
587
|
-
# Remove finished reqs
|
588
|
-
if finished_indices:
|
589
|
-
# Update radix cache
|
590
|
-
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
|
591
|
-
for i in finished_indices:
|
592
|
-
req = batch.reqs[i]
|
593
|
-
req_pool_idx = req_pool_indices_cpu[i]
|
594
|
-
token_ids = tuple(req.input_ids + req.output_ids)
|
595
|
-
seq_len = len(token_ids) - 1
|
596
|
-
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
|
597
|
-
prefix_len = self.tree_cache.insert(
|
598
|
-
token_ids[:seq_len], indices.clone()
|
599
|
-
)
|
600
|
-
|
601
|
-
self.token_to_kv_pool.free(indices[:prefix_len])
|
602
|
-
self.req_to_token_pool.free(req_pool_idx)
|
603
|
-
self.tree_cache.dec_ref_counter(req.last_node)
|
604
|
-
|
605
|
-
# Update batch tensors
|
606
|
-
if unfinished_indices:
|
607
|
-
batch.filter_batch(unfinished_indices)
|
608
|
-
else:
|
609
|
-
batch.reqs = []
|
610
|
-
|
611
|
-
|
612
|
-
class ModelRpcClient:
|
613
|
-
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
|
614
|
-
tp_size = server_args.tp_size
|
615
|
-
|
616
|
-
if tp_size == 1:
|
617
|
-
# Init model
|
618
|
-
self.model_server = ModelRpcServer()
|
619
|
-
self.model_server.exposed_init_model(0, server_args, port_args)
|
620
|
-
|
621
|
-
# Wrap functions
|
622
|
-
def async_wrap(f):
|
623
|
-
async def _func(*args, **kwargs):
|
624
|
-
return f(*args, **kwargs)
|
625
|
-
|
626
|
-
return _func
|
627
|
-
|
628
|
-
self.step = async_wrap(self.model_server.exposed_step)
|
629
|
-
else:
|
630
|
-
with ThreadPoolExecutor(tp_size) as executor:
|
631
|
-
# Launch model processes
|
632
|
-
rets = executor.map(start_model_process, port_args.model_rpc_ports)
|
633
|
-
self.model_servers = [x[0] for x in rets]
|
634
|
-
self.procs = [x[1] for x in rets]
|
635
|
-
|
636
|
-
# Init model
|
637
|
-
def init_model(i):
|
638
|
-
return self.model_servers[i].init_model(i, server_args, port_args)
|
639
|
-
|
640
|
-
rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]
|
641
|
-
|
642
|
-
# Wrap functions
|
643
|
-
def async_wrap(func_name):
|
644
|
-
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
|
645
|
-
|
646
|
-
async def _func(*args, **kwargs):
|
647
|
-
tasks = [f(*args, **kwargs) for f in fs]
|
648
|
-
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
|
649
|
-
return obtain(tasks[0].value)
|
650
|
-
|
651
|
-
return _func
|
652
|
-
|
653
|
-
self.step = async_wrap("step")
|
654
|
-
|
655
|
-
|
656
|
-
def _init_service(port):
|
657
|
-
t = ThreadedServer(
|
658
|
-
ModelRpcServer(),
|
659
|
-
port=port,
|
660
|
-
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
|
661
|
-
)
|
662
|
-
t.start()
|
663
|
-
|
664
|
-
|
665
|
-
def start_model_process(port):
|
666
|
-
proc = multiprocessing.Process(target=_init_service, args=(port,))
|
667
|
-
proc.start()
|
668
|
-
time.sleep(1)
|
669
|
-
|
670
|
-
repeat_count = 0
|
671
|
-
while repeat_count < 20:
|
672
|
-
try:
|
673
|
-
con = rpyc.connect(
|
674
|
-
"localhost",
|
675
|
-
port,
|
676
|
-
config={"allow_pickle": True, "sync_request_timeout": 1800},
|
677
|
-
)
|
678
|
-
break
|
679
|
-
except ConnectionRefusedError:
|
680
|
-
time.sleep(1)
|
681
|
-
repeat_count += 1
|
682
|
-
if repeat_count == 20:
|
683
|
-
raise RuntimeError("init rpc env error!")
|
684
|
-
|
685
|
-
assert proc.is_alive()
|
686
|
-
return con.root, proc
|