sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +7 -7
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +158 -11
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +12 -2
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +28 -3
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +8 -2
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +3 -1
- sglang/srt/hf_transformers_utils.py +130 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +582 -0
- sglang/srt/layers/logits_processor.py +65 -32
- sglang/srt/layers/radix_attention.py +41 -7
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
- sglang/srt/managers/{router → controller}/model_runner.py +262 -158
- sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
- sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
- sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
- sglang/srt/managers/detokenizer_manager.py +42 -46
- sglang/srt/managers/io_struct.py +22 -12
- sglang/srt/managers/tokenizer_manager.py +151 -87
- sglang/srt/model_config.py +83 -5
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +12 -15
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +26 -15
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +86 -19
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +282 -103
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +150 -95
- sglang/srt/openai_protocol.py +11 -2
- sglang/srt/server.py +124 -48
- sglang/srt/server_args.py +128 -48
- sglang/srt/utils.py +234 -67
- sglang/test/test_programs.py +65 -3
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +23 -4
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,61 +1,68 @@
|
|
1
|
+
"""A tensor parallel worker."""
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import logging
|
3
|
-
import multiprocessing
|
4
5
|
import time
|
5
6
|
import warnings
|
6
7
|
from concurrent.futures import ThreadPoolExecutor
|
7
|
-
from typing import
|
8
|
+
from typing import List, Optional
|
8
9
|
|
9
10
|
import rpyc
|
10
11
|
import torch
|
11
12
|
from rpyc.utils.classic import obtain
|
12
|
-
from rpyc.utils.server import ThreadedServer
|
13
|
-
|
14
|
-
try:
|
15
|
-
from vllm.logger import _default_handler as vllm_default_logger
|
16
|
-
except ImportError:
|
17
|
-
from vllm.logger import logger as vllm_default_logger
|
18
13
|
|
14
|
+
from sglang.global_config import global_config
|
19
15
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
20
16
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
21
17
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
18
|
+
from sglang.srt.managers.controller.infer_batch import (
|
19
|
+
FINISH_ABORT,
|
20
|
+
BaseFinishReason,
|
21
|
+
Batch,
|
22
|
+
ForwardMode,
|
23
|
+
Req,
|
24
|
+
)
|
25
|
+
from sglang.srt.managers.controller.model_runner import ModelRunner
|
26
|
+
from sglang.srt.managers.controller.radix_cache import RadixCache
|
27
|
+
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
22
28
|
from sglang.srt.managers.io_struct import (
|
29
|
+
AbortReq,
|
23
30
|
BatchTokenIDOut,
|
24
31
|
FlushCacheReq,
|
25
32
|
TokenizedGenerateReqInput,
|
26
33
|
)
|
27
|
-
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req, FinishReason
|
28
|
-
from sglang.srt.managers.router.model_runner import ModelRunner
|
29
|
-
from sglang.srt.managers.router.radix_cache import RadixCache
|
30
|
-
from sglang.srt.managers.router.scheduler import Scheduler
|
31
34
|
from sglang.srt.model_config import ModelConfig
|
32
|
-
from sglang.srt.server_args import
|
35
|
+
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
33
36
|
from sglang.srt.utils import (
|
34
|
-
get_exception_traceback,
|
35
37
|
get_int_token_logit_bias,
|
36
38
|
is_multimodal_model,
|
37
39
|
set_random_seed,
|
40
|
+
start_rpyc_service_process,
|
41
|
+
connect_rpyc_service,
|
42
|
+
suppress_other_loggers,
|
38
43
|
)
|
44
|
+
from sglang.utils import get_exception_traceback
|
39
45
|
|
40
|
-
|
41
|
-
logger = logging.getLogger("model_rpc")
|
42
|
-
vllm_default_logger.setLevel(logging.WARN)
|
43
|
-
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
46
|
+
logger = logging.getLogger("srt.tp_worker")
|
44
47
|
|
45
48
|
|
46
|
-
class
|
49
|
+
class ModelTpServer:
|
47
50
|
def __init__(
|
48
51
|
self,
|
52
|
+
gpu_id: int,
|
49
53
|
tp_rank: int,
|
50
54
|
server_args: ServerArgs,
|
51
|
-
|
52
|
-
model_overide_args
|
55
|
+
model_port_args: ModelPortArgs,
|
56
|
+
model_overide_args,
|
53
57
|
):
|
54
|
-
server_args,
|
58
|
+
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
59
|
+
suppress_other_loggers()
|
55
60
|
|
56
61
|
# Copy arguments
|
62
|
+
self.gpu_id = gpu_id
|
57
63
|
self.tp_rank = tp_rank
|
58
64
|
self.tp_size = server_args.tp_size
|
65
|
+
self.dp_size = server_args.dp_size
|
59
66
|
self.schedule_heuristic = server_args.schedule_heuristic
|
60
67
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
61
68
|
|
@@ -66,23 +73,16 @@ class ModelRpcServer:
|
|
66
73
|
context_length=server_args.context_length,
|
67
74
|
model_overide_args=model_overide_args,
|
68
75
|
)
|
69
|
-
|
70
|
-
# For model end global settings
|
71
|
-
server_args_dict = {
|
72
|
-
"enable_flashinfer": server_args.enable_flashinfer,
|
73
|
-
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
74
|
-
}
|
75
|
-
|
76
76
|
self.model_runner = ModelRunner(
|
77
77
|
model_config=self.model_config,
|
78
78
|
mem_fraction_static=server_args.mem_fraction_static,
|
79
|
+
gpu_id=gpu_id,
|
79
80
|
tp_rank=tp_rank,
|
80
81
|
tp_size=server_args.tp_size,
|
81
|
-
nccl_port=
|
82
|
-
|
83
|
-
trust_remote_code=server_args.trust_remote_code,
|
84
|
-
server_args_dict=server_args_dict,
|
82
|
+
nccl_port=model_port_args.nccl_port,
|
83
|
+
server_args=server_args,
|
85
84
|
)
|
85
|
+
|
86
86
|
if is_multimodal_model(server_args.model_path):
|
87
87
|
self.processor = get_processor(
|
88
88
|
server_args.tokenizer_path,
|
@@ -96,28 +96,34 @@ class ModelRpcServer:
|
|
96
96
|
tokenizer_mode=server_args.tokenizer_mode,
|
97
97
|
trust_remote_code=server_args.trust_remote_code,
|
98
98
|
)
|
99
|
-
self.
|
100
|
-
self.
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
99
|
+
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
100
|
+
self.max_prefill_tokens = (
|
101
|
+
4096
|
102
|
+
if server_args.max_prefill_tokens is None
|
103
|
+
else server_args.max_prefill_tokens
|
104
|
+
)
|
105
|
+
self.max_running_requests = (
|
106
|
+
self.max_total_num_tokens // 2
|
107
|
+
if server_args.max_running_requests is None
|
108
|
+
else server_args.max_running_requests
|
108
109
|
)
|
109
110
|
self.int_token_logit_bias = torch.tensor(
|
110
111
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
111
112
|
)
|
112
113
|
set_random_seed(server_args.random_seed)
|
114
|
+
|
115
|
+
# Print info
|
113
116
|
logger.info(
|
114
|
-
f"
|
115
|
-
f"
|
116
|
-
f"
|
117
|
-
f"context_len={self.model_config.context_len}
|
117
|
+
f"[gpu_id={self.gpu_id}] "
|
118
|
+
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
119
|
+
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
120
|
+
f"context_len={self.model_config.context_len}"
|
118
121
|
)
|
119
122
|
if self.tp_rank == 0:
|
120
|
-
logger.info(
|
123
|
+
logger.info(
|
124
|
+
f"[gpu_id={self.gpu_id}] "
|
125
|
+
f"server_args: {server_args.print_mode_args()}"
|
126
|
+
)
|
121
127
|
|
122
128
|
# Init cache
|
123
129
|
self.tree_cache = RadixCache(
|
@@ -126,11 +132,11 @@ class ModelRpcServer:
|
|
126
132
|
disable=server_args.disable_radix_cache,
|
127
133
|
)
|
128
134
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
129
|
-
self.scheduler =
|
135
|
+
self.scheduler = ScheduleHeuristic(
|
130
136
|
self.schedule_heuristic,
|
131
|
-
self.
|
132
|
-
self.
|
133
|
-
self.
|
137
|
+
self.max_running_requests,
|
138
|
+
self.max_prefill_tokens,
|
139
|
+
self.max_total_num_tokens,
|
134
140
|
self.tree_cache,
|
135
141
|
)
|
136
142
|
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
@@ -156,30 +162,23 @@ class ModelRpcServer:
|
|
156
162
|
self.jump_forward_cache = JumpForwardCache()
|
157
163
|
|
158
164
|
# Init new token estimation
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
logger.info("Cache flushed successfully!")
|
174
|
-
else:
|
175
|
-
warnings.warn(
|
176
|
-
f"Cache not flushed because there are pending requests. "
|
177
|
-
f"#queue-req: {len(self.forward_queue)}, "
|
178
|
-
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
179
|
-
)
|
165
|
+
assert (
|
166
|
+
server_args.schedule_conservativeness >= 0
|
167
|
+
), "Invalid schedule_conservativeness"
|
168
|
+
self.new_token_ratio = min(
|
169
|
+
global_config.base_new_token_ratio * server_args.schedule_conservativeness,
|
170
|
+
1.0,
|
171
|
+
)
|
172
|
+
self.min_new_token_ratio = min(
|
173
|
+
global_config.base_min_new_token_ratio
|
174
|
+
* server_args.schedule_conservativeness,
|
175
|
+
1.0,
|
176
|
+
)
|
177
|
+
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
178
|
+
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
180
179
|
|
181
180
|
def exposed_step(self, recv_reqs):
|
182
|
-
if self.tp_size != 1:
|
181
|
+
if self.tp_size * self.dp_size != 1:
|
183
182
|
recv_reqs = obtain(recv_reqs)
|
184
183
|
|
185
184
|
try:
|
@@ -189,13 +188,16 @@ class ModelRpcServer:
|
|
189
188
|
self.handle_generate_request(recv_req)
|
190
189
|
elif isinstance(recv_req, FlushCacheReq):
|
191
190
|
self.flush_cache()
|
191
|
+
elif isinstance(recv_req, AbortReq):
|
192
|
+
self.abort_request(recv_req)
|
192
193
|
else:
|
193
194
|
raise ValueError(f"Invalid request: {recv_req}")
|
194
195
|
|
195
196
|
# Forward
|
196
197
|
self.forward_step()
|
197
198
|
except Exception:
|
198
|
-
logger.error("Exception in
|
199
|
+
logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
|
200
|
+
raise
|
199
201
|
|
200
202
|
# Return results
|
201
203
|
ret = self.out_pyobjs
|
@@ -207,9 +209,8 @@ class ModelRpcServer:
|
|
207
209
|
new_batch = self.get_new_fill_batch()
|
208
210
|
|
209
211
|
if new_batch is not None:
|
210
|
-
# Run new fill batch
|
212
|
+
# Run a new fill batch
|
211
213
|
self.forward_fill_batch(new_batch)
|
212
|
-
|
213
214
|
self.cache_filled_batch(new_batch)
|
214
215
|
|
215
216
|
if not new_batch.is_empty():
|
@@ -225,39 +226,43 @@ class ModelRpcServer:
|
|
225
226
|
self.num_generated_tokens += len(self.running_batch.reqs)
|
226
227
|
self.forward_decode_batch(self.running_batch)
|
227
228
|
|
228
|
-
|
229
|
-
|
230
|
-
break
|
231
|
-
|
232
|
-
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
233
|
-
break
|
234
|
-
|
235
|
-
if self.running_batch is not None and self.tp_rank == 0:
|
229
|
+
# Print stats
|
230
|
+
if self.tp_rank == 0:
|
236
231
|
if self.decode_forward_ct % 40 == 0:
|
237
|
-
num_used = self.
|
232
|
+
num_used = self.max_total_num_tokens - (
|
238
233
|
self.token_to_kv_pool.available_size()
|
239
234
|
+ self.tree_cache.evictable_size()
|
240
235
|
)
|
241
|
-
|
236
|
+
throughput = self.num_generated_tokens / (
|
237
|
+
time.time() - self.last_stats_tic
|
238
|
+
)
|
242
239
|
self.num_generated_tokens = 0
|
243
240
|
self.last_stats_tic = time.time()
|
244
241
|
logger.info(
|
242
|
+
f"[gpu_id={self.gpu_id}] Decode batch. "
|
245
243
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
246
244
|
f"#token: {num_used}, "
|
247
|
-
f"token usage: {num_used / self.
|
248
|
-
f"gen throughput (token/s): {
|
245
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
246
|
+
f"gen throughput (token/s): {throughput:.2f}, "
|
249
247
|
f"#queue-req: {len(self.forward_queue)}"
|
250
248
|
)
|
249
|
+
|
250
|
+
if self.running_batch.is_empty():
|
251
|
+
self.running_batch = None
|
252
|
+
break
|
253
|
+
|
254
|
+
if self.out_pyobjs and self.running_batch.has_stream():
|
255
|
+
break
|
251
256
|
else:
|
252
|
-
#
|
257
|
+
# Check the available size
|
253
258
|
available_size = (
|
254
259
|
self.token_to_kv_pool.available_size()
|
255
260
|
+ self.tree_cache.evictable_size()
|
256
261
|
)
|
257
|
-
if available_size != self.
|
262
|
+
if available_size != self.max_total_num_tokens:
|
258
263
|
warnings.warn(
|
259
264
|
"Warning: "
|
260
|
-
f"available_size={available_size},
|
265
|
+
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
261
266
|
"KV cache pool leak detected!"
|
262
267
|
)
|
263
268
|
|
@@ -275,8 +280,14 @@ class ModelRpcServer:
|
|
275
280
|
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
276
281
|
]
|
277
282
|
req.image_size = recv_req.image_size
|
278
|
-
|
279
|
-
req.
|
283
|
+
(
|
284
|
+
req.origin_input_ids,
|
285
|
+
req.image_offset,
|
286
|
+
) = self.model_runner.model.pad_input_ids(
|
287
|
+
req.origin_input_ids_unpadded,
|
288
|
+
req.pad_value,
|
289
|
+
req.pixel_values.shape,
|
290
|
+
req.image_size,
|
280
291
|
)
|
281
292
|
req.sampling_params = recv_req.sampling_params
|
282
293
|
req.return_logprob = recv_req.return_logprob
|
@@ -293,23 +304,25 @@ class ModelRpcServer:
|
|
293
304
|
req.sampling_params.regex
|
294
305
|
)
|
295
306
|
|
296
|
-
# Truncate long
|
297
|
-
req.
|
307
|
+
# Truncate prompts that are too long
|
308
|
+
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
|
298
309
|
req.sampling_params.max_new_tokens = min(
|
299
310
|
req.sampling_params.max_new_tokens,
|
300
|
-
self.model_config.context_len - 1 - len(req.
|
301
|
-
self.
|
311
|
+
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
312
|
+
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
302
313
|
)
|
303
314
|
self.forward_queue.append(req)
|
304
315
|
|
305
|
-
def get_new_fill_batch(self):
|
316
|
+
def get_new_fill_batch(self) -> Optional[Batch]:
|
306
317
|
if (
|
307
318
|
self.running_batch is not None
|
308
|
-
and len(self.running_batch.reqs) > self.
|
319
|
+
and len(self.running_batch.reqs) > self.max_running_requests
|
309
320
|
):
|
310
321
|
return None
|
311
322
|
|
323
|
+
# Compute matched prefix length
|
312
324
|
for req in self.forward_queue:
|
325
|
+
req.input_ids = req.origin_input_ids + req.output_ids
|
313
326
|
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
314
327
|
if req.return_logprob:
|
315
328
|
prefix_indices = prefix_indices[: req.logprob_start_len]
|
@@ -337,7 +350,7 @@ class ModelRpcServer:
|
|
337
350
|
)
|
338
351
|
|
339
352
|
for req in self.forward_queue:
|
340
|
-
if req.return_logprob:
|
353
|
+
if req.return_logprob and req.normalized_prompt_logprob is None:
|
341
354
|
# Need at least two tokens to compute normalized logprob
|
342
355
|
if req.extend_input_len < 2:
|
343
356
|
delta = 2 - req.extend_input_len
|
@@ -355,8 +368,9 @@ class ModelRpcServer:
|
|
355
368
|
if (
|
356
369
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
357
370
|
< available_size
|
358
|
-
and req.extend_input_len + new_batch_input_tokens
|
359
|
-
|
371
|
+
and (req.extend_input_len + new_batch_input_tokens
|
372
|
+
<= self.max_prefill_tokens
|
373
|
+
or len(can_run_list) == 0)
|
360
374
|
):
|
361
375
|
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
362
376
|
available_size += delta
|
@@ -381,6 +395,7 @@ class ModelRpcServer:
|
|
381
395
|
if len(can_run_list) == 0:
|
382
396
|
return None
|
383
397
|
|
398
|
+
# Print stats
|
384
399
|
if self.tp_rank == 0:
|
385
400
|
running_req = (
|
386
401
|
0 if self.running_batch is None else len(self.running_batch.reqs)
|
@@ -394,20 +409,22 @@ class ModelRpcServer:
|
|
394
409
|
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
395
410
|
)
|
396
411
|
logger.info(
|
397
|
-
f"
|
398
|
-
f"#
|
399
|
-
f"#
|
400
|
-
f"#
|
401
|
-
f"
|
402
|
-
f"
|
412
|
+
f"[gpu_id={self.gpu_id}] Prefill batch. "
|
413
|
+
f"#new-seq: {len(can_run_list)}, "
|
414
|
+
f"#new-token: {new_batch_input_tokens}, "
|
415
|
+
f"#cached-token: {hit_tokens}, "
|
416
|
+
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
417
|
+
f"#running-req: {running_req}, "
|
418
|
+
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
403
419
|
)
|
404
|
-
#logger.debug(
|
420
|
+
# logger.debug(
|
405
421
|
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
|
406
422
|
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
|
407
423
|
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
|
408
424
|
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
409
|
-
#)
|
425
|
+
# )
|
410
426
|
|
427
|
+
# Return the new batch
|
411
428
|
new_batch = Batch.init_new(
|
412
429
|
can_run_list,
|
413
430
|
self.req_to_token_pool,
|
@@ -423,73 +440,91 @@ class ModelRpcServer:
|
|
423
440
|
self.model_config.vocab_size, self.int_token_logit_bias
|
424
441
|
)
|
425
442
|
|
443
|
+
# Forward and sample the next tokens
|
426
444
|
if batch.extend_num_tokens != 0:
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
prefill_token_logprobs = prefill_token_logprobs.tolist()
|
437
|
-
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
|
438
|
-
|
439
|
-
next_token_ids, _ = batch.sample(logits)
|
440
|
-
|
441
|
-
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
|
442
|
-
if last_logprobs is not None:
|
443
|
-
last_token_logprobs = (
|
444
|
-
last_logprobs[
|
445
|
-
torch.arange(len(batch.reqs), device=next_token_ids.device),
|
446
|
-
next_token_ids].tolist()
|
447
|
-
)
|
445
|
+
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
446
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
447
|
+
|
448
|
+
# Move logprobs to cpu
|
449
|
+
if output.next_token_logprobs is not None:
|
450
|
+
output.next_token_logprobs = output.next_token_logprobs[
|
451
|
+
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
452
|
+
next_token_ids,
|
453
|
+
].tolist()
|
454
|
+
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
|
455
|
+
output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist()
|
448
456
|
|
449
457
|
next_token_ids = next_token_ids.tolist()
|
450
458
|
else:
|
451
459
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
452
460
|
|
453
|
-
# Check finish
|
461
|
+
# Check finish conditions
|
454
462
|
pt = 0
|
455
463
|
for i, req in enumerate(batch.reqs):
|
456
464
|
req.completion_tokens_wo_jump_forward += 1
|
457
|
-
req.output_ids
|
465
|
+
req.output_ids.append(next_token_ids[i])
|
458
466
|
req.check_finished()
|
459
467
|
|
460
468
|
if req.return_logprob:
|
461
|
-
|
469
|
+
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
470
|
+
pt += req.extend_input_len
|
471
|
+
|
472
|
+
self.handle_finished_requests(batch)
|
473
|
+
|
474
|
+
def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
|
475
|
+
if req.normalized_prompt_logprob is None:
|
476
|
+
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
462
477
|
|
463
|
-
|
464
|
-
|
478
|
+
if req.prefill_token_logprobs is None:
|
479
|
+
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
480
|
+
req.prefill_token_logprobs = list(
|
481
|
+
zip(
|
482
|
+
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
483
|
+
req.input_ids[-req.extend_input_len + 1 :],
|
484
|
+
)
|
485
|
+
)
|
486
|
+
if req.logprob_start_len == 0:
|
487
|
+
req.prefill_token_logprobs = [
|
488
|
+
(None, req.input_ids[0])
|
489
|
+
] + req.prefill_token_logprobs
|
490
|
+
|
491
|
+
if req.last_update_decode_tokens != 0:
|
492
|
+
req.decode_token_logprobs.extend(
|
493
|
+
list(
|
465
494
|
zip(
|
466
|
-
prefill_token_logprobs[
|
467
|
-
|
495
|
+
output.prefill_token_logprobs[
|
496
|
+
pt
|
497
|
+
+ req.extend_input_len
|
498
|
+
- req.last_update_decode_tokens : pt
|
499
|
+
+ req.extend_input_len
|
500
|
+
- 1
|
501
|
+
],
|
502
|
+
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
468
503
|
)
|
469
504
|
)
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
(last_token_logprobs[i], next_token_ids[i])
|
476
|
-
]
|
505
|
+
)
|
506
|
+
|
507
|
+
req.decode_token_logprobs.append(
|
508
|
+
(output.next_token_logprobs[i], next_token_ids[i])
|
509
|
+
)
|
477
510
|
|
478
|
-
|
479
|
-
|
511
|
+
if req.top_logprobs_num > 0:
|
512
|
+
if req.prefill_top_logprobs is None:
|
513
|
+
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
|
480
514
|
if req.logprob_start_len == 0:
|
481
515
|
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
482
|
-
req.decode_top_logprobs = [decode_top_logprobs[i]]
|
483
516
|
|
484
|
-
|
485
|
-
|
486
|
-
|
517
|
+
if req.last_update_decode_tokens != 0:
|
518
|
+
req.decode_top_logprobs.extend(
|
519
|
+
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
520
|
+
)
|
521
|
+
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
487
522
|
|
488
523
|
def cache_filled_batch(self, batch: Batch):
|
489
|
-
req_pool_indices_cpu = batch.req_pool_indices.cpu().
|
524
|
+
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
490
525
|
for i, req in enumerate(batch.reqs):
|
491
526
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
492
|
-
token_ids=tuple(req.
|
527
|
+
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
493
528
|
last_uncached_pos=len(req.prefix_indices),
|
494
529
|
req_pool_idx=req_pool_indices_cpu[i],
|
495
530
|
del_in_memory_pool=False,
|
@@ -498,10 +533,10 @@ class ModelRpcServer:
|
|
498
533
|
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
499
534
|
|
500
535
|
def forward_decode_batch(self, batch: Batch):
|
501
|
-
#
|
536
|
+
# Check if decode out of memory
|
502
537
|
if not batch.check_decode_mem():
|
503
538
|
old_ratio = self.new_token_ratio
|
504
|
-
self.new_token_ratio = min(old_ratio + self.
|
539
|
+
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
505
540
|
|
506
541
|
retracted_reqs = batch.retract_decode()
|
507
542
|
logger.info(
|
@@ -512,27 +547,13 @@ class ModelRpcServer:
|
|
512
547
|
self.forward_queue.extend(retracted_reqs)
|
513
548
|
else:
|
514
549
|
self.new_token_ratio = max(
|
515
|
-
self.new_token_ratio - self.
|
550
|
+
self.new_token_ratio - self.new_token_ratio_decay,
|
516
551
|
self.min_new_token_ratio,
|
517
552
|
)
|
518
553
|
|
519
554
|
if not self.disable_regex_jump_forward:
|
520
|
-
#
|
521
|
-
jump_forward_reqs = batch.check_for_jump_forward()
|
522
|
-
|
523
|
-
# check for image jump-forward
|
524
|
-
for req in jump_forward_reqs:
|
525
|
-
if req.pixel_values is not None:
|
526
|
-
(
|
527
|
-
req.input_ids,
|
528
|
-
req.image_offset,
|
529
|
-
) = self.model_runner.model.pad_input_ids(
|
530
|
-
req.input_ids,
|
531
|
-
req.pad_value,
|
532
|
-
req.pixel_values.shape,
|
533
|
-
req.image_size,
|
534
|
-
)
|
535
|
-
|
555
|
+
# Check for jump-forward
|
556
|
+
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
536
557
|
self.forward_queue.extend(jump_forward_reqs)
|
537
558
|
if batch.is_empty():
|
538
559
|
return
|
@@ -541,23 +562,19 @@ class ModelRpcServer:
|
|
541
562
|
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
542
563
|
batch.prepare_for_decode()
|
543
564
|
|
544
|
-
# Forward
|
545
|
-
|
546
|
-
|
547
|
-
_,
|
548
|
-
_,
|
549
|
-
decode_top_logprobs,
|
550
|
-
last_logprobs,
|
551
|
-
) = self.model_runner.forward(batch, ForwardMode.DECODE)
|
552
|
-
next_token_ids, _ = batch.sample(logits)
|
553
|
-
next_token_ids = next_token_ids.tolist()
|
565
|
+
# Forward and sample the next tokens
|
566
|
+
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
567
|
+
next_token_ids, _ = batch.sample(output.next_token_logits)
|
554
568
|
|
555
|
-
#
|
556
|
-
if
|
557
|
-
|
558
|
-
torch.arange(len(
|
569
|
+
# Move logprobs to cpu
|
570
|
+
if output.next_token_logprobs is not None:
|
571
|
+
next_token_logprobs = output.next_token_logprobs[
|
572
|
+
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
573
|
+
next_token_ids,
|
559
574
|
].tolist()
|
560
575
|
|
576
|
+
next_token_ids = next_token_ids.tolist()
|
577
|
+
|
561
578
|
# Check finish condition
|
562
579
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
563
580
|
req.completion_tokens_wo_jump_forward += 1
|
@@ -565,31 +582,30 @@ class ModelRpcServer:
|
|
565
582
|
req.check_finished()
|
566
583
|
|
567
584
|
if req.return_logprob:
|
568
|
-
req.decode_token_logprobs.append((
|
569
|
-
|
570
|
-
|
571
|
-
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
585
|
+
req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id))
|
586
|
+
if req.top_logprobs_num > 0:
|
587
|
+
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
572
588
|
|
573
589
|
self.handle_finished_requests(batch)
|
574
590
|
|
575
591
|
def handle_finished_requests(self, batch: Batch):
|
576
592
|
output_rids = []
|
577
|
-
|
578
|
-
|
579
|
-
|
593
|
+
decoded_texts = []
|
594
|
+
surr_output_ids = []
|
595
|
+
read_output_ids = []
|
580
596
|
output_skip_special_tokens = []
|
581
597
|
output_spaces_between_special_tokens = []
|
582
598
|
output_meta_info = []
|
583
|
-
|
599
|
+
output_finished_reason: List[BaseFinishReason] = []
|
584
600
|
finished_indices = []
|
585
601
|
unfinished_indices = []
|
586
602
|
for i, req in enumerate(batch.reqs):
|
587
|
-
if req.finished:
|
603
|
+
if req.finished():
|
588
604
|
finished_indices.append(i)
|
589
605
|
else:
|
590
606
|
unfinished_indices.append(i)
|
591
607
|
|
592
|
-
if req.finished or (
|
608
|
+
if req.finished() or (
|
593
609
|
(
|
594
610
|
req.stream
|
595
611
|
and (
|
@@ -599,9 +615,10 @@ class ModelRpcServer:
|
|
599
615
|
)
|
600
616
|
):
|
601
617
|
output_rids.append(req.rid)
|
602
|
-
|
603
|
-
|
604
|
-
|
618
|
+
decoded_texts.append(req.decoded_text)
|
619
|
+
surr_ids, read_ids, _ = req.init_detokenize_incrementally()
|
620
|
+
surr_output_ids.append(surr_ids)
|
621
|
+
read_output_ids.append(read_ids)
|
605
622
|
output_skip_special_tokens.append(
|
606
623
|
req.sampling_params.skip_special_tokens
|
607
624
|
)
|
@@ -610,13 +627,10 @@ class ModelRpcServer:
|
|
610
627
|
)
|
611
628
|
|
612
629
|
meta_info = {
|
613
|
-
"prompt_tokens": req.
|
614
|
-
"completion_tokens": len(req.
|
615
|
-
+ len(req.output_ids)
|
616
|
-
- req.prompt_tokens,
|
630
|
+
"prompt_tokens": len(req.origin_input_ids),
|
631
|
+
"completion_tokens": len(req.output_ids),
|
617
632
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
618
|
-
"finish_reason":
|
619
|
-
"hit_stop_str": req.hit_stop_str,
|
633
|
+
"finish_reason": str(req.finished_reason),
|
620
634
|
}
|
621
635
|
if req.return_logprob:
|
622
636
|
(
|
@@ -633,20 +647,20 @@ class ModelRpcServer:
|
|
633
647
|
req.normalized_prompt_logprob,
|
634
648
|
)
|
635
649
|
output_meta_info.append(meta_info)
|
636
|
-
|
650
|
+
output_finished_reason.append(req.finished_reason)
|
637
651
|
|
638
652
|
# Send to detokenizer
|
639
653
|
if output_rids:
|
640
654
|
self.out_pyobjs.append(
|
641
655
|
BatchTokenIDOut(
|
642
656
|
output_rids,
|
643
|
-
|
644
|
-
|
645
|
-
|
657
|
+
decoded_texts,
|
658
|
+
surr_output_ids,
|
659
|
+
read_output_ids,
|
646
660
|
output_skip_special_tokens,
|
647
661
|
output_spaces_between_special_tokens,
|
648
662
|
output_meta_info,
|
649
|
-
|
663
|
+
output_finished_reason,
|
650
664
|
)
|
651
665
|
)
|
652
666
|
|
@@ -657,7 +671,7 @@ class ModelRpcServer:
|
|
657
671
|
for i in finished_indices:
|
658
672
|
req = batch.reqs[i]
|
659
673
|
self.tree_cache.cache_req(
|
660
|
-
token_ids=tuple(req.
|
674
|
+
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
661
675
|
last_uncached_pos=len(req.prefix_indices),
|
662
676
|
req_pool_idx=req_pool_indices_cpu[i],
|
663
677
|
)
|
@@ -670,21 +684,67 @@ class ModelRpcServer:
|
|
670
684
|
else:
|
671
685
|
batch.reqs = []
|
672
686
|
|
687
|
+
def flush_cache(self):
|
688
|
+
if len(self.forward_queue) == 0 and (
|
689
|
+
self.running_batch is None or len(self.running_batch.reqs) == 0
|
690
|
+
):
|
691
|
+
self.tree_cache.reset()
|
692
|
+
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
693
|
+
self.regex_fsm_cache.reset()
|
694
|
+
self.req_to_token_pool.clear()
|
695
|
+
self.token_to_kv_pool.clear()
|
696
|
+
torch.cuda.empty_cache()
|
697
|
+
logger.info("Cache flushed successfully!")
|
698
|
+
else:
|
699
|
+
warnings.warn(
|
700
|
+
f"Cache not flushed because there are pending requests. "
|
701
|
+
f"#queue-req: {len(self.forward_queue)}, "
|
702
|
+
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
703
|
+
)
|
704
|
+
|
705
|
+
def abort_request(self, recv_req):
|
706
|
+
# Delete requests in the waiting queue
|
707
|
+
to_del = None
|
708
|
+
for i, req in enumerate(self.forward_queue):
|
709
|
+
if req.rid == recv_req.rid:
|
710
|
+
to_del = i
|
711
|
+
break
|
712
|
+
|
713
|
+
if to_del is not None:
|
714
|
+
del self.forward_queue[to_del]
|
715
|
+
|
716
|
+
# Delete requests in the running batch
|
717
|
+
if self.running_batch:
|
718
|
+
for req in self.running_batch.reqs:
|
719
|
+
if req.rid == recv_req.rid:
|
720
|
+
req.finished_reason = FINISH_ABORT()
|
721
|
+
break
|
673
722
|
|
674
|
-
class ModelRpcService(rpyc.Service):
|
675
|
-
exposed_ModelRpcServer = ModelRpcServer
|
676
723
|
|
724
|
+
class ModelTpService(rpyc.Service):
|
725
|
+
exposed_ModelTpServer = ModelTpServer
|
677
726
|
|
678
|
-
|
727
|
+
|
728
|
+
class ModelTpClient:
|
679
729
|
def __init__(
|
680
|
-
self,
|
730
|
+
self,
|
731
|
+
gpu_ids: List[int],
|
732
|
+
server_args: ServerArgs,
|
733
|
+
model_port_args: ModelPortArgs,
|
734
|
+
model_overide_args,
|
681
735
|
):
|
682
|
-
|
736
|
+
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
737
|
+
self.tp_size = server_args.tp_size
|
683
738
|
|
684
|
-
if tp_size == 1:
|
739
|
+
if self.tp_size * server_args.dp_size == 1:
|
685
740
|
# Init model
|
686
|
-
|
687
|
-
|
741
|
+
assert len(gpu_ids) == 1
|
742
|
+
self.model_server = ModelTpService().exposed_ModelTpServer(
|
743
|
+
0,
|
744
|
+
gpu_ids[0],
|
745
|
+
server_args,
|
746
|
+
model_port_args,
|
747
|
+
model_overide_args,
|
688
748
|
)
|
689
749
|
|
690
750
|
# Wrap functions
|
@@ -696,19 +756,31 @@ class ModelRpcClient:
|
|
696
756
|
|
697
757
|
self.step = async_wrap(self.model_server.exposed_step)
|
698
758
|
else:
|
699
|
-
with ThreadPoolExecutor(tp_size) as executor:
|
759
|
+
with ThreadPoolExecutor(self.tp_size) as executor:
|
700
760
|
# Launch model processes
|
701
|
-
|
702
|
-
|
703
|
-
|
761
|
+
if server_args.nnodes == 1:
|
762
|
+
self.procs = list(executor.map(
|
763
|
+
lambda args: start_rpyc_service_process(*args),
|
764
|
+
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
|
765
|
+
))
|
766
|
+
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
|
767
|
+
else:
|
768
|
+
addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)]
|
769
|
+
|
770
|
+
self.model_services = list(executor.map(
|
771
|
+
lambda args: connect_rpyc_service(*args), addrs))
|
704
772
|
|
705
773
|
# Init model
|
706
774
|
def init_model(i):
|
707
|
-
return self.
|
708
|
-
i,
|
775
|
+
return self.model_services[i].ModelTpServer(
|
776
|
+
gpu_ids[i],
|
777
|
+
i,
|
778
|
+
server_args,
|
779
|
+
model_port_args,
|
780
|
+
model_overide_args,
|
709
781
|
)
|
710
782
|
|
711
|
-
self.model_servers = executor.map(init_model, range(tp_size))
|
783
|
+
self.model_servers = list(executor.map(init_model, range(self.tp_size)))
|
712
784
|
|
713
785
|
# Wrap functions
|
714
786
|
def async_wrap(func_name):
|
@@ -722,44 +794,3 @@ class ModelRpcClient:
|
|
722
794
|
return _func
|
723
795
|
|
724
796
|
self.step = async_wrap("step")
|
725
|
-
|
726
|
-
|
727
|
-
def _init_service(port):
|
728
|
-
t = ThreadedServer(
|
729
|
-
ModelRpcService(),
|
730
|
-
port=port,
|
731
|
-
protocol_config={
|
732
|
-
"allow_public_attrs": True,
|
733
|
-
"allow_pickle": True,
|
734
|
-
"sync_request_timeout": 1800,
|
735
|
-
},
|
736
|
-
)
|
737
|
-
t.start()
|
738
|
-
|
739
|
-
|
740
|
-
def start_model_process(port):
|
741
|
-
proc = multiprocessing.Process(target=_init_service, args=(port,))
|
742
|
-
proc.start()
|
743
|
-
time.sleep(1)
|
744
|
-
|
745
|
-
repeat_count = 0
|
746
|
-
while repeat_count < 20:
|
747
|
-
try:
|
748
|
-
con = rpyc.connect(
|
749
|
-
"localhost",
|
750
|
-
port,
|
751
|
-
config={
|
752
|
-
"allow_public_attrs": True,
|
753
|
-
"allow_pickle": True,
|
754
|
-
"sync_request_timeout": 1800,
|
755
|
-
},
|
756
|
-
)
|
757
|
-
break
|
758
|
-
except ConnectionRefusedError:
|
759
|
-
time.sleep(1)
|
760
|
-
repeat_count += 1
|
761
|
-
if repeat_count == 20:
|
762
|
-
raise RuntimeError("init rpc env error!")
|
763
|
-
|
764
|
-
assert proc.is_alive()
|
765
|
-
return con.root, proc
|