sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +5 -1
- sglang/api.py +8 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +11 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +161 -81
- sglang/lang/ir.py +29 -11
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +83 -2
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +26 -10
- sglang/srt/managers/router/infer_batch.py +130 -74
- sglang/srt/managers/router/manager.py +7 -9
- sglang/srt/managers/router/model_rpc.py +224 -135
- sglang/srt/managers/router/model_runner.py +94 -107
- sglang/srt/managers/router/radix_cache.py +54 -18
- sglang/srt/managers/router/scheduler.py +23 -34
- sglang/srt/managers/tokenizer_manager.py +183 -88
- sglang/srt/model_config.py +5 -2
- sglang/srt/models/commandr.py +15 -22
- sglang/srt/models/dbrx.py +22 -29
- sglang/srt/models/gemma.py +14 -24
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +24 -23
- sglang/srt/models/llava.py +85 -25
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/mixtral.py +254 -130
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +28 -25
- sglang/srt/models/qwen2.py +17 -22
- sglang/srt/models/stablelm.py +21 -26
- sglang/srt/models/yivl.py +17 -25
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +101 -52
- sglang/srt/server_args.py +59 -11
- sglang/srt/utils.py +242 -75
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +95 -26
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -402
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -4,42 +4,45 @@ import multiprocessing
|
|
4
4
|
import time
|
5
5
|
import warnings
|
6
6
|
from concurrent.futures import ThreadPoolExecutor
|
7
|
-
from typing import List
|
7
|
+
from typing import List, Optional
|
8
8
|
|
9
9
|
import rpyc
|
10
10
|
import torch
|
11
11
|
from rpyc.utils.classic import obtain
|
12
12
|
from rpyc.utils.server import ThreadedServer
|
13
|
+
|
13
14
|
try:
|
14
15
|
from vllm.logger import _default_handler as vllm_default_logger
|
15
16
|
except ImportError:
|
16
17
|
from vllm.logger import logger as vllm_default_logger
|
17
18
|
|
19
|
+
from sglang.global_config import global_config
|
18
20
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
19
21
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
20
22
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
21
23
|
from sglang.srt.managers.io_struct import (
|
24
|
+
AbortReq,
|
22
25
|
BatchTokenIDOut,
|
23
26
|
FlushCacheReq,
|
24
27
|
TokenizedGenerateReqInput,
|
25
28
|
)
|
26
|
-
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
29
|
+
from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
|
27
30
|
from sglang.srt.managers.router.model_runner import ModelRunner
|
28
31
|
from sglang.srt.managers.router.radix_cache import RadixCache
|
29
32
|
from sglang.srt.managers.router.scheduler import Scheduler
|
30
33
|
from sglang.srt.model_config import ModelConfig
|
31
34
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
32
35
|
from sglang.srt.utils import (
|
33
|
-
get_exception_traceback,
|
34
36
|
get_int_token_logit_bias,
|
35
37
|
is_multimodal_model,
|
36
38
|
set_random_seed,
|
37
39
|
)
|
38
|
-
|
40
|
+
from sglang.utils import get_exception_traceback
|
39
41
|
|
40
42
|
logger = logging.getLogger("model_rpc")
|
41
43
|
vllm_default_logger.setLevel(logging.WARN)
|
42
44
|
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
45
|
+
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
43
46
|
|
44
47
|
|
45
48
|
class ModelRpcServer:
|
@@ -48,6 +51,7 @@ class ModelRpcServer:
|
|
48
51
|
tp_rank: int,
|
49
52
|
server_args: ServerArgs,
|
50
53
|
port_args: PortArgs,
|
54
|
+
model_overide_args: Optional[dict] = None,
|
51
55
|
):
|
52
56
|
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
53
57
|
|
@@ -62,23 +66,17 @@ class ModelRpcServer:
|
|
62
66
|
server_args.model_path,
|
63
67
|
server_args.trust_remote_code,
|
64
68
|
context_length=server_args.context_length,
|
69
|
+
model_overide_args=model_overide_args,
|
65
70
|
)
|
66
71
|
|
67
72
|
# For model end global settings
|
68
|
-
server_args_dict = {
|
69
|
-
"enable_flashinfer": server_args.enable_flashinfer,
|
70
|
-
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
71
|
-
}
|
72
|
-
|
73
73
|
self.model_runner = ModelRunner(
|
74
74
|
model_config=self.model_config,
|
75
75
|
mem_fraction_static=server_args.mem_fraction_static,
|
76
76
|
tp_rank=tp_rank,
|
77
77
|
tp_size=server_args.tp_size,
|
78
78
|
nccl_port=port_args.nccl_port,
|
79
|
-
|
80
|
-
trust_remote_code=server_args.trust_remote_code,
|
81
|
-
server_args_dict=server_args_dict,
|
79
|
+
server_args=server_args,
|
82
80
|
)
|
83
81
|
if is_multimodal_model(server_args.model_path):
|
84
82
|
self.processor = get_processor(
|
@@ -93,37 +91,44 @@ class ModelRpcServer:
|
|
93
91
|
tokenizer_mode=server_args.tokenizer_mode,
|
94
92
|
trust_remote_code=server_args.trust_remote_code,
|
95
93
|
)
|
96
|
-
self.
|
97
|
-
self.
|
98
|
-
self.max_prefill_num_token = max(
|
94
|
+
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
95
|
+
self.max_prefill_tokens = max(
|
99
96
|
self.model_config.context_len,
|
100
97
|
(
|
101
|
-
self.
|
102
|
-
if server_args.
|
103
|
-
else server_args.
|
98
|
+
self.max_total_num_tokens // 6
|
99
|
+
if server_args.max_prefill_tokens is None
|
100
|
+
else server_args.max_prefill_tokens
|
104
101
|
),
|
105
102
|
)
|
103
|
+
self.max_running_requests = (self.max_total_num_tokens // 2
|
104
|
+
if server_args.max_running_requests is None else server_args.max_running_requests)
|
105
|
+
|
106
106
|
self.int_token_logit_bias = torch.tensor(
|
107
107
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
108
108
|
)
|
109
109
|
set_random_seed(server_args.random_seed)
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
f"
|
110
|
+
|
111
|
+
# Print info
|
112
|
+
logger.info(f"[rank={self.tp_rank}] "
|
113
|
+
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
114
|
+
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
114
115
|
f"context_len={self.model_config.context_len}, "
|
115
116
|
)
|
116
117
|
if self.tp_rank == 0:
|
117
118
|
logger.info(f"server_args: {server_args.print_mode_args()}")
|
118
119
|
|
119
120
|
# Init cache
|
120
|
-
self.tree_cache = RadixCache(
|
121
|
+
self.tree_cache = RadixCache(
|
122
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
123
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
124
|
+
disable=server_args.disable_radix_cache,
|
125
|
+
)
|
121
126
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
122
127
|
self.scheduler = Scheduler(
|
123
128
|
self.schedule_heuristic,
|
124
|
-
self.
|
125
|
-
self.
|
126
|
-
self.
|
129
|
+
self.max_running_requests,
|
130
|
+
self.max_prefill_tokens,
|
131
|
+
self.max_total_num_tokens,
|
127
132
|
self.tree_cache,
|
128
133
|
)
|
129
134
|
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
@@ -135,6 +140,8 @@ class ModelRpcServer:
|
|
135
140
|
self.out_pyobjs = []
|
136
141
|
self.decode_forward_ct = 0
|
137
142
|
self.stream_interval = server_args.stream_interval
|
143
|
+
self.num_generated_tokens = 0
|
144
|
+
self.last_stats_tic = time.time()
|
138
145
|
|
139
146
|
# Init the FSM cache for constrained generation
|
140
147
|
self.regex_fsm_cache = FSMCache(
|
@@ -147,27 +154,20 @@ class ModelRpcServer:
|
|
147
154
|
self.jump_forward_cache = JumpForwardCache()
|
148
155
|
|
149
156
|
# Init new token estimation
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
logger.info("Cache flushed successfully!")
|
165
|
-
else:
|
166
|
-
warnings.warn(
|
167
|
-
f"Cache not flushed because there are pending requests. "
|
168
|
-
f"#queue-req: {len(self.forward_queue)}, "
|
169
|
-
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
170
|
-
)
|
157
|
+
assert (
|
158
|
+
server_args.schedule_conservativeness >= 0
|
159
|
+
), "Invalid schedule_conservativeness"
|
160
|
+
self.new_token_ratio = min(
|
161
|
+
global_config.base_new_token_ratio * server_args.schedule_conservativeness,
|
162
|
+
1.0,
|
163
|
+
)
|
164
|
+
self.min_new_token_ratio = min(
|
165
|
+
global_config.base_min_new_token_ratio
|
166
|
+
* server_args.schedule_conservativeness,
|
167
|
+
1.0,
|
168
|
+
)
|
169
|
+
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
170
|
+
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
171
171
|
|
172
172
|
def exposed_step(self, recv_reqs):
|
173
173
|
if self.tp_size != 1:
|
@@ -180,6 +180,8 @@ class ModelRpcServer:
|
|
180
180
|
self.handle_generate_request(recv_req)
|
181
181
|
elif isinstance(recv_req, FlushCacheReq):
|
182
182
|
self.flush_cache()
|
183
|
+
elif isinstance(recv_req, AbortReq):
|
184
|
+
self.abort_request(recv_req)
|
183
185
|
else:
|
184
186
|
raise ValueError(f"Invalid request: {recv_req}")
|
185
187
|
|
@@ -198,8 +200,9 @@ class ModelRpcServer:
|
|
198
200
|
new_batch = self.get_new_fill_batch()
|
199
201
|
|
200
202
|
if new_batch is not None:
|
201
|
-
# Run new fill batch
|
203
|
+
# Run a new fill batch
|
202
204
|
self.forward_fill_batch(new_batch)
|
205
|
+
self.cache_filled_batch(new_batch)
|
203
206
|
|
204
207
|
if not new_batch.is_empty():
|
205
208
|
if self.running_batch is None:
|
@@ -211,37 +214,45 @@ class ModelRpcServer:
|
|
211
214
|
if self.running_batch is not None:
|
212
215
|
# Run a few decode batches continuously for reducing overhead
|
213
216
|
for _ in range(10):
|
217
|
+
self.num_generated_tokens += len(self.running_batch.reqs)
|
214
218
|
self.forward_decode_batch(self.running_batch)
|
215
219
|
|
216
|
-
|
217
|
-
|
218
|
-
break
|
219
|
-
|
220
|
-
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
221
|
-
break
|
222
|
-
|
223
|
-
if self.running_batch is not None and self.tp_rank == 0:
|
220
|
+
# Print stats
|
221
|
+
if self.tp_rank == 0:
|
224
222
|
if self.decode_forward_ct % 40 == 0:
|
225
|
-
num_used = self.
|
223
|
+
num_used = self.max_total_num_tokens - (
|
226
224
|
self.token_to_kv_pool.available_size()
|
227
225
|
+ self.tree_cache.evictable_size()
|
228
226
|
)
|
227
|
+
throuhgput = self.num_generated_tokens / (
|
228
|
+
time.time() - self.last_stats_tic
|
229
|
+
)
|
230
|
+
self.num_generated_tokens = 0
|
231
|
+
self.last_stats_tic = time.time()
|
229
232
|
logger.info(
|
230
233
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
231
234
|
f"#token: {num_used}, "
|
232
|
-
f"token usage: {num_used / self.
|
235
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
236
|
+
f"gen throughput (token/s): {throuhgput:.2f}, "
|
233
237
|
f"#queue-req: {len(self.forward_queue)}"
|
234
238
|
)
|
239
|
+
|
240
|
+
if self.running_batch.is_empty():
|
241
|
+
self.running_batch = None
|
242
|
+
break
|
243
|
+
|
244
|
+
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
245
|
+
break
|
235
246
|
else:
|
236
|
-
#
|
247
|
+
# Check the available size
|
237
248
|
available_size = (
|
238
249
|
self.token_to_kv_pool.available_size()
|
239
250
|
+ self.tree_cache.evictable_size()
|
240
251
|
)
|
241
|
-
if available_size != self.
|
252
|
+
if available_size != self.max_total_num_tokens:
|
242
253
|
warnings.warn(
|
243
254
|
"Warning: "
|
244
|
-
f"available_size={available_size},
|
255
|
+
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
245
256
|
"KV cache pool leak detected!"
|
246
257
|
)
|
247
258
|
|
@@ -259,8 +270,13 @@ class ModelRpcServer:
|
|
259
270
|
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
260
271
|
]
|
261
272
|
req.image_size = recv_req.image_size
|
262
|
-
req.
|
263
|
-
|
273
|
+
req.origin_input_ids, req.image_offset = (
|
274
|
+
self.model_runner.model.pad_input_ids(
|
275
|
+
req.origin_input_ids_unpadded,
|
276
|
+
req.pad_value,
|
277
|
+
req.pixel_values.shape,
|
278
|
+
req.image_size,
|
279
|
+
)
|
264
280
|
)
|
265
281
|
req.sampling_params = recv_req.sampling_params
|
266
282
|
req.return_logprob = recv_req.return_logprob
|
@@ -277,23 +293,28 @@ class ModelRpcServer:
|
|
277
293
|
req.sampling_params.regex
|
278
294
|
)
|
279
295
|
|
280
|
-
# Truncate long
|
281
|
-
req.
|
296
|
+
# Truncate prompts that are too long
|
297
|
+
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
|
282
298
|
req.sampling_params.max_new_tokens = min(
|
283
299
|
req.sampling_params.max_new_tokens,
|
284
|
-
self.model_config.context_len - 1 - len(req.
|
285
|
-
self.
|
300
|
+
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
301
|
+
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
286
302
|
)
|
287
303
|
self.forward_queue.append(req)
|
288
304
|
|
289
305
|
def get_new_fill_batch(self):
|
290
306
|
if (
|
291
307
|
self.running_batch is not None
|
292
|
-
and len(self.running_batch.reqs) > self.
|
308
|
+
and len(self.running_batch.reqs) > self.max_running_requests
|
293
309
|
):
|
294
310
|
return None
|
295
311
|
|
312
|
+
# Compute matched prefix length
|
296
313
|
for req in self.forward_queue:
|
314
|
+
assert (
|
315
|
+
len(req.output_ids) == 0
|
316
|
+
), "The output ids should be empty when prefilling"
|
317
|
+
req.input_ids = req.origin_input_ids + req.prev_output_ids
|
297
318
|
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
298
319
|
if req.return_logprob:
|
299
320
|
prefix_indices = prefix_indices[: req.logprob_start_len]
|
@@ -321,7 +342,7 @@ class ModelRpcServer:
|
|
321
342
|
)
|
322
343
|
|
323
344
|
for req in self.forward_queue:
|
324
|
-
if req.return_logprob:
|
345
|
+
if req.return_logprob and req.normalized_prompt_logprob is None:
|
325
346
|
# Need at least two tokens to compute normalized logprob
|
326
347
|
if req.extend_input_len < 2:
|
327
348
|
delta = 2 - req.extend_input_len
|
@@ -340,22 +361,21 @@ class ModelRpcServer:
|
|
340
361
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
341
362
|
< available_size
|
342
363
|
and req.extend_input_len + new_batch_input_tokens
|
343
|
-
< self.
|
364
|
+
< self.max_prefill_tokens
|
344
365
|
):
|
345
|
-
delta = self.tree_cache.
|
366
|
+
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
346
367
|
available_size += delta
|
347
368
|
|
348
369
|
if not (
|
349
370
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
350
371
|
< available_size
|
351
372
|
):
|
352
|
-
# Undo
|
353
|
-
delta = self.tree_cache.
|
373
|
+
# Undo locking
|
374
|
+
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
354
375
|
available_size += delta
|
355
376
|
break
|
356
377
|
else:
|
357
378
|
# Add this request to the running batch
|
358
|
-
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
359
379
|
can_run_list.append(req)
|
360
380
|
new_batch_total_tokens += (
|
361
381
|
req.extend_input_len + req.max_new_tokens()
|
@@ -366,6 +386,7 @@ class ModelRpcServer:
|
|
366
386
|
if len(can_run_list) == 0:
|
367
387
|
return None
|
368
388
|
|
389
|
+
# Print stats
|
369
390
|
if self.tp_rank == 0:
|
370
391
|
running_req = (
|
371
392
|
0 if self.running_batch is None else len(self.running_batch.reqs)
|
@@ -386,13 +407,14 @@ class ModelRpcServer:
|
|
386
407
|
f"#running_req: {running_req}. "
|
387
408
|
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
|
388
409
|
)
|
389
|
-
#logger.debug(
|
410
|
+
# logger.debug(
|
390
411
|
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
391
412
|
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
392
413
|
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
393
414
|
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
394
|
-
#)
|
415
|
+
# )
|
395
416
|
|
417
|
+
# Return the new batch
|
396
418
|
new_batch = Batch.init_new(
|
397
419
|
can_run_list,
|
398
420
|
self.req_to_token_pool,
|
@@ -425,9 +447,10 @@ class ModelRpcServer:
|
|
425
447
|
|
426
448
|
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
427
449
|
if last_logprobs is not None:
|
428
|
-
last_token_logprobs =
|
429
|
-
|
430
|
-
|
450
|
+
last_token_logprobs = last_logprobs[
|
451
|
+
torch.arange(len(batch.reqs), device=next_token_ids.device),
|
452
|
+
next_token_ids,
|
453
|
+
].tolist()
|
431
454
|
|
432
455
|
next_token_ids = next_token_ids.tolist()
|
433
456
|
else:
|
@@ -441,38 +464,75 @@ class ModelRpcServer:
|
|
441
464
|
req.check_finished()
|
442
465
|
|
443
466
|
if req.return_logprob:
|
444
|
-
req.normalized_prompt_logprob
|
445
|
-
|
446
|
-
|
447
|
-
req.prefill_token_logprobs
|
448
|
-
|
449
|
-
|
450
|
-
|
467
|
+
if req.normalized_prompt_logprob is None:
|
468
|
+
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
469
|
+
|
470
|
+
if req.prefill_token_logprobs is None:
|
471
|
+
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
472
|
+
req.prefill_token_logprobs = list(
|
473
|
+
zip(
|
474
|
+
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
475
|
+
req.input_ids[-req.extend_input_len + 1 :],
|
476
|
+
)
|
451
477
|
)
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
req.
|
478
|
+
if req.logprob_start_len == 0:
|
479
|
+
req.prefill_token_logprobs = [
|
480
|
+
(None, req.input_ids[0])
|
481
|
+
] + req.prefill_token_logprobs
|
482
|
+
|
483
|
+
if req.last_update_decode_tokens != 0:
|
484
|
+
req.decode_token_logprobs.extend(
|
485
|
+
list(
|
486
|
+
zip(
|
487
|
+
prefill_token_logprobs[
|
488
|
+
pt
|
489
|
+
+ req.extend_input_len
|
490
|
+
- req.last_update_decode_tokens : pt
|
491
|
+
+ req.extend_input_len
|
492
|
+
- 1
|
493
|
+
],
|
494
|
+
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
495
|
+
)
|
496
|
+
)
|
497
|
+
)
|
498
|
+
|
499
|
+
req.decode_token_logprobs.append(
|
458
500
|
(last_token_logprobs[i], next_token_ids[i])
|
459
|
-
|
501
|
+
)
|
460
502
|
|
461
503
|
if req.top_logprobs_num > 0:
|
462
|
-
req.prefill_top_logprobs
|
463
|
-
|
464
|
-
req.
|
465
|
-
|
504
|
+
if req.prefill_top_logprobs is None:
|
505
|
+
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
506
|
+
if req.logprob_start_len == 0:
|
507
|
+
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
508
|
+
|
509
|
+
if req.last_update_decode_tokens != 0:
|
510
|
+
req.decode_top_logprobs.extend(
|
511
|
+
prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
512
|
+
)
|
513
|
+
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
466
514
|
|
467
515
|
pt += req.extend_input_len
|
468
516
|
|
469
517
|
self.handle_finished_requests(batch)
|
470
518
|
|
519
|
+
def cache_filled_batch(self, batch: Batch):
|
520
|
+
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
521
|
+
for i, req in enumerate(batch.reqs):
|
522
|
+
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
523
|
+
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
524
|
+
last_uncached_pos=len(req.prefix_indices),
|
525
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
526
|
+
del_in_memory_pool=False,
|
527
|
+
old_last_node=req.last_node,
|
528
|
+
)
|
529
|
+
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
530
|
+
|
471
531
|
def forward_decode_batch(self, batch: Batch):
|
472
532
|
# check if decode out of memory
|
473
533
|
if not batch.check_decode_mem():
|
474
534
|
old_ratio = self.new_token_ratio
|
475
|
-
self.new_token_ratio = min(old_ratio + self.
|
535
|
+
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
476
536
|
|
477
537
|
retracted_reqs = batch.retract_decode()
|
478
538
|
logger.info(
|
@@ -483,26 +543,13 @@ class ModelRpcServer:
|
|
483
543
|
self.forward_queue.extend(retracted_reqs)
|
484
544
|
else:
|
485
545
|
self.new_token_ratio = max(
|
486
|
-
self.new_token_ratio - self.
|
546
|
+
self.new_token_ratio - self.new_token_ratio_decay,
|
487
547
|
self.min_new_token_ratio,
|
488
548
|
)
|
489
549
|
|
490
550
|
if not self.disable_regex_jump_forward:
|
491
551
|
# check for jump-forward
|
492
|
-
jump_forward_reqs = batch.check_for_jump_forward()
|
493
|
-
|
494
|
-
# check for image jump-forward
|
495
|
-
for req in jump_forward_reqs:
|
496
|
-
if req.pixel_values is not None:
|
497
|
-
(
|
498
|
-
req.input_ids,
|
499
|
-
req.image_offset,
|
500
|
-
) = self.model_runner.model.pad_input_ids(
|
501
|
-
req.input_ids,
|
502
|
-
req.pad_value,
|
503
|
-
req.pixel_values.shape,
|
504
|
-
req.image_size,
|
505
|
-
)
|
552
|
+
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
506
553
|
|
507
554
|
self.forward_queue.extend(jump_forward_reqs)
|
508
555
|
if batch.is_empty():
|
@@ -545,8 +592,8 @@ class ModelRpcServer:
|
|
545
592
|
|
546
593
|
def handle_finished_requests(self, batch: Batch):
|
547
594
|
output_rids = []
|
595
|
+
prev_output_strs = []
|
548
596
|
output_tokens = []
|
549
|
-
output_and_jump_forward_strs = []
|
550
597
|
output_hit_stop_str = []
|
551
598
|
output_skip_special_tokens = []
|
552
599
|
output_spaces_between_special_tokens = []
|
@@ -570,8 +617,8 @@ class ModelRpcServer:
|
|
570
617
|
)
|
571
618
|
):
|
572
619
|
output_rids.append(req.rid)
|
620
|
+
prev_output_strs.append(req.prev_output_str)
|
573
621
|
output_tokens.append(req.output_ids)
|
574
|
-
output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
|
575
622
|
output_hit_stop_str.append(req.hit_stop_str)
|
576
623
|
output_skip_special_tokens.append(
|
577
624
|
req.sampling_params.skip_special_tokens
|
@@ -581,12 +628,11 @@ class ModelRpcServer:
|
|
581
628
|
)
|
582
629
|
|
583
630
|
meta_info = {
|
584
|
-
"prompt_tokens": req.
|
585
|
-
"completion_tokens": len(req.
|
586
|
-
+ len(req.output_ids)
|
587
|
-
- req.prompt_tokens,
|
631
|
+
"prompt_tokens": len(req.origin_input_ids),
|
632
|
+
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
|
588
633
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
589
|
-
"finish_reason":
|
634
|
+
"finish_reason": FinishReason.to_str(req.finish_reason),
|
635
|
+
"hit_stop_str": req.hit_stop_str,
|
590
636
|
}
|
591
637
|
if req.return_logprob:
|
592
638
|
(
|
@@ -610,8 +656,8 @@ class ModelRpcServer:
|
|
610
656
|
self.out_pyobjs.append(
|
611
657
|
BatchTokenIDOut(
|
612
658
|
output_rids,
|
659
|
+
prev_output_strs,
|
613
660
|
output_tokens,
|
614
|
-
output_and_jump_forward_strs,
|
615
661
|
output_hit_stop_str,
|
616
662
|
output_skip_special_tokens,
|
617
663
|
output_spaces_between_special_tokens,
|
@@ -626,17 +672,13 @@ class ModelRpcServer:
|
|
626
672
|
req_pool_indices_cpu = batch.req_pool_indices.tolist()
|
627
673
|
for i in finished_indices:
|
628
674
|
req = batch.reqs[i]
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
prefix_len = self.tree_cache.insert(
|
634
|
-
token_ids[:seq_len], indices.clone()
|
675
|
+
self.tree_cache.cache_req(
|
676
|
+
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
677
|
+
last_uncached_pos=len(req.prefix_indices),
|
678
|
+
req_pool_idx=req_pool_indices_cpu[i],
|
635
679
|
)
|
636
680
|
|
637
|
-
self.
|
638
|
-
self.req_to_token_pool.free(req_pool_idx)
|
639
|
-
self.tree_cache.dec_ref_counter(req.last_node)
|
681
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
640
682
|
|
641
683
|
# Update batch tensors
|
642
684
|
if unfinished_indices:
|
@@ -644,19 +686,58 @@ class ModelRpcServer:
|
|
644
686
|
else:
|
645
687
|
batch.reqs = []
|
646
688
|
|
689
|
+
def flush_cache(self):
|
690
|
+
if len(self.forward_queue) == 0 and (
|
691
|
+
self.running_batch is None or len(self.running_batch.reqs) == 0
|
692
|
+
):
|
693
|
+
self.tree_cache.reset()
|
694
|
+
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
695
|
+
self.regex_fsm_cache.reset()
|
696
|
+
self.req_to_token_pool.clear()
|
697
|
+
self.token_to_kv_pool.clear()
|
698
|
+
torch.cuda.empty_cache()
|
699
|
+
logger.info("Cache flushed successfully!")
|
700
|
+
else:
|
701
|
+
warnings.warn(
|
702
|
+
f"Cache not flushed because there are pending requests. "
|
703
|
+
f"#queue-req: {len(self.forward_queue)}, "
|
704
|
+
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
705
|
+
)
|
706
|
+
|
707
|
+
def abort_request(self, recv_req):
|
708
|
+
# Delete requests in the waiting queue
|
709
|
+
to_del = None
|
710
|
+
for i, req in enumerate(self.forward_queue):
|
711
|
+
if req.rid == recv_req.rid:
|
712
|
+
to_del = i
|
713
|
+
break
|
714
|
+
|
715
|
+
if to_del is not None:
|
716
|
+
del self.forward_queue[to_del]
|
717
|
+
|
718
|
+
# Delete requests in the running batch
|
719
|
+
if self.running_batch:
|
720
|
+
for req in self.running_batch.reqs:
|
721
|
+
if req.rid == recv_req.rid:
|
722
|
+
req.finished = True
|
723
|
+
req.finish_reason = FinishReason.ABORT
|
724
|
+
break
|
725
|
+
|
647
726
|
|
648
727
|
class ModelRpcService(rpyc.Service):
|
649
728
|
exposed_ModelRpcServer = ModelRpcServer
|
650
729
|
|
651
730
|
|
652
731
|
class ModelRpcClient:
|
653
|
-
def __init__(
|
732
|
+
def __init__(
|
733
|
+
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
|
734
|
+
):
|
654
735
|
tp_size = server_args.tp_size
|
655
736
|
|
656
737
|
if tp_size == 1:
|
657
738
|
# Init model
|
658
739
|
self.model_server = ModelRpcService().exposed_ModelRpcServer(
|
659
|
-
0, server_args, port_args
|
740
|
+
0, server_args, port_args, model_overide_args
|
660
741
|
)
|
661
742
|
|
662
743
|
# Wrap functions
|
@@ -677,7 +758,7 @@ class ModelRpcClient:
|
|
677
758
|
# Init model
|
678
759
|
def init_model(i):
|
679
760
|
return self.remote_services[i].ModelRpcServer(
|
680
|
-
i, server_args, port_args
|
761
|
+
i, server_args, port_args, model_overide_args
|
681
762
|
)
|
682
763
|
|
683
764
|
self.model_servers = executor.map(init_model, range(tp_size))
|
@@ -700,7 +781,11 @@ def _init_service(port):
|
|
700
781
|
t = ThreadedServer(
|
701
782
|
ModelRpcService(),
|
702
783
|
port=port,
|
703
|
-
protocol_config={
|
784
|
+
protocol_config={
|
785
|
+
"allow_public_attrs": True,
|
786
|
+
"allow_pickle": True,
|
787
|
+
"sync_request_timeout": 3600,
|
788
|
+
},
|
704
789
|
)
|
705
790
|
t.start()
|
706
791
|
|
@@ -716,7 +801,11 @@ def start_model_process(port):
|
|
716
801
|
con = rpyc.connect(
|
717
802
|
"localhost",
|
718
803
|
port,
|
719
|
-
config={
|
804
|
+
config={
|
805
|
+
"allow_public_attrs": True,
|
806
|
+
"allow_pickle": True,
|
807
|
+
"sync_request_timeout": 3600,
|
808
|
+
},
|
720
809
|
)
|
721
810
|
break
|
722
811
|
except ConnectionRefusedError:
|