sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +3 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +8 -1
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +17 -2
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +1 -1
- sglang/srt/hf_transformers_utils.py +75 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +15 -11
- sglang/srt/managers/router/infer_batch.py +103 -59
- sglang/srt/managers/router/manager.py +1 -1
- sglang/srt/managers/router/model_rpc.py +175 -122
- sglang/srt/managers/router/model_runner.py +91 -104
- sglang/srt/managers/router/radix_cache.py +7 -1
- sglang/srt/managers/router/scheduler.py +6 -6
- sglang/srt/managers/tokenizer_manager.py +152 -89
- sglang/srt/model_config.py +4 -5
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +8 -15
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +19 -15
- sglang/srt/models/llava.py +84 -20
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +248 -118
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +77 -42
- sglang/srt/server_args.py +51 -6
- sglang/srt/utils.py +124 -66
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +22 -4
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,791 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
import time
|
4
|
+
import warnings
|
5
|
+
from concurrent.futures import ThreadPoolExecutor
|
6
|
+
from typing import List
|
7
|
+
|
8
|
+
import rpyc
|
9
|
+
import torch
|
10
|
+
from rpyc.utils.classic import obtain
|
11
|
+
|
12
|
+
from sglang.global_config import global_config
|
13
|
+
from sglang.srt.constrained.fsm_cache import FSMCache
|
14
|
+
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
15
|
+
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
16
|
+
from sglang.srt.managers.io_struct import (
|
17
|
+
AbortReq,
|
18
|
+
BatchTokenIDOut,
|
19
|
+
FlushCacheReq,
|
20
|
+
TokenizedGenerateReqInput,
|
21
|
+
)
|
22
|
+
from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
|
23
|
+
from sglang.srt.managers.controller.model_runner import ModelRunner
|
24
|
+
from sglang.srt.managers.controller.radix_cache import RadixCache
|
25
|
+
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
26
|
+
from sglang.srt.model_config import ModelConfig
|
27
|
+
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
28
|
+
from sglang.srt.utils import (
|
29
|
+
get_int_token_logit_bias,
|
30
|
+
is_multimodal_model,
|
31
|
+
set_random_seed,
|
32
|
+
start_rpyc_process,
|
33
|
+
suppress_other_loggers,
|
34
|
+
)
|
35
|
+
from sglang.utils import get_exception_traceback
|
36
|
+
|
37
|
+
logger = logging.getLogger("srt.tp_worker")
|
38
|
+
|
39
|
+
|
40
|
+
class ModelTpServer:
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
gpu_id: int,
|
44
|
+
tp_rank: int,
|
45
|
+
server_args: ServerArgs,
|
46
|
+
model_port_args: ModelPortArgs,
|
47
|
+
model_overide_args,
|
48
|
+
):
|
49
|
+
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
50
|
+
suppress_other_loggers()
|
51
|
+
|
52
|
+
# Copy arguments
|
53
|
+
self.gpu_id = gpu_id
|
54
|
+
self.tp_rank = tp_rank
|
55
|
+
self.tp_size = server_args.tp_size
|
56
|
+
self.dp_size = server_args.dp_size
|
57
|
+
self.schedule_heuristic = server_args.schedule_heuristic
|
58
|
+
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
59
|
+
|
60
|
+
# Init model and tokenizer
|
61
|
+
self.model_config = ModelConfig(
|
62
|
+
server_args.model_path,
|
63
|
+
server_args.trust_remote_code,
|
64
|
+
context_length=server_args.context_length,
|
65
|
+
model_overide_args=model_overide_args,
|
66
|
+
)
|
67
|
+
self.model_runner = ModelRunner(
|
68
|
+
model_config=self.model_config,
|
69
|
+
mem_fraction_static=server_args.mem_fraction_static,
|
70
|
+
gpu_id=gpu_id,
|
71
|
+
tp_rank=tp_rank,
|
72
|
+
tp_size=server_args.tp_size,
|
73
|
+
nccl_port=model_port_args.nccl_port,
|
74
|
+
server_args=server_args,
|
75
|
+
)
|
76
|
+
|
77
|
+
if is_multimodal_model(server_args.model_path):
|
78
|
+
self.processor = get_processor(
|
79
|
+
server_args.tokenizer_path,
|
80
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
81
|
+
trust_remote_code=server_args.trust_remote_code,
|
82
|
+
)
|
83
|
+
self.tokenizer = self.processor.tokenizer
|
84
|
+
else:
|
85
|
+
self.tokenizer = get_tokenizer(
|
86
|
+
server_args.tokenizer_path,
|
87
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
88
|
+
trust_remote_code=server_args.trust_remote_code,
|
89
|
+
)
|
90
|
+
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
91
|
+
self.max_prefill_tokens = max(
|
92
|
+
self.model_config.context_len,
|
93
|
+
(
|
94
|
+
min(self.max_total_num_tokens // 6, 65536)
|
95
|
+
if server_args.max_prefill_tokens is None
|
96
|
+
else server_args.max_prefill_tokens
|
97
|
+
),
|
98
|
+
)
|
99
|
+
self.max_running_requests = (self.max_total_num_tokens // 2
|
100
|
+
if server_args.max_running_requests is None else server_args.max_running_requests)
|
101
|
+
self.int_token_logit_bias = torch.tensor(
|
102
|
+
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
103
|
+
)
|
104
|
+
set_random_seed(server_args.random_seed)
|
105
|
+
|
106
|
+
# Print info
|
107
|
+
logger.info(
|
108
|
+
f"[gpu_id={self.gpu_id}] "
|
109
|
+
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
110
|
+
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
111
|
+
f"context_len={self.model_config.context_len}, "
|
112
|
+
)
|
113
|
+
if self.tp_rank == 0:
|
114
|
+
logger.info(
|
115
|
+
f"[gpu_id={self.gpu_id}] "
|
116
|
+
f"server_args: {server_args.print_mode_args()}"
|
117
|
+
)
|
118
|
+
|
119
|
+
# Init cache
|
120
|
+
self.tree_cache = RadixCache(
|
121
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
122
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
123
|
+
disable=server_args.disable_radix_cache,
|
124
|
+
)
|
125
|
+
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
126
|
+
self.scheduler = ScheduleHeuristic(
|
127
|
+
self.schedule_heuristic,
|
128
|
+
self.max_running_requests,
|
129
|
+
self.max_prefill_tokens,
|
130
|
+
self.max_total_num_tokens,
|
131
|
+
self.tree_cache,
|
132
|
+
)
|
133
|
+
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
134
|
+
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
135
|
+
|
136
|
+
# Init running status
|
137
|
+
self.forward_queue: List[Req] = []
|
138
|
+
self.running_batch: Batch = None
|
139
|
+
self.out_pyobjs = []
|
140
|
+
self.decode_forward_ct = 0
|
141
|
+
self.stream_interval = server_args.stream_interval
|
142
|
+
self.num_generated_tokens = 0
|
143
|
+
self.last_stats_tic = time.time()
|
144
|
+
|
145
|
+
# Init the FSM cache for constrained generation
|
146
|
+
self.regex_fsm_cache = FSMCache(
|
147
|
+
server_args.tokenizer_path,
|
148
|
+
{
|
149
|
+
"tokenizer_mode": server_args.tokenizer_mode,
|
150
|
+
"trust_remote_code": server_args.trust_remote_code,
|
151
|
+
},
|
152
|
+
)
|
153
|
+
self.jump_forward_cache = JumpForwardCache()
|
154
|
+
|
155
|
+
# Init new token estimation
|
156
|
+
assert (
|
157
|
+
server_args.schedule_conservativeness >= 0
|
158
|
+
), "Invalid schedule_conservativeness"
|
159
|
+
self.new_token_ratio = min(
|
160
|
+
global_config.base_new_token_ratio * server_args.schedule_conservativeness,
|
161
|
+
1.0,
|
162
|
+
)
|
163
|
+
self.min_new_token_ratio = min(
|
164
|
+
global_config.base_min_new_token_ratio
|
165
|
+
* server_args.schedule_conservativeness,
|
166
|
+
1.0,
|
167
|
+
)
|
168
|
+
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
169
|
+
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
170
|
+
|
171
|
+
def exposed_step(self, recv_reqs):
|
172
|
+
if self.tp_size * self.dp_size != 1:
|
173
|
+
recv_reqs = obtain(recv_reqs)
|
174
|
+
|
175
|
+
try:
|
176
|
+
# Recv requests
|
177
|
+
for recv_req in recv_reqs:
|
178
|
+
if isinstance(recv_req, TokenizedGenerateReqInput):
|
179
|
+
self.handle_generate_request(recv_req)
|
180
|
+
elif isinstance(recv_req, FlushCacheReq):
|
181
|
+
self.flush_cache()
|
182
|
+
elif isinstance(recv_req, AbortReq):
|
183
|
+
self.abort_request(recv_req)
|
184
|
+
else:
|
185
|
+
raise ValueError(f"Invalid request: {recv_req}")
|
186
|
+
|
187
|
+
# Forward
|
188
|
+
self.forward_step()
|
189
|
+
except Exception:
|
190
|
+
logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
|
191
|
+
raise
|
192
|
+
|
193
|
+
# Return results
|
194
|
+
ret = self.out_pyobjs
|
195
|
+
self.out_pyobjs = []
|
196
|
+
return ret
|
197
|
+
|
198
|
+
@torch.inference_mode()
|
199
|
+
def forward_step(self):
|
200
|
+
new_batch = self.get_new_fill_batch()
|
201
|
+
|
202
|
+
if new_batch is not None:
|
203
|
+
# Run a new fill batch
|
204
|
+
self.forward_fill_batch(new_batch)
|
205
|
+
self.cache_filled_batch(new_batch)
|
206
|
+
|
207
|
+
if not new_batch.is_empty():
|
208
|
+
if self.running_batch is None:
|
209
|
+
self.running_batch = new_batch
|
210
|
+
else:
|
211
|
+
self.running_batch.merge(new_batch)
|
212
|
+
else:
|
213
|
+
# Run decode batch
|
214
|
+
if self.running_batch is not None:
|
215
|
+
# Run a few decode batches continuously for reducing overhead
|
216
|
+
for _ in range(10):
|
217
|
+
self.num_generated_tokens += len(self.running_batch.reqs)
|
218
|
+
self.forward_decode_batch(self.running_batch)
|
219
|
+
|
220
|
+
# Print stats
|
221
|
+
if self.tp_rank == 0:
|
222
|
+
if self.decode_forward_ct % 40 == 0:
|
223
|
+
num_used = self.max_total_num_tokens - (
|
224
|
+
self.token_to_kv_pool.available_size()
|
225
|
+
+ self.tree_cache.evictable_size()
|
226
|
+
)
|
227
|
+
throughput = self.num_generated_tokens / (
|
228
|
+
time.time() - self.last_stats_tic
|
229
|
+
)
|
230
|
+
self.num_generated_tokens = 0
|
231
|
+
self.last_stats_tic = time.time()
|
232
|
+
logger.info(
|
233
|
+
f"[gpu_id={self.gpu_id}] Decode batch. "
|
234
|
+
f"#running-req: {len(self.running_batch.reqs)}, "
|
235
|
+
f"#token: {num_used}, "
|
236
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
237
|
+
f"gen throughput (token/s): {throughput:.2f}, "
|
238
|
+
f"#queue-req: {len(self.forward_queue)}"
|
239
|
+
)
|
240
|
+
|
241
|
+
if self.running_batch.is_empty():
|
242
|
+
self.running_batch = None
|
243
|
+
break
|
244
|
+
|
245
|
+
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
246
|
+
break
|
247
|
+
else:
|
248
|
+
# Check the available size
|
249
|
+
available_size = (
|
250
|
+
self.token_to_kv_pool.available_size()
|
251
|
+
+ self.tree_cache.evictable_size()
|
252
|
+
)
|
253
|
+
if available_size != self.max_total_num_tokens:
|
254
|
+
warnings.warn(
|
255
|
+
"Warning: "
|
256
|
+
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
257
|
+
"KV cache pool leak detected!"
|
258
|
+
)
|
259
|
+
|
260
|
+
def handle_generate_request(
|
261
|
+
self,
|
262
|
+
recv_req: TokenizedGenerateReqInput,
|
263
|
+
):
|
264
|
+
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
265
|
+
req.pixel_values = recv_req.pixel_values
|
266
|
+
if req.pixel_values is not None:
|
267
|
+
req.pad_value = [
|
268
|
+
(recv_req.image_hash) % self.model_config.vocab_size,
|
269
|
+
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
270
|
+
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
271
|
+
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
272
|
+
]
|
273
|
+
req.image_size = recv_req.image_size
|
274
|
+
req.origin_input_ids, req.image_offset = (
|
275
|
+
self.model_runner.model.pad_input_ids(
|
276
|
+
req.origin_input_ids_unpadded,
|
277
|
+
req.pad_value,
|
278
|
+
req.pixel_values.shape,
|
279
|
+
req.image_size,
|
280
|
+
)
|
281
|
+
)
|
282
|
+
req.sampling_params = recv_req.sampling_params
|
283
|
+
req.return_logprob = recv_req.return_logprob
|
284
|
+
req.logprob_start_len = recv_req.logprob_start_len
|
285
|
+
req.top_logprobs_num = recv_req.top_logprobs_num
|
286
|
+
req.stream = recv_req.stream
|
287
|
+
req.tokenizer = self.tokenizer
|
288
|
+
|
289
|
+
# Init regex fsm
|
290
|
+
if req.sampling_params.regex is not None:
|
291
|
+
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
292
|
+
if not self.disable_regex_jump_forward:
|
293
|
+
req.jump_forward_map = self.jump_forward_cache.query(
|
294
|
+
req.sampling_params.regex
|
295
|
+
)
|
296
|
+
|
297
|
+
# Truncate prompts that are too long
|
298
|
+
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
|
299
|
+
req.sampling_params.max_new_tokens = min(
|
300
|
+
req.sampling_params.max_new_tokens,
|
301
|
+
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
302
|
+
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
303
|
+
)
|
304
|
+
self.forward_queue.append(req)
|
305
|
+
|
306
|
+
def get_new_fill_batch(self):
|
307
|
+
if (
|
308
|
+
self.running_batch is not None
|
309
|
+
and len(self.running_batch.reqs) > self.max_running_requests
|
310
|
+
):
|
311
|
+
return None
|
312
|
+
|
313
|
+
# Compute matched prefix length
|
314
|
+
for req in self.forward_queue:
|
315
|
+
assert (
|
316
|
+
len(req.output_ids) == 0
|
317
|
+
), "The output ids should be empty when prefilling"
|
318
|
+
req.input_ids = req.origin_input_ids + req.prev_output_ids
|
319
|
+
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
320
|
+
if req.return_logprob:
|
321
|
+
prefix_indices = prefix_indices[: req.logprob_start_len]
|
322
|
+
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
323
|
+
req.prefix_indices = prefix_indices
|
324
|
+
req.last_node = last_node
|
325
|
+
|
326
|
+
# Get priority queue
|
327
|
+
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
|
328
|
+
|
329
|
+
# Add requests if there is available space
|
330
|
+
can_run_list = []
|
331
|
+
new_batch_total_tokens = 0
|
332
|
+
new_batch_input_tokens = 0
|
333
|
+
|
334
|
+
available_size = (
|
335
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
336
|
+
)
|
337
|
+
if self.running_batch:
|
338
|
+
available_size -= sum(
|
339
|
+
[
|
340
|
+
(r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
|
341
|
+
for r in self.running_batch.reqs
|
342
|
+
]
|
343
|
+
)
|
344
|
+
|
345
|
+
for req in self.forward_queue:
|
346
|
+
if req.return_logprob and req.normalized_prompt_logprob is None:
|
347
|
+
# Need at least two tokens to compute normalized logprob
|
348
|
+
if req.extend_input_len < 2:
|
349
|
+
delta = 2 - req.extend_input_len
|
350
|
+
req.extend_input_len += delta
|
351
|
+
req.prefix_indices = req.prefix_indices[:-delta]
|
352
|
+
if req.image_offset is not None:
|
353
|
+
req.image_offset += delta
|
354
|
+
if req.extend_input_len == 0 and req.max_new_tokens() > 0:
|
355
|
+
# Need at least one token to compute logits
|
356
|
+
req.extend_input_len = 1
|
357
|
+
req.prefix_indices = req.prefix_indices[:-1]
|
358
|
+
if req.image_offset is not None:
|
359
|
+
req.image_offset += 1
|
360
|
+
|
361
|
+
if (
|
362
|
+
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
363
|
+
< available_size
|
364
|
+
and req.extend_input_len + new_batch_input_tokens
|
365
|
+
< self.max_prefill_tokens
|
366
|
+
):
|
367
|
+
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
368
|
+
available_size += delta
|
369
|
+
|
370
|
+
if not (
|
371
|
+
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
372
|
+
< available_size
|
373
|
+
):
|
374
|
+
# Undo locking
|
375
|
+
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
376
|
+
available_size += delta
|
377
|
+
break
|
378
|
+
else:
|
379
|
+
# Add this request to the running batch
|
380
|
+
can_run_list.append(req)
|
381
|
+
new_batch_total_tokens += (
|
382
|
+
req.extend_input_len + req.max_new_tokens()
|
383
|
+
)
|
384
|
+
new_batch_input_tokens += req.extend_input_len
|
385
|
+
else:
|
386
|
+
break
|
387
|
+
if len(can_run_list) == 0:
|
388
|
+
return None
|
389
|
+
|
390
|
+
# Print stats
|
391
|
+
if self.tp_rank == 0:
|
392
|
+
running_req = (
|
393
|
+
0 if self.running_batch is None else len(self.running_batch.reqs)
|
394
|
+
)
|
395
|
+
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
|
396
|
+
self.tree_cache_metrics["total"] += (
|
397
|
+
hit_tokens + new_batch_input_tokens
|
398
|
+
) / 10**9
|
399
|
+
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
|
400
|
+
tree_cache_hit_rate = (
|
401
|
+
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
402
|
+
)
|
403
|
+
logger.info(
|
404
|
+
f"[gpu_id={self.gpu_id}] Prefil batch. "
|
405
|
+
f"#new-seq: {len(can_run_list)}, "
|
406
|
+
f"#new-token: {new_batch_input_tokens}, "
|
407
|
+
f"#cached-token: {hit_tokens}, "
|
408
|
+
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
409
|
+
f"#running-req: {running_req}, "
|
410
|
+
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
411
|
+
)
|
412
|
+
# logger.debug(
|
413
|
+
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
414
|
+
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
415
|
+
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
416
|
+
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
417
|
+
# )
|
418
|
+
|
419
|
+
# Return the new batch
|
420
|
+
new_batch = Batch.init_new(
|
421
|
+
can_run_list,
|
422
|
+
self.req_to_token_pool,
|
423
|
+
self.token_to_kv_pool,
|
424
|
+
self.tree_cache,
|
425
|
+
)
|
426
|
+
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
427
|
+
return new_batch
|
428
|
+
|
429
|
+
def forward_fill_batch(self, batch: Batch):
|
430
|
+
# Build batch tensors
|
431
|
+
batch.prepare_for_extend(
|
432
|
+
self.model_config.vocab_size, self.int_token_logit_bias
|
433
|
+
)
|
434
|
+
|
435
|
+
if batch.extend_num_tokens != 0:
|
436
|
+
# Forward
|
437
|
+
logits, (
|
438
|
+
prefill_token_logprobs,
|
439
|
+
normalized_prompt_logprobs,
|
440
|
+
prefill_top_logprobs,
|
441
|
+
decode_top_logprobs,
|
442
|
+
last_logprobs,
|
443
|
+
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
444
|
+
if prefill_token_logprobs is not None:
|
445
|
+
prefill_token_logprobs = prefill_token_logprobs.tolist()
|
446
|
+
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
|
447
|
+
|
448
|
+
next_token_ids, _ = batch.sample(logits)
|
449
|
+
|
450
|
+
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
451
|
+
if last_logprobs is not None:
|
452
|
+
last_token_logprobs = last_logprobs[
|
453
|
+
torch.arange(len(batch.reqs), device=next_token_ids.device),
|
454
|
+
next_token_ids,
|
455
|
+
].tolist()
|
456
|
+
|
457
|
+
next_token_ids = next_token_ids.tolist()
|
458
|
+
else:
|
459
|
+
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
460
|
+
|
461
|
+
# Check finish condition
|
462
|
+
pt = 0
|
463
|
+
for i, req in enumerate(batch.reqs):
|
464
|
+
req.completion_tokens_wo_jump_forward += 1
|
465
|
+
req.output_ids = [next_token_ids[i]]
|
466
|
+
req.check_finished()
|
467
|
+
|
468
|
+
if req.return_logprob:
|
469
|
+
if req.normalized_prompt_logprob is None:
|
470
|
+
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
471
|
+
|
472
|
+
if req.prefill_token_logprobs is None:
|
473
|
+
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
474
|
+
req.prefill_token_logprobs = list(
|
475
|
+
zip(
|
476
|
+
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
477
|
+
req.input_ids[-req.extend_input_len + 1 :],
|
478
|
+
)
|
479
|
+
)
|
480
|
+
if req.logprob_start_len == 0:
|
481
|
+
req.prefill_token_logprobs = [
|
482
|
+
(None, req.input_ids[0])
|
483
|
+
] + req.prefill_token_logprobs
|
484
|
+
|
485
|
+
if req.last_update_decode_tokens != 0:
|
486
|
+
req.decode_token_logprobs.extend(
|
487
|
+
list(
|
488
|
+
zip(
|
489
|
+
prefill_token_logprobs[
|
490
|
+
pt
|
491
|
+
+ req.extend_input_len
|
492
|
+
- req.last_update_decode_tokens : pt
|
493
|
+
+ req.extend_input_len
|
494
|
+
- 1
|
495
|
+
],
|
496
|
+
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
497
|
+
)
|
498
|
+
)
|
499
|
+
)
|
500
|
+
|
501
|
+
req.decode_token_logprobs.append(
|
502
|
+
(last_token_logprobs[i], next_token_ids[i])
|
503
|
+
)
|
504
|
+
|
505
|
+
if req.top_logprobs_num > 0:
|
506
|
+
if req.prefill_top_logprobs is None:
|
507
|
+
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
508
|
+
if req.logprob_start_len == 0:
|
509
|
+
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
510
|
+
|
511
|
+
if req.last_update_decode_tokens != 0:
|
512
|
+
req.decode_top_logprobs.extend(
|
513
|
+
prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
514
|
+
)
|
515
|
+
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
516
|
+
|
517
|
+
pt += req.extend_input_len
|
518
|
+
|
519
|
+
self.handle_finished_requests(batch)
|
520
|
+
|
521
|
+
def cache_filled_batch(self, batch: Batch):
|
522
|
+
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
523
|
+
for i, req in enumerate(batch.reqs):
|
524
|
+
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
525
|
+
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
526
|
+
last_uncached_pos=len(req.prefix_indices),
|
527
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
528
|
+
del_in_memory_pool=False,
|
529
|
+
old_last_node=req.last_node,
|
530
|
+
)
|
531
|
+
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
532
|
+
|
533
|
+
def forward_decode_batch(self, batch: Batch):
|
534
|
+
# check if decode out of memory
|
535
|
+
if not batch.check_decode_mem():
|
536
|
+
old_ratio = self.new_token_ratio
|
537
|
+
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
538
|
+
|
539
|
+
retracted_reqs = batch.retract_decode()
|
540
|
+
logger.info(
|
541
|
+
"decode out of memory happened, "
|
542
|
+
f"#retracted_reqs: {len(retracted_reqs)}, "
|
543
|
+
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
544
|
+
)
|
545
|
+
self.forward_queue.extend(retracted_reqs)
|
546
|
+
else:
|
547
|
+
self.new_token_ratio = max(
|
548
|
+
self.new_token_ratio - self.new_token_ratio_decay,
|
549
|
+
self.min_new_token_ratio,
|
550
|
+
)
|
551
|
+
|
552
|
+
if not self.disable_regex_jump_forward:
|
553
|
+
# check for jump-forward
|
554
|
+
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
555
|
+
|
556
|
+
self.forward_queue.extend(jump_forward_reqs)
|
557
|
+
if batch.is_empty():
|
558
|
+
return
|
559
|
+
|
560
|
+
# Update batch tensors
|
561
|
+
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
562
|
+
batch.prepare_for_decode()
|
563
|
+
|
564
|
+
# Forward
|
565
|
+
logits, (
|
566
|
+
_,
|
567
|
+
_,
|
568
|
+
_,
|
569
|
+
decode_top_logprobs,
|
570
|
+
last_logprobs,
|
571
|
+
) = self.model_runner.forward(batch, ForwardMode.DECODE)
|
572
|
+
next_token_ids, _ = batch.sample(logits)
|
573
|
+
next_token_ids = next_token_ids.tolist()
|
574
|
+
|
575
|
+
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
576
|
+
if last_logprobs is not None:
|
577
|
+
new_token_logprobs = last_logprobs[
|
578
|
+
torch.arange(len(batch.reqs)), next_token_ids
|
579
|
+
].tolist()
|
580
|
+
|
581
|
+
# Check finish condition
|
582
|
+
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
583
|
+
req.completion_tokens_wo_jump_forward += 1
|
584
|
+
req.output_ids.append(next_token_id)
|
585
|
+
req.check_finished()
|
586
|
+
|
587
|
+
if req.return_logprob:
|
588
|
+
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
|
589
|
+
|
590
|
+
if req.top_logprobs_num > 0:
|
591
|
+
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
592
|
+
|
593
|
+
self.handle_finished_requests(batch)
|
594
|
+
|
595
|
+
def handle_finished_requests(self, batch: Batch):
|
596
|
+
output_rids = []
|
597
|
+
prev_output_strs = []
|
598
|
+
output_tokens = []
|
599
|
+
output_skip_special_tokens = []
|
600
|
+
output_spaces_between_special_tokens = []
|
601
|
+
output_meta_info = []
|
602
|
+
output_finished_reason: List[BaseFinishReason] = []
|
603
|
+
finished_indices = []
|
604
|
+
unfinished_indices = []
|
605
|
+
for i, req in enumerate(batch.reqs):
|
606
|
+
if req.finished():
|
607
|
+
finished_indices.append(i)
|
608
|
+
else:
|
609
|
+
unfinished_indices.append(i)
|
610
|
+
|
611
|
+
if req.finished() or (
|
612
|
+
(
|
613
|
+
req.stream
|
614
|
+
and (
|
615
|
+
self.decode_forward_ct % self.stream_interval == 0
|
616
|
+
or len(req.output_ids) == 1
|
617
|
+
)
|
618
|
+
)
|
619
|
+
):
|
620
|
+
output_rids.append(req.rid)
|
621
|
+
prev_output_strs.append(req.prev_output_str)
|
622
|
+
output_tokens.append(req.output_ids)
|
623
|
+
output_skip_special_tokens.append(
|
624
|
+
req.sampling_params.skip_special_tokens
|
625
|
+
)
|
626
|
+
output_spaces_between_special_tokens.append(
|
627
|
+
req.sampling_params.spaces_between_special_tokens
|
628
|
+
)
|
629
|
+
|
630
|
+
meta_info = {
|
631
|
+
"prompt_tokens": len(req.origin_input_ids),
|
632
|
+
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
|
633
|
+
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
634
|
+
"finish_reason": str(req.finished_reason),
|
635
|
+
}
|
636
|
+
if req.return_logprob:
|
637
|
+
(
|
638
|
+
meta_info["prefill_token_logprobs"],
|
639
|
+
meta_info["decode_token_logprobs"],
|
640
|
+
meta_info["prefill_top_logprobs"],
|
641
|
+
meta_info["decode_top_logprobs"],
|
642
|
+
meta_info["normalized_prompt_logprob"],
|
643
|
+
) = (
|
644
|
+
req.prefill_token_logprobs,
|
645
|
+
req.decode_token_logprobs,
|
646
|
+
req.prefill_top_logprobs,
|
647
|
+
req.decode_top_logprobs,
|
648
|
+
req.normalized_prompt_logprob,
|
649
|
+
)
|
650
|
+
output_meta_info.append(meta_info)
|
651
|
+
output_finished_reason.append(req.finished_reason)
|
652
|
+
|
653
|
+
# Send to detokenizer
|
654
|
+
if output_rids:
|
655
|
+
self.out_pyobjs.append(
|
656
|
+
BatchTokenIDOut(
|
657
|
+
output_rids,
|
658
|
+
prev_output_strs,
|
659
|
+
output_tokens,
|
660
|
+
output_skip_special_tokens,
|
661
|
+
output_spaces_between_special_tokens,
|
662
|
+
output_meta_info,
|
663
|
+
output_finished_reason,
|
664
|
+
)
|
665
|
+
)
|
666
|
+
|
667
|
+
# Remove finished reqs
|
668
|
+
if finished_indices:
|
669
|
+
# Update radix cache
|
670
|
+
req_pool_indices_cpu = batch.req_pool_indices.tolist()
|
671
|
+
for i in finished_indices:
|
672
|
+
req = batch.reqs[i]
|
673
|
+
self.tree_cache.cache_req(
|
674
|
+
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
675
|
+
last_uncached_pos=len(req.prefix_indices),
|
676
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
677
|
+
)
|
678
|
+
|
679
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
680
|
+
|
681
|
+
# Update batch tensors
|
682
|
+
if unfinished_indices:
|
683
|
+
batch.filter_batch(unfinished_indices)
|
684
|
+
else:
|
685
|
+
batch.reqs = []
|
686
|
+
|
687
|
+
def flush_cache(self):
|
688
|
+
if len(self.forward_queue) == 0 and (
|
689
|
+
self.running_batch is None or len(self.running_batch.reqs) == 0
|
690
|
+
):
|
691
|
+
self.tree_cache.reset()
|
692
|
+
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
693
|
+
self.regex_fsm_cache.reset()
|
694
|
+
self.req_to_token_pool.clear()
|
695
|
+
self.token_to_kv_pool.clear()
|
696
|
+
torch.cuda.empty_cache()
|
697
|
+
logger.info("Cache flushed successfully!")
|
698
|
+
else:
|
699
|
+
warnings.warn(
|
700
|
+
f"Cache not flushed because there are pending requests. "
|
701
|
+
f"#queue-req: {len(self.forward_queue)}, "
|
702
|
+
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
703
|
+
)
|
704
|
+
|
705
|
+
def abort_request(self, recv_req):
|
706
|
+
# Delete requests in the waiting queue
|
707
|
+
to_del = None
|
708
|
+
for i, req in enumerate(self.forward_queue):
|
709
|
+
if req.rid == recv_req.rid:
|
710
|
+
to_del = i
|
711
|
+
break
|
712
|
+
|
713
|
+
if to_del is not None:
|
714
|
+
del self.forward_queue[to_del]
|
715
|
+
|
716
|
+
# Delete requests in the running batch
|
717
|
+
if self.running_batch:
|
718
|
+
for req in self.running_batch.reqs:
|
719
|
+
if req.rid == recv_req.rid:
|
720
|
+
req.finished_reason = FINISH_ABORT()
|
721
|
+
break
|
722
|
+
|
723
|
+
|
724
|
+
class ModelTpService(rpyc.Service):
|
725
|
+
exposed_ModelTpServer = ModelTpServer
|
726
|
+
|
727
|
+
|
728
|
+
class ModelTpClient:
|
729
|
+
def __init__(
|
730
|
+
self,
|
731
|
+
gpu_ids: List[int],
|
732
|
+
server_args: ServerArgs,
|
733
|
+
model_port_args: ModelPortArgs,
|
734
|
+
model_overide_args,
|
735
|
+
):
|
736
|
+
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
737
|
+
self.tp_size = server_args.tp_size
|
738
|
+
|
739
|
+
if self.tp_size * server_args.dp_size == 1:
|
740
|
+
# Init model
|
741
|
+
assert len(gpu_ids) == 1
|
742
|
+
self.model_server = ModelTpService().exposed_ModelTpServer(
|
743
|
+
0,
|
744
|
+
gpu_ids[0],
|
745
|
+
server_args,
|
746
|
+
model_port_args,
|
747
|
+
model_overide_args,
|
748
|
+
)
|
749
|
+
|
750
|
+
# Wrap functions
|
751
|
+
def async_wrap(f):
|
752
|
+
async def _func(*args, **kwargs):
|
753
|
+
return f(*args, **kwargs)
|
754
|
+
|
755
|
+
return _func
|
756
|
+
|
757
|
+
self.step = async_wrap(self.model_server.exposed_step)
|
758
|
+
else:
|
759
|
+
with ThreadPoolExecutor(self.tp_size) as executor:
|
760
|
+
# Launch model processes
|
761
|
+
rets = executor.map(
|
762
|
+
lambda args: start_rpyc_process(*args),
|
763
|
+
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
|
764
|
+
)
|
765
|
+
self.model_services = [x[0] for x in rets]
|
766
|
+
self.procs = [x[1] for x in rets]
|
767
|
+
|
768
|
+
# Init model
|
769
|
+
def init_model(i):
|
770
|
+
return self.model_services[i].ModelTpServer(
|
771
|
+
gpu_ids[i],
|
772
|
+
i,
|
773
|
+
server_args,
|
774
|
+
model_port_args,
|
775
|
+
model_overide_args,
|
776
|
+
)
|
777
|
+
|
778
|
+
self.model_servers = executor.map(init_model, range(self.tp_size))
|
779
|
+
|
780
|
+
# Wrap functions
|
781
|
+
def async_wrap(func_name):
|
782
|
+
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
|
783
|
+
|
784
|
+
async def _func(*args, **kwargs):
|
785
|
+
tasks = [f(*args, **kwargs) for f in fs]
|
786
|
+
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
|
787
|
+
return obtain(tasks[0].value)
|
788
|
+
|
789
|
+
return _func
|
790
|
+
|
791
|
+
self.step = async_wrap("step")
|