sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,813 @@
|
|
1
|
+
"""A tensor parallel worker."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import logging
|
5
|
+
import time
|
6
|
+
import warnings
|
7
|
+
from concurrent.futures import ThreadPoolExecutor
|
8
|
+
from typing import List, Optional
|
9
|
+
|
10
|
+
import rpyc
|
11
|
+
import torch
|
12
|
+
from rpyc.utils.classic import obtain
|
13
|
+
|
14
|
+
from sglang.global_config import global_config
|
15
|
+
from sglang.srt.constrained.fsm_cache import FSMCache
|
16
|
+
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
17
|
+
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
18
|
+
from sglang.srt.managers.controller.infer_batch import (
|
19
|
+
FINISH_ABORT,
|
20
|
+
BaseFinishReason,
|
21
|
+
Batch,
|
22
|
+
ForwardMode,
|
23
|
+
Req,
|
24
|
+
)
|
25
|
+
from sglang.srt.managers.controller.model_runner import ModelRunner
|
26
|
+
from sglang.srt.managers.controller.radix_cache import RadixCache
|
27
|
+
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
28
|
+
from sglang.srt.managers.io_struct import (
|
29
|
+
AbortReq,
|
30
|
+
BatchTokenIDOut,
|
31
|
+
FlushCacheReq,
|
32
|
+
TokenizedGenerateReqInput,
|
33
|
+
)
|
34
|
+
from sglang.srt.model_config import ModelConfig
|
35
|
+
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
36
|
+
from sglang.srt.utils import (
|
37
|
+
connect_rpyc_service,
|
38
|
+
get_int_token_logit_bias,
|
39
|
+
is_multimodal_model,
|
40
|
+
set_random_seed,
|
41
|
+
start_rpyc_service_process,
|
42
|
+
suppress_other_loggers,
|
43
|
+
)
|
44
|
+
from sglang.utils import get_exception_traceback
|
45
|
+
|
46
|
+
logger = logging.getLogger("srt.tp_worker")
|
47
|
+
|
48
|
+
|
49
|
+
class ModelTpServer:
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
gpu_id: int,
|
53
|
+
tp_rank: int,
|
54
|
+
server_args: ServerArgs,
|
55
|
+
model_port_args: ModelPortArgs,
|
56
|
+
model_overide_args: dict,
|
57
|
+
):
|
58
|
+
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
59
|
+
suppress_other_loggers()
|
60
|
+
|
61
|
+
# Copy arguments
|
62
|
+
self.gpu_id = gpu_id
|
63
|
+
self.tp_rank = tp_rank
|
64
|
+
self.tp_size = server_args.tp_size
|
65
|
+
self.dp_size = server_args.dp_size
|
66
|
+
self.schedule_heuristic = server_args.schedule_heuristic
|
67
|
+
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
68
|
+
|
69
|
+
# Init model and tokenizer
|
70
|
+
self.model_config = ModelConfig(
|
71
|
+
server_args.model_path,
|
72
|
+
server_args.trust_remote_code,
|
73
|
+
context_length=server_args.context_length,
|
74
|
+
model_overide_args=model_overide_args,
|
75
|
+
)
|
76
|
+
self.model_runner = ModelRunner(
|
77
|
+
model_config=self.model_config,
|
78
|
+
mem_fraction_static=server_args.mem_fraction_static,
|
79
|
+
gpu_id=gpu_id,
|
80
|
+
tp_rank=tp_rank,
|
81
|
+
tp_size=server_args.tp_size,
|
82
|
+
nccl_port=model_port_args.nccl_port,
|
83
|
+
server_args=server_args,
|
84
|
+
)
|
85
|
+
|
86
|
+
if is_multimodal_model(server_args.model_path):
|
87
|
+
self.processor = get_processor(
|
88
|
+
server_args.tokenizer_path,
|
89
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
90
|
+
trust_remote_code=server_args.trust_remote_code,
|
91
|
+
)
|
92
|
+
self.tokenizer = self.processor.tokenizer
|
93
|
+
else:
|
94
|
+
self.tokenizer = get_tokenizer(
|
95
|
+
server_args.tokenizer_path,
|
96
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
97
|
+
trust_remote_code=server_args.trust_remote_code,
|
98
|
+
)
|
99
|
+
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
100
|
+
self.max_prefill_tokens = (
|
101
|
+
16384
|
102
|
+
if server_args.max_prefill_tokens is None
|
103
|
+
else server_args.max_prefill_tokens
|
104
|
+
)
|
105
|
+
self.max_running_requests = (
|
106
|
+
self.max_total_num_tokens // 2
|
107
|
+
if server_args.max_running_requests is None
|
108
|
+
else server_args.max_running_requests
|
109
|
+
)
|
110
|
+
self.int_token_logit_bias = torch.tensor(
|
111
|
+
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
112
|
+
)
|
113
|
+
set_random_seed(server_args.random_seed)
|
114
|
+
|
115
|
+
# Print info
|
116
|
+
logger.info(
|
117
|
+
f"[gpu_id={self.gpu_id}] "
|
118
|
+
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
119
|
+
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
120
|
+
f"context_len={self.model_config.context_len}"
|
121
|
+
)
|
122
|
+
if self.tp_rank == 0:
|
123
|
+
logger.info(
|
124
|
+
f"[gpu_id={self.gpu_id}] "
|
125
|
+
f"server_args: {server_args.print_mode_args()}"
|
126
|
+
)
|
127
|
+
|
128
|
+
# Init cache
|
129
|
+
self.tree_cache = RadixCache(
|
130
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
131
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
132
|
+
disable=server_args.disable_radix_cache,
|
133
|
+
)
|
134
|
+
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
135
|
+
self.scheduler = ScheduleHeuristic(
|
136
|
+
self.schedule_heuristic,
|
137
|
+
self.max_running_requests,
|
138
|
+
self.max_prefill_tokens,
|
139
|
+
self.max_total_num_tokens,
|
140
|
+
self.tree_cache,
|
141
|
+
)
|
142
|
+
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
143
|
+
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
144
|
+
|
145
|
+
# Init running status
|
146
|
+
self.forward_queue: List[Req] = []
|
147
|
+
self.running_batch: Batch = None
|
148
|
+
self.out_pyobjs = []
|
149
|
+
self.decode_forward_ct = 0
|
150
|
+
self.stream_interval = server_args.stream_interval
|
151
|
+
self.num_generated_tokens = 0
|
152
|
+
self.last_stats_tic = time.time()
|
153
|
+
|
154
|
+
# Init the FSM cache for constrained generation
|
155
|
+
self.regex_fsm_cache = FSMCache(
|
156
|
+
server_args.tokenizer_path,
|
157
|
+
{
|
158
|
+
"tokenizer_mode": server_args.tokenizer_mode,
|
159
|
+
"trust_remote_code": server_args.trust_remote_code,
|
160
|
+
},
|
161
|
+
)
|
162
|
+
self.jump_forward_cache = JumpForwardCache()
|
163
|
+
|
164
|
+
# Init new token estimation
|
165
|
+
assert (
|
166
|
+
server_args.schedule_conservativeness >= 0
|
167
|
+
), "Invalid schedule_conservativeness"
|
168
|
+
self.new_token_ratio = min(
|
169
|
+
global_config.base_new_token_ratio * server_args.schedule_conservativeness,
|
170
|
+
1.0,
|
171
|
+
)
|
172
|
+
self.min_new_token_ratio = min(
|
173
|
+
global_config.base_min_new_token_ratio
|
174
|
+
* server_args.schedule_conservativeness,
|
175
|
+
1.0,
|
176
|
+
)
|
177
|
+
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
178
|
+
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
179
|
+
|
180
|
+
def exposed_step(self, recv_reqs):
|
181
|
+
if not isinstance(recv_reqs, list):
|
182
|
+
recv_reqs = obtain(recv_reqs)
|
183
|
+
|
184
|
+
try:
|
185
|
+
# Recv requests
|
186
|
+
for recv_req in recv_reqs:
|
187
|
+
if isinstance(recv_req, TokenizedGenerateReqInput):
|
188
|
+
self.handle_generate_request(recv_req)
|
189
|
+
elif isinstance(recv_req, FlushCacheReq):
|
190
|
+
self.flush_cache()
|
191
|
+
elif isinstance(recv_req, AbortReq):
|
192
|
+
self.abort_request(recv_req)
|
193
|
+
else:
|
194
|
+
raise ValueError(f"Invalid request: {recv_req}")
|
195
|
+
|
196
|
+
# Forward
|
197
|
+
self.forward_step()
|
198
|
+
except Exception:
|
199
|
+
logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
|
200
|
+
raise
|
201
|
+
|
202
|
+
# Return results
|
203
|
+
ret = self.out_pyobjs
|
204
|
+
self.out_pyobjs = []
|
205
|
+
return ret
|
206
|
+
|
207
|
+
@torch.inference_mode()
|
208
|
+
def forward_step(self):
|
209
|
+
new_batch = self.get_new_prefill_batch()
|
210
|
+
|
211
|
+
if new_batch is not None:
|
212
|
+
# Run a new prefill batch
|
213
|
+
self.forward_prefill_batch(new_batch)
|
214
|
+
self.cache_filled_batch(new_batch)
|
215
|
+
|
216
|
+
if not new_batch.is_empty():
|
217
|
+
if self.running_batch is None:
|
218
|
+
self.running_batch = new_batch
|
219
|
+
else:
|
220
|
+
self.running_batch.merge(new_batch)
|
221
|
+
else:
|
222
|
+
# Run a decode batch
|
223
|
+
if self.running_batch is not None:
|
224
|
+
# Run a few decode batches continuously for reducing overhead
|
225
|
+
for _ in range(global_config.num_continue_decode_steps):
|
226
|
+
self.num_generated_tokens += len(self.running_batch.reqs)
|
227
|
+
self.forward_decode_batch(self.running_batch)
|
228
|
+
|
229
|
+
# Print stats
|
230
|
+
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
231
|
+
num_used = self.max_total_num_tokens - (
|
232
|
+
self.token_to_kv_pool.available_size()
|
233
|
+
+ self.tree_cache.evictable_size()
|
234
|
+
)
|
235
|
+
throughput = self.num_generated_tokens / (
|
236
|
+
time.time() - self.last_stats_tic
|
237
|
+
)
|
238
|
+
self.num_generated_tokens = 0
|
239
|
+
self.last_stats_tic = time.time()
|
240
|
+
logger.info(
|
241
|
+
f"[gpu_id={self.gpu_id}] Decode batch. "
|
242
|
+
f"#running-req: {len(self.running_batch.reqs)}, "
|
243
|
+
f"#token: {num_used}, "
|
244
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
245
|
+
f"gen throughput (token/s): {throughput:.2f}, "
|
246
|
+
f"#queue-req: {len(self.forward_queue)}"
|
247
|
+
)
|
248
|
+
|
249
|
+
if self.running_batch.is_empty():
|
250
|
+
self.running_batch = None
|
251
|
+
break
|
252
|
+
|
253
|
+
if self.out_pyobjs and self.running_batch.has_stream():
|
254
|
+
break
|
255
|
+
else:
|
256
|
+
# Check the available size
|
257
|
+
available_size = (
|
258
|
+
self.token_to_kv_pool.available_size()
|
259
|
+
+ self.tree_cache.evictable_size()
|
260
|
+
)
|
261
|
+
if available_size != self.max_total_num_tokens:
|
262
|
+
warnings.warn(
|
263
|
+
"Warning: "
|
264
|
+
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
265
|
+
"KV cache pool leak detected!"
|
266
|
+
)
|
267
|
+
|
268
|
+
def handle_generate_request(
|
269
|
+
self,
|
270
|
+
recv_req: TokenizedGenerateReqInput,
|
271
|
+
):
|
272
|
+
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
273
|
+
req.pixel_values = recv_req.pixel_values
|
274
|
+
if req.pixel_values is not None:
|
275
|
+
req.pad_value = [
|
276
|
+
(recv_req.image_hash) % self.model_config.vocab_size,
|
277
|
+
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
278
|
+
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
279
|
+
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
280
|
+
]
|
281
|
+
req.image_size = recv_req.image_size
|
282
|
+
(
|
283
|
+
req.origin_input_ids,
|
284
|
+
req.image_offset,
|
285
|
+
) = self.model_runner.model.pad_input_ids(
|
286
|
+
req.origin_input_ids_unpadded,
|
287
|
+
req.pad_value,
|
288
|
+
req.pixel_values.shape,
|
289
|
+
req.image_size,
|
290
|
+
)
|
291
|
+
req.sampling_params = recv_req.sampling_params
|
292
|
+
req.return_logprob = recv_req.return_logprob
|
293
|
+
req.logprob_start_len = recv_req.logprob_start_len
|
294
|
+
req.top_logprobs_num = recv_req.top_logprobs_num
|
295
|
+
req.stream = recv_req.stream
|
296
|
+
req.tokenizer = self.tokenizer
|
297
|
+
|
298
|
+
# Init regex fsm
|
299
|
+
if req.sampling_params.regex is not None:
|
300
|
+
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
301
|
+
if not self.disable_regex_jump_forward:
|
302
|
+
req.jump_forward_map = self.jump_forward_cache.query(
|
303
|
+
req.sampling_params.regex
|
304
|
+
)
|
305
|
+
|
306
|
+
# Truncate prompts that are too long
|
307
|
+
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
|
308
|
+
req.sampling_params.max_new_tokens = min(
|
309
|
+
req.sampling_params.max_new_tokens,
|
310
|
+
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
311
|
+
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
312
|
+
)
|
313
|
+
self.forward_queue.append(req)
|
314
|
+
|
315
|
+
def get_new_prefill_batch(self) -> Optional[Batch]:
|
316
|
+
running_bs = (
|
317
|
+
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
318
|
+
)
|
319
|
+
if running_bs >= self.max_running_requests:
|
320
|
+
return
|
321
|
+
|
322
|
+
# Compute matched prefix length
|
323
|
+
for req in self.forward_queue:
|
324
|
+
req.input_ids = req.origin_input_ids + req.output_ids
|
325
|
+
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
326
|
+
if req.return_logprob:
|
327
|
+
prefix_indices = prefix_indices[: req.logprob_start_len]
|
328
|
+
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
329
|
+
req.prefix_indices = prefix_indices
|
330
|
+
req.last_node = last_node
|
331
|
+
|
332
|
+
# Get priority queue
|
333
|
+
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
|
334
|
+
|
335
|
+
# Add requests if there is available space
|
336
|
+
can_run_list = []
|
337
|
+
new_batch_total_tokens = 0
|
338
|
+
new_batch_input_tokens = 0
|
339
|
+
|
340
|
+
available_size = (
|
341
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
342
|
+
)
|
343
|
+
if self.running_batch:
|
344
|
+
available_size -= sum(
|
345
|
+
[
|
346
|
+
(r.sampling_params.max_new_tokens - len(r.output_ids)) * self.new_token_ratio
|
347
|
+
for r in self.running_batch.reqs
|
348
|
+
]
|
349
|
+
)
|
350
|
+
|
351
|
+
for req in self.forward_queue:
|
352
|
+
if req.return_logprob and req.normalized_prompt_logprob is None:
|
353
|
+
# Need at least two tokens to compute normalized logprob
|
354
|
+
if req.extend_input_len < 2:
|
355
|
+
delta = 2 - req.extend_input_len
|
356
|
+
req.extend_input_len += delta
|
357
|
+
req.prefix_indices = req.prefix_indices[:-delta]
|
358
|
+
if req.image_offset is not None:
|
359
|
+
req.image_offset += delta
|
360
|
+
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
|
361
|
+
# Need at least one token to compute logits
|
362
|
+
req.extend_input_len = 1
|
363
|
+
req.prefix_indices = req.prefix_indices[:-1]
|
364
|
+
if req.image_offset is not None:
|
365
|
+
req.image_offset += 1
|
366
|
+
|
367
|
+
if (
|
368
|
+
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
|
369
|
+
< available_size
|
370
|
+
and (
|
371
|
+
req.extend_input_len + new_batch_input_tokens
|
372
|
+
<= self.max_prefill_tokens
|
373
|
+
or len(can_run_list) == 0
|
374
|
+
)
|
375
|
+
):
|
376
|
+
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
377
|
+
available_size += delta
|
378
|
+
|
379
|
+
if not (
|
380
|
+
req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
|
381
|
+
< available_size
|
382
|
+
):
|
383
|
+
# Undo locking
|
384
|
+
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
385
|
+
available_size += delta
|
386
|
+
break
|
387
|
+
else:
|
388
|
+
# Add this request to the running batch
|
389
|
+
can_run_list.append(req)
|
390
|
+
new_batch_total_tokens += (
|
391
|
+
req.extend_input_len + req.sampling_params.max_new_tokens
|
392
|
+
)
|
393
|
+
new_batch_input_tokens += req.extend_input_len
|
394
|
+
else:
|
395
|
+
break
|
396
|
+
|
397
|
+
if running_bs + len(can_run_list) >= self.max_running_requests:
|
398
|
+
break
|
399
|
+
|
400
|
+
if len(can_run_list) == 0:
|
401
|
+
return None
|
402
|
+
|
403
|
+
# Print stats
|
404
|
+
if self.tp_rank == 0:
|
405
|
+
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
|
406
|
+
self.tree_cache_metrics["total"] += (
|
407
|
+
hit_tokens + new_batch_input_tokens
|
408
|
+
) / 10**9
|
409
|
+
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
|
410
|
+
tree_cache_hit_rate = (
|
411
|
+
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
412
|
+
)
|
413
|
+
logger.info(
|
414
|
+
f"[gpu_id={self.gpu_id}] Prefill batch. "
|
415
|
+
f"#new-seq: {len(can_run_list)}, "
|
416
|
+
f"#new-token: {new_batch_input_tokens}, "
|
417
|
+
f"#cached-token: {hit_tokens}, "
|
418
|
+
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
419
|
+
f"#running-req: {running_bs}, "
|
420
|
+
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
421
|
+
)
|
422
|
+
# logger.debug(
|
423
|
+
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
424
|
+
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
425
|
+
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
426
|
+
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
427
|
+
# )
|
428
|
+
|
429
|
+
# Return the new batch
|
430
|
+
new_batch = Batch.init_new(
|
431
|
+
can_run_list,
|
432
|
+
self.req_to_token_pool,
|
433
|
+
self.token_to_kv_pool,
|
434
|
+
self.tree_cache,
|
435
|
+
)
|
436
|
+
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
437
|
+
return new_batch
|
438
|
+
|
439
|
+
def forward_prefill_batch(self, batch: Batch):
|
440
|
+
# Build batch tensors
|
441
|
+
batch.prepare_for_extend(
|
442
|
+
self.model_config.vocab_size, self.int_token_logit_bias
|
443
|
+
)
|
444
|
+
|
445
|
+
# Forward and sample the next tokens
|
446
|
+
if batch.extend_num_tokens != 0:
|
447
|
+
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
448
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
449
|
+
|
450
|
+
# Move logprobs to cpu
|
451
|
+
if output.next_token_logprobs is not None:
|
452
|
+
output.next_token_logprobs = output.next_token_logprobs[
|
453
|
+
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
454
|
+
next_token_ids,
|
455
|
+
].tolist()
|
456
|
+
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
|
457
|
+
output.normalized_prompt_logprobs = (
|
458
|
+
output.normalized_prompt_logprobs.tolist()
|
459
|
+
)
|
460
|
+
|
461
|
+
next_token_ids = next_token_ids.tolist()
|
462
|
+
else:
|
463
|
+
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
464
|
+
|
465
|
+
# Check finish conditions
|
466
|
+
pt = 0
|
467
|
+
for i, req in enumerate(batch.reqs):
|
468
|
+
req.completion_tokens_wo_jump_forward += 1
|
469
|
+
req.output_ids.append(next_token_ids[i])
|
470
|
+
req.check_finished()
|
471
|
+
|
472
|
+
if req.return_logprob:
|
473
|
+
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
474
|
+
pt += req.extend_input_len
|
475
|
+
|
476
|
+
self.handle_finished_requests(batch)
|
477
|
+
|
478
|
+
def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
|
479
|
+
if req.normalized_prompt_logprob is None:
|
480
|
+
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
481
|
+
|
482
|
+
if req.prefill_token_logprobs is None:
|
483
|
+
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
484
|
+
req.prefill_token_logprobs = list(
|
485
|
+
zip(
|
486
|
+
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
487
|
+
req.input_ids[-req.extend_input_len + 1 :],
|
488
|
+
)
|
489
|
+
)
|
490
|
+
if req.logprob_start_len == 0:
|
491
|
+
req.prefill_token_logprobs = [
|
492
|
+
(None, req.input_ids[0])
|
493
|
+
] + req.prefill_token_logprobs
|
494
|
+
|
495
|
+
if req.last_update_decode_tokens != 0:
|
496
|
+
req.decode_token_logprobs.extend(
|
497
|
+
list(
|
498
|
+
zip(
|
499
|
+
output.prefill_token_logprobs[
|
500
|
+
pt
|
501
|
+
+ req.extend_input_len
|
502
|
+
- req.last_update_decode_tokens : pt
|
503
|
+
+ req.extend_input_len
|
504
|
+
- 1
|
505
|
+
],
|
506
|
+
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
507
|
+
)
|
508
|
+
)
|
509
|
+
)
|
510
|
+
|
511
|
+
req.decode_token_logprobs.append(
|
512
|
+
(output.next_token_logprobs[i], next_token_ids[i])
|
513
|
+
)
|
514
|
+
|
515
|
+
if req.top_logprobs_num > 0:
|
516
|
+
if req.prefill_top_logprobs is None:
|
517
|
+
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
|
518
|
+
if req.logprob_start_len == 0:
|
519
|
+
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
520
|
+
|
521
|
+
if req.last_update_decode_tokens != 0:
|
522
|
+
req.decode_top_logprobs.extend(
|
523
|
+
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
524
|
+
)
|
525
|
+
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
526
|
+
|
527
|
+
def cache_filled_batch(self, batch: Batch):
|
528
|
+
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
529
|
+
for i, req in enumerate(batch.reqs):
|
530
|
+
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
531
|
+
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
532
|
+
last_uncached_pos=len(req.prefix_indices),
|
533
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
534
|
+
del_in_memory_pool=False,
|
535
|
+
old_last_node=req.last_node,
|
536
|
+
)
|
537
|
+
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
538
|
+
|
539
|
+
def forward_decode_batch(self, batch: Batch):
|
540
|
+
# Check if decode out of memory
|
541
|
+
if not batch.check_decode_mem():
|
542
|
+
old_ratio = self.new_token_ratio
|
543
|
+
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
544
|
+
|
545
|
+
retracted_reqs = batch.retract_decode()
|
546
|
+
logger.info(
|
547
|
+
"decode out of memory happened, "
|
548
|
+
f"#retracted_reqs: {len(retracted_reqs)}, "
|
549
|
+
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
550
|
+
)
|
551
|
+
self.forward_queue.extend(retracted_reqs)
|
552
|
+
else:
|
553
|
+
self.new_token_ratio = max(
|
554
|
+
self.new_token_ratio - self.new_token_ratio_decay,
|
555
|
+
self.min_new_token_ratio,
|
556
|
+
)
|
557
|
+
|
558
|
+
if not self.disable_regex_jump_forward:
|
559
|
+
# Check for jump-forward
|
560
|
+
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
561
|
+
self.forward_queue.extend(jump_forward_reqs)
|
562
|
+
if batch.is_empty():
|
563
|
+
return
|
564
|
+
|
565
|
+
# Update batch tensors
|
566
|
+
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
567
|
+
batch.prepare_for_decode()
|
568
|
+
|
569
|
+
# Forward and sample the next tokens
|
570
|
+
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
571
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
572
|
+
|
573
|
+
# Move logprobs to cpu
|
574
|
+
if output.next_token_logprobs is not None:
|
575
|
+
next_token_logprobs = output.next_token_logprobs[
|
576
|
+
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
577
|
+
next_token_ids,
|
578
|
+
].tolist()
|
579
|
+
|
580
|
+
next_token_ids = next_token_ids.tolist()
|
581
|
+
|
582
|
+
# Check finish condition
|
583
|
+
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
584
|
+
req.completion_tokens_wo_jump_forward += 1
|
585
|
+
req.output_ids.append(next_token_id)
|
586
|
+
req.check_finished()
|
587
|
+
|
588
|
+
if req.return_logprob:
|
589
|
+
req.decode_token_logprobs.append(
|
590
|
+
(next_token_logprobs[i], next_token_id)
|
591
|
+
)
|
592
|
+
if req.top_logprobs_num > 0:
|
593
|
+
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
594
|
+
|
595
|
+
self.handle_finished_requests(batch)
|
596
|
+
|
597
|
+
def handle_finished_requests(self, batch: Batch):
|
598
|
+
output_rids = []
|
599
|
+
decoded_texts = []
|
600
|
+
surr_output_ids = []
|
601
|
+
read_output_ids = []
|
602
|
+
output_skip_special_tokens = []
|
603
|
+
output_spaces_between_special_tokens = []
|
604
|
+
output_meta_info = []
|
605
|
+
output_finished_reason: List[BaseFinishReason] = []
|
606
|
+
finished_indices = []
|
607
|
+
unfinished_indices = []
|
608
|
+
for i, req in enumerate(batch.reqs):
|
609
|
+
if req.finished():
|
610
|
+
finished_indices.append(i)
|
611
|
+
else:
|
612
|
+
unfinished_indices.append(i)
|
613
|
+
|
614
|
+
if req.finished() or (
|
615
|
+
(
|
616
|
+
req.stream
|
617
|
+
and (
|
618
|
+
self.decode_forward_ct % self.stream_interval == 0
|
619
|
+
or len(req.output_ids) == 1
|
620
|
+
)
|
621
|
+
)
|
622
|
+
):
|
623
|
+
output_rids.append(req.rid)
|
624
|
+
decoded_texts.append(req.decoded_text)
|
625
|
+
surr_ids, read_ids, _ = req.init_detokenize_incrementally()
|
626
|
+
surr_output_ids.append(surr_ids)
|
627
|
+
read_output_ids.append(read_ids)
|
628
|
+
output_skip_special_tokens.append(
|
629
|
+
req.sampling_params.skip_special_tokens
|
630
|
+
)
|
631
|
+
output_spaces_between_special_tokens.append(
|
632
|
+
req.sampling_params.spaces_between_special_tokens
|
633
|
+
)
|
634
|
+
|
635
|
+
meta_info = {
|
636
|
+
"prompt_tokens": len(req.origin_input_ids),
|
637
|
+
"completion_tokens": len(req.output_ids),
|
638
|
+
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
639
|
+
"finish_reason": str(req.finished_reason),
|
640
|
+
}
|
641
|
+
if req.return_logprob:
|
642
|
+
(
|
643
|
+
meta_info["prefill_token_logprobs"],
|
644
|
+
meta_info["decode_token_logprobs"],
|
645
|
+
meta_info["prefill_top_logprobs"],
|
646
|
+
meta_info["decode_top_logprobs"],
|
647
|
+
meta_info["normalized_prompt_logprob"],
|
648
|
+
) = (
|
649
|
+
req.prefill_token_logprobs,
|
650
|
+
req.decode_token_logprobs,
|
651
|
+
req.prefill_top_logprobs,
|
652
|
+
req.decode_top_logprobs,
|
653
|
+
req.normalized_prompt_logprob,
|
654
|
+
)
|
655
|
+
output_meta_info.append(meta_info)
|
656
|
+
output_finished_reason.append(req.finished_reason)
|
657
|
+
|
658
|
+
# Send to detokenizer
|
659
|
+
if output_rids:
|
660
|
+
self.out_pyobjs.append(
|
661
|
+
BatchTokenIDOut(
|
662
|
+
output_rids,
|
663
|
+
decoded_texts,
|
664
|
+
surr_output_ids,
|
665
|
+
read_output_ids,
|
666
|
+
output_skip_special_tokens,
|
667
|
+
output_spaces_between_special_tokens,
|
668
|
+
output_meta_info,
|
669
|
+
output_finished_reason,
|
670
|
+
)
|
671
|
+
)
|
672
|
+
|
673
|
+
# Remove finished reqs
|
674
|
+
if finished_indices:
|
675
|
+
# Update radix cache
|
676
|
+
req_pool_indices_cpu = batch.req_pool_indices.tolist()
|
677
|
+
for i in finished_indices:
|
678
|
+
req = batch.reqs[i]
|
679
|
+
self.tree_cache.cache_req(
|
680
|
+
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
681
|
+
last_uncached_pos=len(req.prefix_indices),
|
682
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
683
|
+
)
|
684
|
+
|
685
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
686
|
+
|
687
|
+
# Update batch tensors
|
688
|
+
if unfinished_indices:
|
689
|
+
batch.filter_batch(unfinished_indices)
|
690
|
+
else:
|
691
|
+
batch.reqs = []
|
692
|
+
|
693
|
+
def flush_cache(self):
|
694
|
+
if len(self.forward_queue) == 0 and (
|
695
|
+
self.running_batch is None or len(self.running_batch.reqs) == 0
|
696
|
+
):
|
697
|
+
self.tree_cache.reset()
|
698
|
+
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
699
|
+
self.regex_fsm_cache.reset()
|
700
|
+
self.req_to_token_pool.clear()
|
701
|
+
self.token_to_kv_pool.clear()
|
702
|
+
torch.cuda.empty_cache()
|
703
|
+
logger.info("Cache flushed successfully!")
|
704
|
+
else:
|
705
|
+
warnings.warn(
|
706
|
+
f"Cache not flushed because there are pending requests. "
|
707
|
+
f"#queue-req: {len(self.forward_queue)}, "
|
708
|
+
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
709
|
+
)
|
710
|
+
|
711
|
+
def abort_request(self, recv_req):
|
712
|
+
# Delete requests in the waiting queue
|
713
|
+
to_del = None
|
714
|
+
for i, req in enumerate(self.forward_queue):
|
715
|
+
if req.rid == recv_req.rid:
|
716
|
+
to_del = i
|
717
|
+
break
|
718
|
+
|
719
|
+
if to_del is not None:
|
720
|
+
del self.forward_queue[to_del]
|
721
|
+
|
722
|
+
# Delete requests in the running batch
|
723
|
+
if self.running_batch:
|
724
|
+
for req in self.running_batch.reqs:
|
725
|
+
if req.rid == recv_req.rid:
|
726
|
+
req.finished_reason = FINISH_ABORT()
|
727
|
+
break
|
728
|
+
|
729
|
+
|
730
|
+
class ModelTpService(rpyc.Service):
|
731
|
+
exposed_ModelTpServer = ModelTpServer
|
732
|
+
|
733
|
+
|
734
|
+
class ModelTpClient:
|
735
|
+
def __init__(
|
736
|
+
self,
|
737
|
+
gpu_ids: List[int],
|
738
|
+
server_args: ServerArgs,
|
739
|
+
model_port_args: ModelPortArgs,
|
740
|
+
model_overide_args,
|
741
|
+
):
|
742
|
+
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
743
|
+
self.tp_size = server_args.tp_size
|
744
|
+
|
745
|
+
if self.tp_size * server_args.dp_size == 1:
|
746
|
+
# Init model
|
747
|
+
assert len(gpu_ids) == 1
|
748
|
+
self.model_server = ModelTpService().exposed_ModelTpServer(
|
749
|
+
gpu_ids[0],
|
750
|
+
0,
|
751
|
+
server_args,
|
752
|
+
model_port_args,
|
753
|
+
model_overide_args,
|
754
|
+
)
|
755
|
+
|
756
|
+
# Wrap functions
|
757
|
+
def async_wrap(f):
|
758
|
+
async def _func(*args, **kwargs):
|
759
|
+
return f(*args, **kwargs)
|
760
|
+
|
761
|
+
return _func
|
762
|
+
|
763
|
+
self.step = async_wrap(self.model_server.exposed_step)
|
764
|
+
else:
|
765
|
+
with ThreadPoolExecutor(self.tp_size) as executor:
|
766
|
+
# Launch model processes
|
767
|
+
if server_args.nnodes == 1:
|
768
|
+
self.procs = list(
|
769
|
+
executor.map(
|
770
|
+
lambda args: start_rpyc_service_process(*args),
|
771
|
+
[
|
772
|
+
(ModelTpService, p)
|
773
|
+
for p in model_port_args.model_tp_ports
|
774
|
+
],
|
775
|
+
)
|
776
|
+
)
|
777
|
+
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
|
778
|
+
else:
|
779
|
+
addrs = [
|
780
|
+
(ip, port)
|
781
|
+
for ip, port in zip(
|
782
|
+
model_port_args.model_tp_ips, model_port_args.model_tp_ports
|
783
|
+
)
|
784
|
+
]
|
785
|
+
|
786
|
+
self.model_services = list(
|
787
|
+
executor.map(lambda args: connect_rpyc_service(*args), addrs)
|
788
|
+
)
|
789
|
+
|
790
|
+
# Init model
|
791
|
+
def init_model(i):
|
792
|
+
return self.model_services[i].ModelTpServer(
|
793
|
+
gpu_ids[i],
|
794
|
+
i,
|
795
|
+
server_args,
|
796
|
+
model_port_args,
|
797
|
+
model_overide_args,
|
798
|
+
)
|
799
|
+
|
800
|
+
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
|
801
|
+
|
802
|
+
# Wrap functions
|
803
|
+
def async_wrap(func_name):
|
804
|
+
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
|
805
|
+
|
806
|
+
async def _func(*args, **kwargs):
|
807
|
+
tasks = [f(*args, **kwargs) for f in fs]
|
808
|
+
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
|
809
|
+
return obtain(tasks[0].value)
|
810
|
+
|
811
|
+
return _func
|
812
|
+
|
813
|
+
self.step = async_wrap("step")
|