sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +7 -1
- sglang/bench_latency.py +9 -6
- sglang/bench_serving.py +46 -22
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +4 -2
- sglang/lang/ir.py +16 -7
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/layers/activation.py +32 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +9 -2
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +7 -2
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +40 -16
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +115 -97
- sglang/srt/managers/tokenizer_manager.py +194 -112
- sglang/srt/managers/tp_worker.py +290 -359
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +71 -25
- sglang/srt/model_executor/forward_batch_info.py +293 -156
- sglang/srt/model_executor/model_runner.py +77 -57
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/deepseek.py +2 -2
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +11 -6
- sglang/srt/models/grok.py +50 -396
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/mixtral.py +56 -254
- sglang/srt/models/mixtral_quant.py +1 -4
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +2 -13
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +187 -48
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -8
- sglang/srt/server.py +91 -29
- sglang/srt/server_args.py +32 -19
- sglang/srt/utils.py +32 -15
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +81 -73
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +36 -7
- sglang/test/test_utils.py +24 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
- sglang-0.2.13.dist-info/RECORD +112 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -17,25 +17,30 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import logging
|
19
19
|
import multiprocessing
|
20
|
+
import os
|
20
21
|
import pickle
|
21
22
|
import time
|
22
23
|
import warnings
|
23
|
-
from typing import List, Optional
|
24
|
+
from typing import Any, List, Optional, Union
|
24
25
|
|
25
26
|
import torch
|
27
|
+
import torch.distributed
|
26
28
|
import torch.distributed as dist
|
27
29
|
|
28
30
|
from sglang.global_config import global_config
|
29
31
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
30
32
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
31
33
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
34
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
32
35
|
from sglang.srt.managers.io_struct import (
|
33
36
|
AbortReq,
|
37
|
+
BatchEmbeddingOut,
|
34
38
|
BatchTokenIDOut,
|
35
39
|
FlushCacheReq,
|
40
|
+
TokenizedEmbeddingReqInput,
|
36
41
|
TokenizedGenerateReqInput,
|
37
42
|
)
|
38
|
-
from sglang.srt.managers.policy_scheduler import PolicyScheduler
|
43
|
+
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
39
44
|
from sglang.srt.managers.schedule_batch import (
|
40
45
|
FINISH_ABORT,
|
41
46
|
BaseFinishReason,
|
@@ -49,7 +54,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
49
54
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
50
55
|
from sglang.srt.server_args import ServerArgs
|
51
56
|
from sglang.srt.utils import (
|
52
|
-
get_int_token_logit_bias,
|
53
57
|
is_multimodal_model,
|
54
58
|
set_random_seed,
|
55
59
|
suppress_other_loggers,
|
@@ -59,6 +63,9 @@ from sglang.utils import get_exception_traceback
|
|
59
63
|
logger = logging.getLogger(__name__)
|
60
64
|
|
61
65
|
|
66
|
+
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
67
|
+
|
68
|
+
|
62
69
|
class ModelTpServer:
|
63
70
|
def __init__(
|
64
71
|
self,
|
@@ -98,26 +105,24 @@ class ModelTpServer:
|
|
98
105
|
nccl_port=nccl_port,
|
99
106
|
server_args=server_args,
|
100
107
|
)
|
101
|
-
|
102
|
-
|
103
|
-
self.processor = get_processor(
|
104
|
-
server_args.tokenizer_path,
|
105
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
106
|
-
trust_remote_code=server_args.trust_remote_code,
|
107
|
-
)
|
108
|
-
self.tokenizer = self.processor.tokenizer
|
108
|
+
if server_args.skip_tokenizer_init:
|
109
|
+
self.tokenizer = self.processor = None
|
109
110
|
else:
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
111
|
+
if is_multimodal_model(server_args.model_path):
|
112
|
+
self.processor = get_processor(
|
113
|
+
server_args.tokenizer_path,
|
114
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
115
|
+
trust_remote_code=server_args.trust_remote_code,
|
116
|
+
)
|
117
|
+
self.tokenizer = self.processor.tokenizer
|
118
|
+
else:
|
119
|
+
self.tokenizer = get_tokenizer(
|
120
|
+
server_args.tokenizer_path,
|
121
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
122
|
+
trust_remote_code=server_args.trust_remote_code,
|
123
|
+
)
|
115
124
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
116
|
-
self.max_prefill_tokens =
|
117
|
-
16384
|
118
|
-
if server_args.max_prefill_tokens is None
|
119
|
-
else server_args.max_prefill_tokens
|
120
|
-
)
|
125
|
+
self.max_prefill_tokens = server_args.max_prefill_tokens
|
121
126
|
self.max_running_requests = min(
|
122
127
|
(
|
123
128
|
self.max_total_num_tokens // 2
|
@@ -126,9 +131,6 @@ class ModelTpServer:
|
|
126
131
|
),
|
127
132
|
self.model_runner.req_to_token_pool.size - 1,
|
128
133
|
)
|
129
|
-
self.int_token_logit_bias = torch.tensor(
|
130
|
-
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
131
|
-
)
|
132
134
|
self.max_req_input_len = min(
|
133
135
|
self.model_config.context_len - 1,
|
134
136
|
self.max_total_num_tokens - 1,
|
@@ -160,13 +162,7 @@ class ModelTpServer:
|
|
160
162
|
disable=server_args.disable_radix_cache,
|
161
163
|
)
|
162
164
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
163
|
-
self.scheduler = PolicyScheduler(
|
164
|
-
self.schedule_policy,
|
165
|
-
self.max_running_requests,
|
166
|
-
self.max_prefill_tokens,
|
167
|
-
self.max_total_num_tokens,
|
168
|
-
self.tree_cache,
|
169
|
-
)
|
165
|
+
self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache)
|
170
166
|
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
171
167
|
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
172
168
|
|
@@ -180,13 +176,15 @@ class ModelTpServer:
|
|
180
176
|
self.last_stats_tic = time.time()
|
181
177
|
|
182
178
|
# Init the FSM cache for constrained generation
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
179
|
+
if not server_args.skip_tokenizer_init:
|
180
|
+
self.regex_fsm_cache = FSMCache(
|
181
|
+
server_args.tokenizer_path,
|
182
|
+
{
|
183
|
+
"tokenizer_mode": server_args.tokenizer_mode,
|
184
|
+
"trust_remote_code": server_args.trust_remote_code,
|
185
|
+
},
|
186
|
+
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
187
|
+
)
|
190
188
|
self.jump_forward_cache = JumpForwardCache()
|
191
189
|
|
192
190
|
# Init new token estimation
|
@@ -201,11 +199,13 @@ class ModelTpServer:
|
|
201
199
|
self.new_token_ratio = self.min_new_token_ratio
|
202
200
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
203
201
|
|
204
|
-
def exposed_step(self, recv_reqs):
|
202
|
+
def exposed_step(self, recv_reqs: List):
|
205
203
|
try:
|
206
204
|
# Recv requests
|
207
205
|
for recv_req in recv_reqs:
|
208
|
-
if isinstance(
|
206
|
+
if isinstance(
|
207
|
+
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
208
|
+
):
|
209
209
|
self.handle_generate_request(recv_req)
|
210
210
|
elif isinstance(recv_req, FlushCacheReq):
|
211
211
|
self.flush_cache()
|
@@ -232,8 +232,6 @@ class ModelTpServer:
|
|
232
232
|
if new_batch is not None:
|
233
233
|
# Run a new prefill batch
|
234
234
|
self.forward_prefill_batch(new_batch)
|
235
|
-
self.cache_filled_batch(new_batch)
|
236
|
-
self.filter_out_inflight(new_batch)
|
237
235
|
|
238
236
|
if not new_batch.is_empty():
|
239
237
|
if self.running_batch is None:
|
@@ -250,7 +248,7 @@ class ModelTpServer:
|
|
250
248
|
|
251
249
|
# Print stats
|
252
250
|
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
253
|
-
self.
|
251
|
+
self.print_decode_stats()
|
254
252
|
|
255
253
|
if self.running_batch.is_empty():
|
256
254
|
self.running_batch = None
|
@@ -262,7 +260,7 @@ class ModelTpServer:
|
|
262
260
|
self.check_memory()
|
263
261
|
self.new_token_ratio = global_config.init_new_token_ratio
|
264
262
|
|
265
|
-
def
|
263
|
+
def print_decode_stats(self):
|
266
264
|
num_used = self.max_total_num_tokens - (
|
267
265
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
268
266
|
)
|
@@ -288,6 +286,7 @@ class ModelTpServer:
|
|
288
286
|
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
289
287
|
"KV cache pool leak detected!"
|
290
288
|
)
|
289
|
+
exit(1) if crash_on_warning else None
|
291
290
|
|
292
291
|
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
293
292
|
warnings.warn(
|
@@ -296,44 +295,46 @@ class ModelTpServer:
|
|
296
295
|
f"total slots={self.req_to_token_pool.size}\n"
|
297
296
|
"Memory pool leak detected!"
|
298
297
|
)
|
298
|
+
exit(1) if crash_on_warning else None
|
299
299
|
|
300
300
|
def handle_generate_request(
|
301
301
|
self,
|
302
|
-
recv_req: TokenizedGenerateReqInput,
|
302
|
+
recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
303
303
|
):
|
304
304
|
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
305
|
-
req.pixel_values = recv_req.pixel_values
|
306
|
-
if req.pixel_values is not None:
|
307
|
-
req.pad_value = [
|
308
|
-
(recv_req.image_hash) % self.model_config.vocab_size,
|
309
|
-
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
310
|
-
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
311
|
-
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
312
|
-
]
|
313
|
-
req.image_size = recv_req.image_size
|
314
|
-
(
|
315
|
-
req.origin_input_ids,
|
316
|
-
req.image_offset,
|
317
|
-
) = self.model_runner.model.pad_input_ids(
|
318
|
-
req.origin_input_ids_unpadded,
|
319
|
-
req.pad_value,
|
320
|
-
req.pixel_values.shape,
|
321
|
-
req.image_size,
|
322
|
-
)
|
323
|
-
req.sampling_params = recv_req.sampling_params
|
324
|
-
req.return_logprob = recv_req.return_logprob
|
325
|
-
req.logprob_start_len = recv_req.logprob_start_len
|
326
|
-
req.top_logprobs_num = recv_req.top_logprobs_num
|
327
|
-
req.stream = recv_req.stream
|
328
305
|
req.tokenizer = self.tokenizer
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
req.
|
333
|
-
|
334
|
-
|
335
|
-
|
306
|
+
req.sampling_params = recv_req.sampling_params
|
307
|
+
if self.model_runner.is_generation:
|
308
|
+
req.pixel_values = recv_req.pixel_values
|
309
|
+
if req.pixel_values is not None:
|
310
|
+
req.pad_value = [
|
311
|
+
(recv_req.image_hash) % self.model_config.vocab_size,
|
312
|
+
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
313
|
+
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
314
|
+
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
315
|
+
]
|
316
|
+
req.image_size = recv_req.image_size
|
317
|
+
(
|
318
|
+
req.origin_input_ids,
|
319
|
+
req.image_offset,
|
320
|
+
) = self.model_runner.model.pad_input_ids(
|
321
|
+
req.origin_input_ids_unpadded,
|
322
|
+
req.pad_value,
|
323
|
+
req.pixel_values.shape,
|
324
|
+
req.image_size,
|
336
325
|
)
|
326
|
+
req.return_logprob = recv_req.return_logprob
|
327
|
+
req.logprob_start_len = recv_req.logprob_start_len
|
328
|
+
req.top_logprobs_num = recv_req.top_logprobs_num
|
329
|
+
req.stream = recv_req.stream
|
330
|
+
|
331
|
+
# Init regex fsm
|
332
|
+
if req.sampling_params.regex is not None:
|
333
|
+
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
334
|
+
if not self.disable_regex_jump_forward:
|
335
|
+
req.jump_forward_map = self.jump_forward_cache.query(
|
336
|
+
req.sampling_params.regex
|
337
|
+
)
|
337
338
|
|
338
339
|
# Truncate prompts that are too long
|
339
340
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
@@ -342,186 +343,87 @@ class ModelTpServer:
|
|
342
343
|
"the max context length. Truncated!!!"
|
343
344
|
)
|
344
345
|
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
346
|
+
|
347
|
+
if self.model_runner.is_generation:
|
348
|
+
req.sampling_params.max_new_tokens = min(
|
349
|
+
(
|
350
|
+
req.sampling_params.max_new_tokens
|
351
|
+
if req.sampling_params.max_new_tokens is not None
|
352
|
+
else 1 << 30
|
353
|
+
),
|
354
|
+
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
355
|
+
)
|
356
|
+
|
353
357
|
self.waiting_queue.append(req)
|
354
358
|
|
355
359
|
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
356
|
-
# TODO(lsyin): organize this function
|
357
360
|
running_bs = (
|
358
361
|
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
359
362
|
)
|
360
363
|
if running_bs >= self.max_running_requests:
|
361
|
-
return
|
362
|
-
|
363
|
-
# Compute matched prefix length
|
364
|
-
for req in self.waiting_queue:
|
365
|
-
req.input_ids = req.origin_input_ids + req.output_ids
|
366
|
-
try_match_ids = req.input_ids
|
367
|
-
if req.return_logprob:
|
368
|
-
try_match_ids = req.input_ids[: req.logprob_start_len]
|
369
|
-
# NOTE: the prefix_indices must always be aligned with last_node
|
370
|
-
prefix_indices, last_node = self.tree_cache.match_prefix(
|
371
|
-
rid=req.rid, key=try_match_ids
|
372
|
-
)
|
373
|
-
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
374
|
-
req.prefix_indices = prefix_indices
|
375
|
-
req.last_node = last_node
|
364
|
+
return None
|
376
365
|
|
377
366
|
# Get priority queue
|
378
|
-
|
379
|
-
|
380
|
-
# Add requests if there is available space
|
381
|
-
can_run_list = []
|
382
|
-
new_batch_total_tokens = 0
|
383
|
-
new_batch_input_tokens = 0
|
367
|
+
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
|
384
368
|
|
385
|
-
|
386
|
-
self.
|
369
|
+
adder = PrefillAdder(
|
370
|
+
self.tree_cache,
|
371
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
372
|
+
self.max_prefill_tokens,
|
373
|
+
self.chunked_prefill_size,
|
387
374
|
)
|
388
|
-
if self.running_batch:
|
389
|
-
available_size -= sum(
|
390
|
-
[
|
391
|
-
(r.sampling_params.max_new_tokens - len(r.output_ids))
|
392
|
-
* self.new_token_ratio
|
393
|
-
for r in self.running_batch.reqs
|
394
|
-
]
|
395
|
-
)
|
396
375
|
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
|
376
|
+
if self.running_batch is not None:
|
377
|
+
adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
|
378
|
+
|
379
|
+
has_inflight = self.current_inflight_req is not None
|
380
|
+
if self.current_inflight_req is not None:
|
381
|
+
self.current_inflight_req.init_next_round_input(
|
382
|
+
None if prefix_computed else self.tree_cache
|
405
383
|
)
|
406
|
-
|
407
|
-
|
384
|
+
self.current_inflight_req = adder.add_inflight_req(
|
385
|
+
self.current_inflight_req
|
408
386
|
)
|
409
|
-
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
|
410
|
-
can_run_list.append(r)
|
411
|
-
|
412
|
-
if not truncated:
|
413
|
-
# Finish inflight
|
414
|
-
self.current_inflight_req = None
|
415
|
-
new_batch_total_tokens += (
|
416
|
-
r.extend_input_len + r.sampling_params.max_new_tokens
|
417
|
-
)
|
418
|
-
new_batch_input_tokens += r.extend_input_len
|
419
|
-
else:
|
420
|
-
new_batch_total_tokens += r.extend_input_len
|
421
|
-
new_batch_input_tokens += r.extend_input_len
|
422
387
|
|
423
388
|
for req in self.waiting_queue:
|
424
|
-
|
425
|
-
|
426
|
-
if req.extend_input_len < 2:
|
427
|
-
delta = 2 - req.extend_input_len
|
428
|
-
req.extend_input_len += delta
|
429
|
-
req.prefix_indices = req.prefix_indices[:-delta]
|
430
|
-
if req.image_offset is not None:
|
431
|
-
req.image_offset += delta
|
432
|
-
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
|
433
|
-
# Need at least one token to compute logits
|
434
|
-
req.extend_input_len = 1
|
435
|
-
req.prefix_indices = req.prefix_indices[:-1]
|
436
|
-
if req.image_offset is not None:
|
437
|
-
req.image_offset += 1
|
438
|
-
|
389
|
+
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
390
|
+
res = adder.add_one_req(req)
|
439
391
|
if (
|
440
|
-
|
441
|
-
|
442
|
-
+
|
443
|
-
< available_size
|
444
|
-
and (
|
445
|
-
req.extend_input_len + new_batch_input_tokens
|
446
|
-
<= self.max_prefill_tokens
|
447
|
-
or len(can_run_list) == 0
|
448
|
-
)
|
392
|
+
not res
|
393
|
+
or adder.no_remaining_tokens()
|
394
|
+
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
449
395
|
):
|
450
|
-
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
451
|
-
available_size += delta
|
452
|
-
|
453
|
-
if not (
|
454
|
-
req.extend_input_len
|
455
|
-
+ req.sampling_params.max_new_tokens
|
456
|
-
+ new_batch_total_tokens
|
457
|
-
< available_size
|
458
|
-
):
|
459
|
-
# Undo locking
|
460
|
-
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
461
|
-
available_size += delta
|
462
|
-
break
|
463
|
-
else:
|
464
|
-
# Add this request to the running batch
|
465
|
-
if (
|
466
|
-
self.chunked_prefill_size is None
|
467
|
-
or (
|
468
|
-
new_batch_input_tokens + req.extend_input_len
|
469
|
-
<= self.chunked_prefill_size
|
470
|
-
)
|
471
|
-
or (
|
472
|
-
req.return_logprob and req.normalized_prompt_logprob is None
|
473
|
-
)
|
474
|
-
):
|
475
|
-
can_run_list.append(req)
|
476
|
-
new_batch_total_tokens += (
|
477
|
-
req.extend_input_len + req.sampling_params.max_new_tokens
|
478
|
-
)
|
479
|
-
new_batch_input_tokens += req.extend_input_len
|
480
|
-
else:
|
481
|
-
trunc_len = self.chunked_prefill_size - new_batch_input_tokens
|
482
|
-
|
483
|
-
if trunc_len <= 0:
|
484
|
-
# Undo locking
|
485
|
-
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
486
|
-
available_size += delta
|
487
|
-
break
|
488
|
-
|
489
|
-
req.extend_input_len = trunc_len
|
490
|
-
req.input_ids = req.input_ids[
|
491
|
-
: len(req.prefix_indices) + req.extend_input_len
|
492
|
-
]
|
493
|
-
can_run_list.append(req)
|
494
|
-
self.current_inflight_req = req
|
495
|
-
new_batch_input_tokens += req.extend_input_len
|
496
|
-
new_batch_total_tokens += req.extend_input_len
|
497
|
-
break
|
498
|
-
else:
|
499
396
|
break
|
500
397
|
|
501
|
-
|
502
|
-
|
398
|
+
can_run_list = adder.can_run_list
|
399
|
+
|
400
|
+
if adder.new_inflight_req is not None:
|
401
|
+
assert self.current_inflight_req is None
|
402
|
+
self.current_inflight_req = adder.new_inflight_req
|
503
403
|
|
504
404
|
if len(can_run_list) == 0:
|
505
405
|
return None
|
506
406
|
|
507
407
|
# Print stats
|
508
408
|
if self.tp_rank == 0:
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
409
|
+
if isinstance(self.tree_cache, RadixCache):
|
410
|
+
self.tree_cache_metrics["total"] += (
|
411
|
+
adder.log_input_tokens + adder.log_hit_tokens
|
412
|
+
) / 10**9
|
413
|
+
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
414
|
+
tree_cache_hit_rate = (
|
415
|
+
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
416
|
+
)
|
417
|
+
else:
|
418
|
+
tree_cache_hit_rate = 0.0
|
517
419
|
logger.info(
|
518
420
|
f"[gpu={self.gpu_id}] Prefill batch. "
|
519
421
|
f"#new-seq: {len(can_run_list)}, "
|
520
|
-
f"#new-token: {
|
521
|
-
f"#cached-token: {
|
422
|
+
f"#new-token: {adder.log_input_tokens}, "
|
423
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
522
424
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
523
425
|
f"#running-req: {running_bs}, "
|
524
|
-
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) +
|
426
|
+
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
525
427
|
)
|
526
428
|
|
527
429
|
# Return the new batch
|
@@ -536,45 +438,90 @@ class ModelTpServer:
|
|
536
438
|
|
537
439
|
def forward_prefill_batch(self, batch: ScheduleBatch):
|
538
440
|
# Build batch tensors
|
539
|
-
batch.prepare_for_extend(
|
540
|
-
|
541
|
-
|
441
|
+
batch.prepare_for_extend(self.model_config.vocab_size)
|
442
|
+
|
443
|
+
if self.model_runner.is_generation:
|
444
|
+
# Forward and sample the next tokens
|
445
|
+
if batch.extend_num_tokens != 0:
|
446
|
+
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
447
|
+
next_token_ids = batch.sample(output.next_token_logits)
|
448
|
+
|
449
|
+
# Move logprobs to cpu
|
450
|
+
if output.next_token_logprobs is not None:
|
451
|
+
output.next_token_logprobs = output.next_token_logprobs[
|
452
|
+
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
453
|
+
next_token_ids,
|
454
|
+
].tolist()
|
455
|
+
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
456
|
+
output.normalized_prompt_logprobs = (
|
457
|
+
output.normalized_prompt_logprobs.tolist()
|
458
|
+
)
|
542
459
|
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
next_token_ids
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
460
|
+
next_token_ids = next_token_ids.tolist()
|
461
|
+
else:
|
462
|
+
if self.tokenizer is None:
|
463
|
+
next_token_ids = []
|
464
|
+
for req in batch.reqs:
|
465
|
+
next_token_ids.append(
|
466
|
+
next(iter(req.sampling_params.stop_token_ids))
|
467
|
+
)
|
468
|
+
else:
|
469
|
+
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
470
|
+
|
471
|
+
# Check finish conditions
|
472
|
+
pt = 0
|
473
|
+
for i, req in enumerate(batch.reqs):
|
474
|
+
if req is not self.current_inflight_req:
|
475
|
+
# Inflight reqs' prefill is not finished
|
476
|
+
req.completion_tokens_wo_jump_forward += 1
|
477
|
+
req.output_ids.append(next_token_ids[i])
|
478
|
+
req.check_finished()
|
479
|
+
|
480
|
+
if req.finished():
|
481
|
+
self.tree_cache.cache_finished_req(req)
|
482
|
+
else:
|
483
|
+
self.tree_cache.cache_unfinished_req(req)
|
558
484
|
|
559
|
-
|
560
|
-
|
561
|
-
|
485
|
+
if req is self.current_inflight_req:
|
486
|
+
# Inflight request would get a new req idx
|
487
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
562
488
|
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
489
|
+
if req.return_logprob:
|
490
|
+
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
491
|
+
pt += req.extend_input_len
|
492
|
+
else:
|
493
|
+
assert batch.extend_num_tokens != 0
|
494
|
+
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
495
|
+
embeddings = output.embeddings.tolist()
|
496
|
+
|
497
|
+
# Check finish conditions
|
498
|
+
for i, req in enumerate(batch.reqs):
|
499
|
+
req.embedding = embeddings[i]
|
500
|
+
if req is not self.current_inflight_req:
|
501
|
+
# Inflight reqs' prefill is not finished
|
502
|
+
# dummy output token for embedding models
|
503
|
+
req.output_ids.append(0)
|
504
|
+
req.check_finished()
|
505
|
+
|
506
|
+
if req.finished():
|
507
|
+
self.tree_cache.cache_finished_req(req)
|
508
|
+
else:
|
509
|
+
self.tree_cache.cache_unfinished_req(req)
|
570
510
|
|
571
|
-
|
572
|
-
|
573
|
-
|
511
|
+
if req is self.current_inflight_req:
|
512
|
+
# Inflight request would get a new req idx
|
513
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
574
514
|
|
575
515
|
self.handle_finished_requests(batch)
|
576
516
|
|
577
|
-
def add_logprob_return_values(
|
517
|
+
def add_logprob_return_values(
|
518
|
+
self,
|
519
|
+
i,
|
520
|
+
req: Req,
|
521
|
+
pt: int,
|
522
|
+
next_token_ids: List[int],
|
523
|
+
output: LogitProcessorOutput,
|
524
|
+
):
|
578
525
|
if req.normalized_prompt_logprob is None:
|
579
526
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
580
527
|
|
@@ -583,12 +530,12 @@ class ModelTpServer:
|
|
583
530
|
req.input_token_logprobs = list(
|
584
531
|
zip(
|
585
532
|
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
|
586
|
-
req.
|
533
|
+
req.fill_ids[-req.extend_input_len + 1 :],
|
587
534
|
)
|
588
535
|
)
|
589
536
|
if req.logprob_start_len == 0:
|
590
537
|
req.input_token_logprobs = [
|
591
|
-
(None, req.
|
538
|
+
(None, req.fill_ids[0])
|
592
539
|
] + req.input_token_logprobs
|
593
540
|
|
594
541
|
if req.last_update_decode_tokens != 0:
|
@@ -602,7 +549,7 @@ class ModelTpServer:
|
|
602
549
|
+ req.extend_input_len
|
603
550
|
- 1
|
604
551
|
],
|
605
|
-
req.
|
552
|
+
req.fill_ids[-req.last_update_decode_tokens + 1 :],
|
606
553
|
)
|
607
554
|
)
|
608
555
|
)
|
@@ -623,22 +570,6 @@ class ModelTpServer:
|
|
623
570
|
)
|
624
571
|
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
625
572
|
|
626
|
-
def cache_filled_batch(self, batch: ScheduleBatch):
|
627
|
-
for i, req in enumerate(batch.reqs):
|
628
|
-
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
629
|
-
rid=req.rid,
|
630
|
-
token_ids=tuple(req.input_ids),
|
631
|
-
last_uncached_pos=len(req.prefix_indices),
|
632
|
-
req_pool_idx=req.req_pool_idx,
|
633
|
-
del_in_memory_pool=False,
|
634
|
-
old_last_node=req.last_node,
|
635
|
-
)
|
636
|
-
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
637
|
-
|
638
|
-
if req is self.current_inflight_req:
|
639
|
-
# inflight request would get a new req idx
|
640
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
641
|
-
|
642
573
|
def forward_decode_batch(self, batch: ScheduleBatch):
|
643
574
|
# Check if decode out of memory
|
644
575
|
if not batch.check_decode_mem():
|
@@ -689,6 +620,9 @@ class ModelTpServer:
|
|
689
620
|
req.output_ids.append(next_token_id)
|
690
621
|
req.check_finished()
|
691
622
|
|
623
|
+
if req.finished():
|
624
|
+
self.tree_cache.cache_finished_req(req)
|
625
|
+
|
692
626
|
if req.return_logprob:
|
693
627
|
req.output_token_logprobs.append(
|
694
628
|
(next_token_logprobs[i], next_token_id)
|
@@ -700,20 +634,21 @@ class ModelTpServer:
|
|
700
634
|
|
701
635
|
def handle_finished_requests(self, batch: ScheduleBatch):
|
702
636
|
output_rids = []
|
703
|
-
output_vids = []
|
704
|
-
decoded_texts = []
|
705
|
-
output_read_ids = []
|
706
|
-
output_read_offsets = []
|
707
|
-
output_skip_special_tokens = []
|
708
|
-
output_spaces_between_special_tokens = []
|
709
637
|
output_meta_info = []
|
710
638
|
output_finished_reason: List[BaseFinishReason] = []
|
711
|
-
|
639
|
+
if self.model_runner.is_generation:
|
640
|
+
output_vids = []
|
641
|
+
decoded_texts = []
|
642
|
+
output_read_ids = []
|
643
|
+
output_read_offsets = []
|
644
|
+
output_skip_special_tokens = []
|
645
|
+
output_spaces_between_special_tokens = []
|
646
|
+
else: # for embedding model
|
647
|
+
output_embeddings = []
|
712
648
|
unfinished_indices = []
|
649
|
+
|
713
650
|
for i, req in enumerate(batch.reqs):
|
714
|
-
if req.finished():
|
715
|
-
finished_indices.append(i)
|
716
|
-
else:
|
651
|
+
if not req.finished() and req is not self.current_inflight_req:
|
717
652
|
unfinished_indices.append(i)
|
718
653
|
|
719
654
|
if req.finished() or (
|
@@ -726,85 +661,75 @@ class ModelTpServer:
|
|
726
661
|
)
|
727
662
|
):
|
728
663
|
output_rids.append(req.rid)
|
729
|
-
output_vids.append(req.vid)
|
730
|
-
decoded_texts.append(req.decoded_text)
|
731
|
-
read_ids, read_offset = req.init_incremental_detokenize()
|
732
|
-
output_read_ids.append(read_ids)
|
733
|
-
output_read_offsets.append(read_offset)
|
734
|
-
output_skip_special_tokens.append(
|
735
|
-
req.sampling_params.skip_special_tokens
|
736
|
-
)
|
737
|
-
output_spaces_between_special_tokens.append(
|
738
|
-
req.sampling_params.spaces_between_special_tokens
|
739
|
-
)
|
740
|
-
|
741
|
-
meta_info = {
|
742
|
-
"prompt_tokens": len(req.origin_input_ids),
|
743
|
-
"completion_tokens": len(req.output_ids),
|
744
|
-
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
745
|
-
"finish_reason": str(req.finished_reason),
|
746
|
-
}
|
747
|
-
if req.return_logprob:
|
748
|
-
(
|
749
|
-
meta_info["input_token_logprobs"],
|
750
|
-
meta_info["output_token_logprobs"],
|
751
|
-
meta_info["input_top_logprobs"],
|
752
|
-
meta_info["output_top_logprobs"],
|
753
|
-
meta_info["normalized_prompt_logprob"],
|
754
|
-
) = (
|
755
|
-
req.input_token_logprobs,
|
756
|
-
req.output_token_logprobs,
|
757
|
-
req.input_top_logprobs,
|
758
|
-
req.output_top_logprobs,
|
759
|
-
req.normalized_prompt_logprob,
|
760
|
-
)
|
761
|
-
output_meta_info.append(meta_info)
|
762
664
|
output_finished_reason.append(req.finished_reason)
|
665
|
+
if self.model_runner.is_generation:
|
666
|
+
output_vids.append(req.vid)
|
667
|
+
decoded_texts.append(req.decoded_text)
|
668
|
+
read_ids, read_offset = req.init_incremental_detokenize()
|
669
|
+
output_read_ids.append(read_ids)
|
670
|
+
output_read_offsets.append(read_offset)
|
671
|
+
output_skip_special_tokens.append(
|
672
|
+
req.sampling_params.skip_special_tokens
|
673
|
+
)
|
674
|
+
output_spaces_between_special_tokens.append(
|
675
|
+
req.sampling_params.spaces_between_special_tokens
|
676
|
+
)
|
677
|
+
|
678
|
+
meta_info = {
|
679
|
+
"prompt_tokens": len(req.origin_input_ids),
|
680
|
+
"completion_tokens": len(req.output_ids),
|
681
|
+
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
682
|
+
"finish_reason": str(req.finished_reason),
|
683
|
+
}
|
684
|
+
if req.return_logprob:
|
685
|
+
(
|
686
|
+
meta_info["input_token_logprobs"],
|
687
|
+
meta_info["output_token_logprobs"],
|
688
|
+
meta_info["input_top_logprobs"],
|
689
|
+
meta_info["output_top_logprobs"],
|
690
|
+
meta_info["normalized_prompt_logprob"],
|
691
|
+
) = (
|
692
|
+
req.input_token_logprobs,
|
693
|
+
req.output_token_logprobs,
|
694
|
+
req.input_top_logprobs,
|
695
|
+
req.output_top_logprobs,
|
696
|
+
req.normalized_prompt_logprob,
|
697
|
+
)
|
698
|
+
output_meta_info.append(meta_info)
|
699
|
+
else: # for embedding model
|
700
|
+
output_embeddings.append(req.embedding)
|
701
|
+
meta_info = {
|
702
|
+
"prompt_tokens": len(req.origin_input_ids),
|
703
|
+
}
|
704
|
+
output_meta_info.append(meta_info)
|
763
705
|
|
764
706
|
# Send to detokenizer
|
765
707
|
if output_rids:
|
766
|
-
self.
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
708
|
+
if self.model_runner.is_generation:
|
709
|
+
self.out_pyobjs.append(
|
710
|
+
BatchTokenIDOut(
|
711
|
+
output_rids,
|
712
|
+
output_vids,
|
713
|
+
decoded_texts,
|
714
|
+
output_read_ids,
|
715
|
+
output_read_offsets,
|
716
|
+
output_skip_special_tokens,
|
717
|
+
output_spaces_between_special_tokens,
|
718
|
+
output_meta_info,
|
719
|
+
output_finished_reason,
|
720
|
+
)
|
777
721
|
)
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
rid=req.rid,
|
787
|
-
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
788
|
-
last_uncached_pos=len(req.prefix_indices),
|
789
|
-
req_pool_idx=req.req_pool_idx,
|
722
|
+
else: # for embedding model
|
723
|
+
self.out_pyobjs.append(
|
724
|
+
BatchEmbeddingOut(
|
725
|
+
output_rids,
|
726
|
+
output_embeddings,
|
727
|
+
output_meta_info,
|
728
|
+
output_finished_reason,
|
729
|
+
)
|
790
730
|
)
|
791
731
|
|
792
|
-
|
793
|
-
|
794
|
-
# Update batch tensors
|
795
|
-
if unfinished_indices:
|
796
|
-
batch.filter_batch(unfinished_indices)
|
797
|
-
else:
|
798
|
-
batch.reqs = []
|
799
|
-
|
800
|
-
def filter_out_inflight(self, batch: ScheduleBatch):
|
801
|
-
# TODO(lsyin): reduce the overhead, make a special version for this
|
802
|
-
if self.current_inflight_req is None:
|
803
|
-
return
|
804
|
-
|
805
|
-
to_remove = batch.reqs.index(self.current_inflight_req)
|
806
|
-
unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
|
807
|
-
|
732
|
+
# Remove finished reqs: update batch tensors
|
808
733
|
batch.filter_batch(unfinished_indices)
|
809
734
|
|
810
735
|
def flush_cache(self):
|
@@ -871,7 +796,11 @@ def run_tp_server(
|
|
871
796
|
|
872
797
|
|
873
798
|
def launch_tp_servers(
|
874
|
-
gpu_ids
|
799
|
+
gpu_ids: List[int],
|
800
|
+
tp_rank_range: List[int],
|
801
|
+
server_args: ServerArgs,
|
802
|
+
nccl_port: int,
|
803
|
+
model_overide_args: dict,
|
875
804
|
):
|
876
805
|
"""Launch multiple tensor parallel servers."""
|
877
806
|
procs = []
|
@@ -886,7 +815,9 @@ def launch_tp_servers(
|
|
886
815
|
return procs
|
887
816
|
|
888
817
|
|
889
|
-
def broadcast_recv_input(
|
818
|
+
def broadcast_recv_input(
|
819
|
+
data: Any, rank: int, dist_group: torch.distributed.ProcessGroup
|
820
|
+
):
|
890
821
|
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
891
822
|
|
892
823
|
if rank == 0:
|