sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +3 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +8 -1
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +17 -2
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +1 -1
- sglang/srt/hf_transformers_utils.py +75 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +15 -11
- sglang/srt/managers/router/infer_batch.py +103 -59
- sglang/srt/managers/router/manager.py +1 -1
- sglang/srt/managers/router/model_rpc.py +175 -122
- sglang/srt/managers/router/model_runner.py +91 -104
- sglang/srt/managers/router/radix_cache.py +7 -1
- sglang/srt/managers/router/scheduler.py +6 -6
- sglang/srt/managers/tokenizer_manager.py +152 -89
- sglang/srt/model_config.py +4 -5
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +8 -15
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +19 -15
- sglang/srt/models/llava.py +84 -20
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +248 -118
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +77 -42
- sglang/srt/server_args.py +51 -6
- sglang/srt/utils.py +124 -66
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +22 -4
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import multiprocessing
|
|
4
4
|
import time
|
5
5
|
import warnings
|
6
6
|
from concurrent.futures import ThreadPoolExecutor
|
7
|
-
from typing import
|
7
|
+
from typing import List, Optional
|
8
8
|
|
9
9
|
import rpyc
|
10
10
|
import torch
|
@@ -16,31 +16,33 @@ try:
|
|
16
16
|
except ImportError:
|
17
17
|
from vllm.logger import logger as vllm_default_logger
|
18
18
|
|
19
|
+
from sglang.global_config import global_config
|
19
20
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
20
21
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
21
22
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
22
23
|
from sglang.srt.managers.io_struct import (
|
24
|
+
AbortReq,
|
23
25
|
BatchTokenIDOut,
|
24
26
|
FlushCacheReq,
|
25
27
|
TokenizedGenerateReqInput,
|
26
28
|
)
|
27
|
-
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
29
|
+
from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
|
28
30
|
from sglang.srt.managers.router.model_runner import ModelRunner
|
29
31
|
from sglang.srt.managers.router.radix_cache import RadixCache
|
30
32
|
from sglang.srt.managers.router.scheduler import Scheduler
|
31
33
|
from sglang.srt.model_config import ModelConfig
|
32
34
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
33
35
|
from sglang.srt.utils import (
|
34
|
-
get_exception_traceback,
|
35
36
|
get_int_token_logit_bias,
|
36
37
|
is_multimodal_model,
|
37
38
|
set_random_seed,
|
38
39
|
)
|
39
|
-
|
40
|
+
from sglang.utils import get_exception_traceback
|
40
41
|
|
41
42
|
logger = logging.getLogger("model_rpc")
|
42
43
|
vllm_default_logger.setLevel(logging.WARN)
|
43
44
|
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
45
|
+
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
44
46
|
|
45
47
|
|
46
48
|
class ModelRpcServer:
|
@@ -68,20 +70,13 @@ class ModelRpcServer:
|
|
68
70
|
)
|
69
71
|
|
70
72
|
# For model end global settings
|
71
|
-
server_args_dict = {
|
72
|
-
"enable_flashinfer": server_args.enable_flashinfer,
|
73
|
-
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
74
|
-
}
|
75
|
-
|
76
73
|
self.model_runner = ModelRunner(
|
77
74
|
model_config=self.model_config,
|
78
75
|
mem_fraction_static=server_args.mem_fraction_static,
|
79
76
|
tp_rank=tp_rank,
|
80
77
|
tp_size=server_args.tp_size,
|
81
78
|
nccl_port=port_args.nccl_port,
|
82
|
-
|
83
|
-
trust_remote_code=server_args.trust_remote_code,
|
84
|
-
server_args_dict=server_args_dict,
|
79
|
+
server_args=server_args,
|
85
80
|
)
|
86
81
|
if is_multimodal_model(server_args.model_path):
|
87
82
|
self.processor = get_processor(
|
@@ -96,24 +91,27 @@ class ModelRpcServer:
|
|
96
91
|
tokenizer_mode=server_args.tokenizer_mode,
|
97
92
|
trust_remote_code=server_args.trust_remote_code,
|
98
93
|
)
|
99
|
-
self.
|
100
|
-
self.
|
101
|
-
self.max_prefill_num_token = max(
|
94
|
+
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
95
|
+
self.max_prefill_tokens = max(
|
102
96
|
self.model_config.context_len,
|
103
97
|
(
|
104
|
-
self.
|
105
|
-
if server_args.
|
106
|
-
else server_args.
|
98
|
+
self.max_total_num_tokens // 6
|
99
|
+
if server_args.max_prefill_tokens is None
|
100
|
+
else server_args.max_prefill_tokens
|
107
101
|
),
|
108
102
|
)
|
103
|
+
self.max_running_requests = (self.max_total_num_tokens // 2
|
104
|
+
if server_args.max_running_requests is None else server_args.max_running_requests)
|
105
|
+
|
109
106
|
self.int_token_logit_bias = torch.tensor(
|
110
107
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
111
108
|
)
|
112
109
|
set_random_seed(server_args.random_seed)
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
f"
|
110
|
+
|
111
|
+
# Print info
|
112
|
+
logger.info(f"[rank={self.tp_rank}] "
|
113
|
+
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
114
|
+
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
117
115
|
f"context_len={self.model_config.context_len}, "
|
118
116
|
)
|
119
117
|
if self.tp_rank == 0:
|
@@ -128,9 +126,9 @@ class ModelRpcServer:
|
|
128
126
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
129
127
|
self.scheduler = Scheduler(
|
130
128
|
self.schedule_heuristic,
|
131
|
-
self.
|
132
|
-
self.
|
133
|
-
self.
|
129
|
+
self.max_running_requests,
|
130
|
+
self.max_prefill_tokens,
|
131
|
+
self.max_total_num_tokens,
|
134
132
|
self.tree_cache,
|
135
133
|
)
|
136
134
|
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
@@ -156,27 +154,20 @@ class ModelRpcServer:
|
|
156
154
|
self.jump_forward_cache = JumpForwardCache()
|
157
155
|
|
158
156
|
# Init new token estimation
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
logger.info("Cache flushed successfully!")
|
174
|
-
else:
|
175
|
-
warnings.warn(
|
176
|
-
f"Cache not flushed because there are pending requests. "
|
177
|
-
f"#queue-req: {len(self.forward_queue)}, "
|
178
|
-
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
179
|
-
)
|
157
|
+
assert (
|
158
|
+
server_args.schedule_conservativeness >= 0
|
159
|
+
), "Invalid schedule_conservativeness"
|
160
|
+
self.new_token_ratio = min(
|
161
|
+
global_config.base_new_token_ratio * server_args.schedule_conservativeness,
|
162
|
+
1.0,
|
163
|
+
)
|
164
|
+
self.min_new_token_ratio = min(
|
165
|
+
global_config.base_min_new_token_ratio
|
166
|
+
* server_args.schedule_conservativeness,
|
167
|
+
1.0,
|
168
|
+
)
|
169
|
+
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
170
|
+
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
180
171
|
|
181
172
|
def exposed_step(self, recv_reqs):
|
182
173
|
if self.tp_size != 1:
|
@@ -189,6 +180,8 @@ class ModelRpcServer:
|
|
189
180
|
self.handle_generate_request(recv_req)
|
190
181
|
elif isinstance(recv_req, FlushCacheReq):
|
191
182
|
self.flush_cache()
|
183
|
+
elif isinstance(recv_req, AbortReq):
|
184
|
+
self.abort_request(recv_req)
|
192
185
|
else:
|
193
186
|
raise ValueError(f"Invalid request: {recv_req}")
|
194
187
|
|
@@ -207,9 +200,8 @@ class ModelRpcServer:
|
|
207
200
|
new_batch = self.get_new_fill_batch()
|
208
201
|
|
209
202
|
if new_batch is not None:
|
210
|
-
# Run new fill batch
|
203
|
+
# Run a new fill batch
|
211
204
|
self.forward_fill_batch(new_batch)
|
212
|
-
|
213
205
|
self.cache_filled_batch(new_batch)
|
214
206
|
|
215
207
|
if not new_batch.is_empty():
|
@@ -225,39 +217,42 @@ class ModelRpcServer:
|
|
225
217
|
self.num_generated_tokens += len(self.running_batch.reqs)
|
226
218
|
self.forward_decode_batch(self.running_batch)
|
227
219
|
|
228
|
-
|
229
|
-
|
230
|
-
break
|
231
|
-
|
232
|
-
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
233
|
-
break
|
234
|
-
|
235
|
-
if self.running_batch is not None and self.tp_rank == 0:
|
220
|
+
# Print stats
|
221
|
+
if self.tp_rank == 0:
|
236
222
|
if self.decode_forward_ct % 40 == 0:
|
237
|
-
num_used = self.
|
223
|
+
num_used = self.max_total_num_tokens - (
|
238
224
|
self.token_to_kv_pool.available_size()
|
239
225
|
+ self.tree_cache.evictable_size()
|
240
226
|
)
|
241
|
-
throuhgput = self.num_generated_tokens / (
|
227
|
+
throuhgput = self.num_generated_tokens / (
|
228
|
+
time.time() - self.last_stats_tic
|
229
|
+
)
|
242
230
|
self.num_generated_tokens = 0
|
243
231
|
self.last_stats_tic = time.time()
|
244
232
|
logger.info(
|
245
233
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
246
234
|
f"#token: {num_used}, "
|
247
|
-
f"token usage: {num_used / self.
|
235
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
248
236
|
f"gen throughput (token/s): {throuhgput:.2f}, "
|
249
237
|
f"#queue-req: {len(self.forward_queue)}"
|
250
238
|
)
|
239
|
+
|
240
|
+
if self.running_batch.is_empty():
|
241
|
+
self.running_batch = None
|
242
|
+
break
|
243
|
+
|
244
|
+
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
245
|
+
break
|
251
246
|
else:
|
252
|
-
#
|
247
|
+
# Check the available size
|
253
248
|
available_size = (
|
254
249
|
self.token_to_kv_pool.available_size()
|
255
250
|
+ self.tree_cache.evictable_size()
|
256
251
|
)
|
257
|
-
if available_size != self.
|
252
|
+
if available_size != self.max_total_num_tokens:
|
258
253
|
warnings.warn(
|
259
254
|
"Warning: "
|
260
|
-
f"available_size={available_size},
|
255
|
+
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
261
256
|
"KV cache pool leak detected!"
|
262
257
|
)
|
263
258
|
|
@@ -275,8 +270,13 @@ class ModelRpcServer:
|
|
275
270
|
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
276
271
|
]
|
277
272
|
req.image_size = recv_req.image_size
|
278
|
-
req.
|
279
|
-
|
273
|
+
req.origin_input_ids, req.image_offset = (
|
274
|
+
self.model_runner.model.pad_input_ids(
|
275
|
+
req.origin_input_ids_unpadded,
|
276
|
+
req.pad_value,
|
277
|
+
req.pixel_values.shape,
|
278
|
+
req.image_size,
|
279
|
+
)
|
280
280
|
)
|
281
281
|
req.sampling_params = recv_req.sampling_params
|
282
282
|
req.return_logprob = recv_req.return_logprob
|
@@ -293,23 +293,28 @@ class ModelRpcServer:
|
|
293
293
|
req.sampling_params.regex
|
294
294
|
)
|
295
295
|
|
296
|
-
# Truncate long
|
297
|
-
req.
|
296
|
+
# Truncate prompts that are too long
|
297
|
+
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
|
298
298
|
req.sampling_params.max_new_tokens = min(
|
299
299
|
req.sampling_params.max_new_tokens,
|
300
|
-
self.model_config.context_len - 1 - len(req.
|
301
|
-
self.
|
300
|
+
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
301
|
+
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
302
302
|
)
|
303
303
|
self.forward_queue.append(req)
|
304
304
|
|
305
305
|
def get_new_fill_batch(self):
|
306
306
|
if (
|
307
307
|
self.running_batch is not None
|
308
|
-
and len(self.running_batch.reqs) > self.
|
308
|
+
and len(self.running_batch.reqs) > self.max_running_requests
|
309
309
|
):
|
310
310
|
return None
|
311
311
|
|
312
|
+
# Compute matched prefix length
|
312
313
|
for req in self.forward_queue:
|
314
|
+
assert (
|
315
|
+
len(req.output_ids) == 0
|
316
|
+
), "The output ids should be empty when prefilling"
|
317
|
+
req.input_ids = req.origin_input_ids + req.prev_output_ids
|
313
318
|
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
314
319
|
if req.return_logprob:
|
315
320
|
prefix_indices = prefix_indices[: req.logprob_start_len]
|
@@ -337,7 +342,7 @@ class ModelRpcServer:
|
|
337
342
|
)
|
338
343
|
|
339
344
|
for req in self.forward_queue:
|
340
|
-
if req.return_logprob:
|
345
|
+
if req.return_logprob and req.normalized_prompt_logprob is None:
|
341
346
|
# Need at least two tokens to compute normalized logprob
|
342
347
|
if req.extend_input_len < 2:
|
343
348
|
delta = 2 - req.extend_input_len
|
@@ -356,7 +361,7 @@ class ModelRpcServer:
|
|
356
361
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
357
362
|
< available_size
|
358
363
|
and req.extend_input_len + new_batch_input_tokens
|
359
|
-
< self.
|
364
|
+
< self.max_prefill_tokens
|
360
365
|
):
|
361
366
|
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
362
367
|
available_size += delta
|
@@ -381,6 +386,7 @@ class ModelRpcServer:
|
|
381
386
|
if len(can_run_list) == 0:
|
382
387
|
return None
|
383
388
|
|
389
|
+
# Print stats
|
384
390
|
if self.tp_rank == 0:
|
385
391
|
running_req = (
|
386
392
|
0 if self.running_batch is None else len(self.running_batch.reqs)
|
@@ -401,13 +407,14 @@ class ModelRpcServer:
|
|
401
407
|
f"#running_req: {running_req}. "
|
402
408
|
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
|
403
409
|
)
|
404
|
-
#logger.debug(
|
410
|
+
# logger.debug(
|
405
411
|
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
406
412
|
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
407
413
|
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
408
414
|
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
409
|
-
#)
|
415
|
+
# )
|
410
416
|
|
417
|
+
# Return the new batch
|
411
418
|
new_batch = Batch.init_new(
|
412
419
|
can_run_list,
|
413
420
|
self.req_to_token_pool,
|
@@ -440,11 +447,10 @@ class ModelRpcServer:
|
|
440
447
|
|
441
448
|
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
442
449
|
if last_logprobs is not None:
|
443
|
-
last_token_logprobs =
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
)
|
450
|
+
last_token_logprobs = last_logprobs[
|
451
|
+
torch.arange(len(batch.reqs), device=next_token_ids.device),
|
452
|
+
next_token_ids,
|
453
|
+
].tolist()
|
448
454
|
|
449
455
|
next_token_ids = next_token_ids.tolist()
|
450
456
|
else:
|
@@ -458,35 +464,60 @@ class ModelRpcServer:
|
|
458
464
|
req.check_finished()
|
459
465
|
|
460
466
|
if req.return_logprob:
|
461
|
-
req.normalized_prompt_logprob
|
462
|
-
|
463
|
-
|
464
|
-
req.prefill_token_logprobs
|
465
|
-
|
466
|
-
|
467
|
-
|
467
|
+
if req.normalized_prompt_logprob is None:
|
468
|
+
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
469
|
+
|
470
|
+
if req.prefill_token_logprobs is None:
|
471
|
+
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
472
|
+
req.prefill_token_logprobs = list(
|
473
|
+
zip(
|
474
|
+
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
475
|
+
req.input_ids[-req.extend_input_len + 1 :],
|
476
|
+
)
|
468
477
|
)
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
req.
|
478
|
+
if req.logprob_start_len == 0:
|
479
|
+
req.prefill_token_logprobs = [
|
480
|
+
(None, req.input_ids[0])
|
481
|
+
] + req.prefill_token_logprobs
|
482
|
+
|
483
|
+
if req.last_update_decode_tokens != 0:
|
484
|
+
req.decode_token_logprobs.extend(
|
485
|
+
list(
|
486
|
+
zip(
|
487
|
+
prefill_token_logprobs[
|
488
|
+
pt
|
489
|
+
+ req.extend_input_len
|
490
|
+
- req.last_update_decode_tokens : pt
|
491
|
+
+ req.extend_input_len
|
492
|
+
- 1
|
493
|
+
],
|
494
|
+
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
495
|
+
)
|
496
|
+
)
|
497
|
+
)
|
498
|
+
|
499
|
+
req.decode_token_logprobs.append(
|
475
500
|
(last_token_logprobs[i], next_token_ids[i])
|
476
|
-
|
501
|
+
)
|
477
502
|
|
478
503
|
if req.top_logprobs_num > 0:
|
479
|
-
req.prefill_top_logprobs
|
480
|
-
|
481
|
-
req.
|
482
|
-
|
504
|
+
if req.prefill_top_logprobs is None:
|
505
|
+
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
506
|
+
if req.logprob_start_len == 0:
|
507
|
+
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
508
|
+
|
509
|
+
if req.last_update_decode_tokens != 0:
|
510
|
+
req.decode_top_logprobs.extend(
|
511
|
+
prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
512
|
+
)
|
513
|
+
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
483
514
|
|
484
515
|
pt += req.extend_input_len
|
485
516
|
|
486
517
|
self.handle_finished_requests(batch)
|
487
518
|
|
488
519
|
def cache_filled_batch(self, batch: Batch):
|
489
|
-
req_pool_indices_cpu = batch.req_pool_indices.cpu().
|
520
|
+
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
490
521
|
for i, req in enumerate(batch.reqs):
|
491
522
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
492
523
|
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
@@ -501,7 +532,7 @@ class ModelRpcServer:
|
|
501
532
|
# check if decode out of memory
|
502
533
|
if not batch.check_decode_mem():
|
503
534
|
old_ratio = self.new_token_ratio
|
504
|
-
self.new_token_ratio = min(old_ratio + self.
|
535
|
+
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
505
536
|
|
506
537
|
retracted_reqs = batch.retract_decode()
|
507
538
|
logger.info(
|
@@ -512,26 +543,13 @@ class ModelRpcServer:
|
|
512
543
|
self.forward_queue.extend(retracted_reqs)
|
513
544
|
else:
|
514
545
|
self.new_token_ratio = max(
|
515
|
-
self.new_token_ratio - self.
|
546
|
+
self.new_token_ratio - self.new_token_ratio_decay,
|
516
547
|
self.min_new_token_ratio,
|
517
548
|
)
|
518
549
|
|
519
550
|
if not self.disable_regex_jump_forward:
|
520
551
|
# check for jump-forward
|
521
|
-
jump_forward_reqs = batch.check_for_jump_forward()
|
522
|
-
|
523
|
-
# check for image jump-forward
|
524
|
-
for req in jump_forward_reqs:
|
525
|
-
if req.pixel_values is not None:
|
526
|
-
(
|
527
|
-
req.input_ids,
|
528
|
-
req.image_offset,
|
529
|
-
) = self.model_runner.model.pad_input_ids(
|
530
|
-
req.input_ids,
|
531
|
-
req.pad_value,
|
532
|
-
req.pixel_values.shape,
|
533
|
-
req.image_size,
|
534
|
-
)
|
552
|
+
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
535
553
|
|
536
554
|
self.forward_queue.extend(jump_forward_reqs)
|
537
555
|
if batch.is_empty():
|
@@ -574,8 +592,8 @@ class ModelRpcServer:
|
|
574
592
|
|
575
593
|
def handle_finished_requests(self, batch: Batch):
|
576
594
|
output_rids = []
|
595
|
+
prev_output_strs = []
|
577
596
|
output_tokens = []
|
578
|
-
output_and_jump_forward_strs = []
|
579
597
|
output_hit_stop_str = []
|
580
598
|
output_skip_special_tokens = []
|
581
599
|
output_spaces_between_special_tokens = []
|
@@ -599,8 +617,8 @@ class ModelRpcServer:
|
|
599
617
|
)
|
600
618
|
):
|
601
619
|
output_rids.append(req.rid)
|
620
|
+
prev_output_strs.append(req.prev_output_str)
|
602
621
|
output_tokens.append(req.output_ids)
|
603
|
-
output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
|
604
622
|
output_hit_stop_str.append(req.hit_stop_str)
|
605
623
|
output_skip_special_tokens.append(
|
606
624
|
req.sampling_params.skip_special_tokens
|
@@ -610,10 +628,8 @@ class ModelRpcServer:
|
|
610
628
|
)
|
611
629
|
|
612
630
|
meta_info = {
|
613
|
-
"prompt_tokens": req.
|
614
|
-
"completion_tokens": len(req.
|
615
|
-
+ len(req.output_ids)
|
616
|
-
- req.prompt_tokens,
|
631
|
+
"prompt_tokens": len(req.origin_input_ids),
|
632
|
+
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
|
617
633
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
618
634
|
"finish_reason": FinishReason.to_str(req.finish_reason),
|
619
635
|
"hit_stop_str": req.hit_stop_str,
|
@@ -640,8 +656,8 @@ class ModelRpcServer:
|
|
640
656
|
self.out_pyobjs.append(
|
641
657
|
BatchTokenIDOut(
|
642
658
|
output_rids,
|
659
|
+
prev_output_strs,
|
643
660
|
output_tokens,
|
644
|
-
output_and_jump_forward_strs,
|
645
661
|
output_hit_stop_str,
|
646
662
|
output_skip_special_tokens,
|
647
663
|
output_spaces_between_special_tokens,
|
@@ -670,6 +686,43 @@ class ModelRpcServer:
|
|
670
686
|
else:
|
671
687
|
batch.reqs = []
|
672
688
|
|
689
|
+
def flush_cache(self):
|
690
|
+
if len(self.forward_queue) == 0 and (
|
691
|
+
self.running_batch is None or len(self.running_batch.reqs) == 0
|
692
|
+
):
|
693
|
+
self.tree_cache.reset()
|
694
|
+
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
695
|
+
self.regex_fsm_cache.reset()
|
696
|
+
self.req_to_token_pool.clear()
|
697
|
+
self.token_to_kv_pool.clear()
|
698
|
+
torch.cuda.empty_cache()
|
699
|
+
logger.info("Cache flushed successfully!")
|
700
|
+
else:
|
701
|
+
warnings.warn(
|
702
|
+
f"Cache not flushed because there are pending requests. "
|
703
|
+
f"#queue-req: {len(self.forward_queue)}, "
|
704
|
+
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
705
|
+
)
|
706
|
+
|
707
|
+
def abort_request(self, recv_req):
|
708
|
+
# Delete requests in the waiting queue
|
709
|
+
to_del = None
|
710
|
+
for i, req in enumerate(self.forward_queue):
|
711
|
+
if req.rid == recv_req.rid:
|
712
|
+
to_del = i
|
713
|
+
break
|
714
|
+
|
715
|
+
if to_del is not None:
|
716
|
+
del self.forward_queue[to_del]
|
717
|
+
|
718
|
+
# Delete requests in the running batch
|
719
|
+
if self.running_batch:
|
720
|
+
for req in self.running_batch.reqs:
|
721
|
+
if req.rid == recv_req.rid:
|
722
|
+
req.finished = True
|
723
|
+
req.finish_reason = FinishReason.ABORT
|
724
|
+
break
|
725
|
+
|
673
726
|
|
674
727
|
class ModelRpcService(rpyc.Service):
|
675
728
|
exposed_ModelRpcServer = ModelRpcServer
|
@@ -731,7 +784,7 @@ def _init_service(port):
|
|
731
784
|
protocol_config={
|
732
785
|
"allow_public_attrs": True,
|
733
786
|
"allow_pickle": True,
|
734
|
-
"sync_request_timeout":
|
787
|
+
"sync_request_timeout": 3600,
|
735
788
|
},
|
736
789
|
)
|
737
790
|
t.start()
|
@@ -751,7 +804,7 @@ def start_model_process(port):
|
|
751
804
|
config={
|
752
805
|
"allow_public_attrs": True,
|
753
806
|
"allow_pickle": True,
|
754
|
-
"sync_request_timeout":
|
807
|
+
"sync_request_timeout": 3600,
|
755
808
|
},
|
756
809
|
)
|
757
810
|
break
|