sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +151 -40
- sglang/bench_serving.py +46 -22
- sglang/check_env.py +24 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -29
- sglang/lang/choices.py +164 -0
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +14 -5
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +6 -1
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +4 -7
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +174 -380
- sglang/srt/managers/tokenizer_manager.py +197 -112
- sglang/srt/managers/tp_worker.py +299 -364
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +10 -15
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +27 -12
- sglang/srt/model_executor/forward_batch_info.py +319 -0
- sglang/srt/model_executor/model_runner.py +30 -47
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +1 -1
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -2
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +3 -8
- sglang/srt/models/llama2.py +5 -5
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -12
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +189 -39
- sglang/srt/openai_api/protocol.py +43 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +93 -21
- sglang/srt/server_args.py +30 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +21 -3
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.10.dist-info/RECORD +0 -100
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -17,35 +17,40 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import logging
|
19
19
|
import multiprocessing
|
20
|
+
import os
|
20
21
|
import pickle
|
21
22
|
import time
|
22
23
|
import warnings
|
23
|
-
from typing import List, Optional
|
24
|
+
from typing import Any, List, Optional, Union
|
24
25
|
|
25
26
|
import torch
|
27
|
+
import torch.distributed
|
26
28
|
import torch.distributed as dist
|
27
29
|
|
28
30
|
from sglang.global_config import global_config
|
29
31
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
30
32
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
31
33
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
34
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
32
35
|
from sglang.srt.managers.io_struct import (
|
33
36
|
AbortReq,
|
37
|
+
BatchEmbeddingOut,
|
34
38
|
BatchTokenIDOut,
|
35
39
|
FlushCacheReq,
|
40
|
+
TokenizedEmbeddingReqInput,
|
36
41
|
TokenizedGenerateReqInput,
|
37
42
|
)
|
38
|
-
from sglang.srt.managers.policy_scheduler import PolicyScheduler
|
43
|
+
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
39
44
|
from sglang.srt.managers.schedule_batch import (
|
40
45
|
FINISH_ABORT,
|
41
46
|
BaseFinishReason,
|
42
|
-
Batch,
|
43
|
-
ForwardMode,
|
44
47
|
Req,
|
48
|
+
ScheduleBatch,
|
45
49
|
)
|
46
50
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
47
51
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
48
52
|
from sglang.srt.model_config import ModelConfig
|
53
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
49
54
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
50
55
|
from sglang.srt.server_args import ServerArgs
|
51
56
|
from sglang.srt.utils import (
|
@@ -59,6 +64,9 @@ from sglang.utils import get_exception_traceback
|
|
59
64
|
logger = logging.getLogger(__name__)
|
60
65
|
|
61
66
|
|
67
|
+
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
68
|
+
|
69
|
+
|
62
70
|
class ModelTpServer:
|
63
71
|
def __init__(
|
64
72
|
self,
|
@@ -98,26 +106,24 @@ class ModelTpServer:
|
|
98
106
|
nccl_port=nccl_port,
|
99
107
|
server_args=server_args,
|
100
108
|
)
|
101
|
-
|
102
|
-
|
103
|
-
self.processor = get_processor(
|
104
|
-
server_args.tokenizer_path,
|
105
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
106
|
-
trust_remote_code=server_args.trust_remote_code,
|
107
|
-
)
|
108
|
-
self.tokenizer = self.processor.tokenizer
|
109
|
+
if server_args.skip_tokenizer_init:
|
110
|
+
self.tokenizer = self.processor = None
|
109
111
|
else:
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
112
|
+
if is_multimodal_model(server_args.model_path):
|
113
|
+
self.processor = get_processor(
|
114
|
+
server_args.tokenizer_path,
|
115
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
116
|
+
trust_remote_code=server_args.trust_remote_code,
|
117
|
+
)
|
118
|
+
self.tokenizer = self.processor.tokenizer
|
119
|
+
else:
|
120
|
+
self.tokenizer = get_tokenizer(
|
121
|
+
server_args.tokenizer_path,
|
122
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
123
|
+
trust_remote_code=server_args.trust_remote_code,
|
124
|
+
)
|
115
125
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
116
|
-
self.max_prefill_tokens =
|
117
|
-
16384
|
118
|
-
if server_args.max_prefill_tokens is None
|
119
|
-
else server_args.max_prefill_tokens
|
120
|
-
)
|
126
|
+
self.max_prefill_tokens = server_args.max_prefill_tokens
|
121
127
|
self.max_running_requests = min(
|
122
128
|
(
|
123
129
|
self.max_total_num_tokens // 2
|
@@ -160,19 +166,13 @@ class ModelTpServer:
|
|
160
166
|
disable=server_args.disable_radix_cache,
|
161
167
|
)
|
162
168
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
163
|
-
self.scheduler = PolicyScheduler(
|
164
|
-
self.schedule_policy,
|
165
|
-
self.max_running_requests,
|
166
|
-
self.max_prefill_tokens,
|
167
|
-
self.max_total_num_tokens,
|
168
|
-
self.tree_cache,
|
169
|
-
)
|
169
|
+
self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache)
|
170
170
|
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
171
171
|
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
172
172
|
|
173
173
|
# Init running status
|
174
174
|
self.waiting_queue: List[Req] = []
|
175
|
-
self.running_batch:
|
175
|
+
self.running_batch: ScheduleBatch = None
|
176
176
|
self.out_pyobjs = []
|
177
177
|
self.decode_forward_ct = 0
|
178
178
|
self.stream_interval = server_args.stream_interval
|
@@ -180,13 +180,15 @@ class ModelTpServer:
|
|
180
180
|
self.last_stats_tic = time.time()
|
181
181
|
|
182
182
|
# Init the FSM cache for constrained generation
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
183
|
+
if not server_args.skip_tokenizer_init:
|
184
|
+
self.regex_fsm_cache = FSMCache(
|
185
|
+
server_args.tokenizer_path,
|
186
|
+
{
|
187
|
+
"tokenizer_mode": server_args.tokenizer_mode,
|
188
|
+
"trust_remote_code": server_args.trust_remote_code,
|
189
|
+
},
|
190
|
+
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
191
|
+
)
|
190
192
|
self.jump_forward_cache = JumpForwardCache()
|
191
193
|
|
192
194
|
# Init new token estimation
|
@@ -200,13 +202,14 @@ class ModelTpServer:
|
|
200
202
|
)
|
201
203
|
self.new_token_ratio = self.min_new_token_ratio
|
202
204
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
203
|
-
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
204
205
|
|
205
|
-
def exposed_step(self, recv_reqs):
|
206
|
+
def exposed_step(self, recv_reqs: List):
|
206
207
|
try:
|
207
208
|
# Recv requests
|
208
209
|
for recv_req in recv_reqs:
|
209
|
-
if isinstance(
|
210
|
+
if isinstance(
|
211
|
+
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
212
|
+
):
|
210
213
|
self.handle_generate_request(recv_req)
|
211
214
|
elif isinstance(recv_req, FlushCacheReq):
|
212
215
|
self.flush_cache()
|
@@ -233,8 +236,6 @@ class ModelTpServer:
|
|
233
236
|
if new_batch is not None:
|
234
237
|
# Run a new prefill batch
|
235
238
|
self.forward_prefill_batch(new_batch)
|
236
|
-
self.cache_filled_batch(new_batch)
|
237
|
-
self.filter_out_inflight(new_batch)
|
238
239
|
|
239
240
|
if not new_batch.is_empty():
|
240
241
|
if self.running_batch is None:
|
@@ -251,7 +252,7 @@ class ModelTpServer:
|
|
251
252
|
|
252
253
|
# Print stats
|
253
254
|
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
254
|
-
self.
|
255
|
+
self.print_decode_stats()
|
255
256
|
|
256
257
|
if self.running_batch.is_empty():
|
257
258
|
self.running_batch = None
|
@@ -263,7 +264,7 @@ class ModelTpServer:
|
|
263
264
|
self.check_memory()
|
264
265
|
self.new_token_ratio = global_config.init_new_token_ratio
|
265
266
|
|
266
|
-
def
|
267
|
+
def print_decode_stats(self):
|
267
268
|
num_used = self.max_total_num_tokens - (
|
268
269
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
269
270
|
)
|
@@ -289,52 +290,55 @@ class ModelTpServer:
|
|
289
290
|
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
290
291
|
"KV cache pool leak detected!"
|
291
292
|
)
|
293
|
+
exit(1) if crash_on_warning else None
|
292
294
|
|
293
|
-
if self.req_to_token_pool.
|
295
|
+
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
294
296
|
warnings.warn(
|
295
297
|
"Warning: "
|
296
|
-
f"available req slots={self.req_to_token_pool.
|
298
|
+
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
|
297
299
|
f"total slots={self.req_to_token_pool.size}\n"
|
298
300
|
"Memory pool leak detected!"
|
299
301
|
)
|
302
|
+
exit(1) if crash_on_warning else None
|
300
303
|
|
301
304
|
def handle_generate_request(
|
302
305
|
self,
|
303
|
-
recv_req: TokenizedGenerateReqInput,
|
306
|
+
recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
304
307
|
):
|
305
308
|
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
306
|
-
req.pixel_values = recv_req.pixel_values
|
307
|
-
if req.pixel_values is not None:
|
308
|
-
req.pad_value = [
|
309
|
-
(recv_req.image_hash) % self.model_config.vocab_size,
|
310
|
-
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
311
|
-
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
312
|
-
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
313
|
-
]
|
314
|
-
req.image_size = recv_req.image_size
|
315
|
-
(
|
316
|
-
req.origin_input_ids,
|
317
|
-
req.image_offset,
|
318
|
-
) = self.model_runner.model.pad_input_ids(
|
319
|
-
req.origin_input_ids_unpadded,
|
320
|
-
req.pad_value,
|
321
|
-
req.pixel_values.shape,
|
322
|
-
req.image_size,
|
323
|
-
)
|
324
|
-
req.sampling_params = recv_req.sampling_params
|
325
|
-
req.return_logprob = recv_req.return_logprob
|
326
|
-
req.logprob_start_len = recv_req.logprob_start_len
|
327
|
-
req.top_logprobs_num = recv_req.top_logprobs_num
|
328
|
-
req.stream = recv_req.stream
|
329
309
|
req.tokenizer = self.tokenizer
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
req.
|
334
|
-
|
335
|
-
|
336
|
-
|
310
|
+
req.sampling_params = recv_req.sampling_params
|
311
|
+
if self.model_runner.is_generation:
|
312
|
+
req.pixel_values = recv_req.pixel_values
|
313
|
+
if req.pixel_values is not None:
|
314
|
+
req.pad_value = [
|
315
|
+
(recv_req.image_hash) % self.model_config.vocab_size,
|
316
|
+
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
317
|
+
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
318
|
+
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
319
|
+
]
|
320
|
+
req.image_size = recv_req.image_size
|
321
|
+
(
|
322
|
+
req.origin_input_ids,
|
323
|
+
req.image_offset,
|
324
|
+
) = self.model_runner.model.pad_input_ids(
|
325
|
+
req.origin_input_ids_unpadded,
|
326
|
+
req.pad_value,
|
327
|
+
req.pixel_values.shape,
|
328
|
+
req.image_size,
|
337
329
|
)
|
330
|
+
req.return_logprob = recv_req.return_logprob
|
331
|
+
req.logprob_start_len = recv_req.logprob_start_len
|
332
|
+
req.top_logprobs_num = recv_req.top_logprobs_num
|
333
|
+
req.stream = recv_req.stream
|
334
|
+
|
335
|
+
# Init regex fsm
|
336
|
+
if req.sampling_params.regex is not None:
|
337
|
+
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
338
|
+
if not self.disable_regex_jump_forward:
|
339
|
+
req.jump_forward_map = self.jump_forward_cache.query(
|
340
|
+
req.sampling_params.regex
|
341
|
+
)
|
338
342
|
|
339
343
|
# Truncate prompts that are too long
|
340
344
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
@@ -343,189 +347,91 @@ class ModelTpServer:
|
|
343
347
|
"the max context length. Truncated!!!"
|
344
348
|
)
|
345
349
|
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
350
|
+
|
351
|
+
if self.model_runner.is_generation:
|
352
|
+
req.sampling_params.max_new_tokens = min(
|
353
|
+
(
|
354
|
+
req.sampling_params.max_new_tokens
|
355
|
+
if req.sampling_params.max_new_tokens is not None
|
356
|
+
else 1 << 30
|
357
|
+
),
|
358
|
+
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
359
|
+
)
|
360
|
+
|
354
361
|
self.waiting_queue.append(req)
|
355
362
|
|
356
|
-
def get_new_prefill_batch(self) -> Optional[
|
357
|
-
# TODO(lsyin): organize this function
|
363
|
+
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
358
364
|
running_bs = (
|
359
365
|
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
360
366
|
)
|
361
367
|
if running_bs >= self.max_running_requests:
|
362
|
-
return
|
363
|
-
|
364
|
-
# Compute matched prefix length
|
365
|
-
for req in self.waiting_queue:
|
366
|
-
req.input_ids = req.origin_input_ids + req.output_ids
|
367
|
-
prefix_indices, last_node = self.tree_cache.match_prefix(
|
368
|
-
rid=req.rid,
|
369
|
-
key=req.input_ids,
|
370
|
-
)
|
371
|
-
if req.return_logprob:
|
372
|
-
prefix_indices = prefix_indices[: req.logprob_start_len]
|
373
|
-
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
374
|
-
req.prefix_indices = prefix_indices
|
375
|
-
req.last_node = last_node
|
368
|
+
return None
|
376
369
|
|
377
370
|
# Get priority queue
|
378
|
-
|
371
|
+
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
|
379
372
|
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
available_size = (
|
386
|
-
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
373
|
+
adder = PrefillAdder(
|
374
|
+
self.tree_cache,
|
375
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
376
|
+
self.max_prefill_tokens,
|
377
|
+
self.chunked_prefill_size,
|
387
378
|
)
|
388
|
-
if self.running_batch:
|
389
|
-
available_size -= sum(
|
390
|
-
[
|
391
|
-
(r.sampling_params.max_new_tokens - len(r.output_ids))
|
392
|
-
* self.new_token_ratio
|
393
|
-
for r in self.running_batch.reqs
|
394
|
-
]
|
395
|
-
)
|
396
379
|
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
|
380
|
+
if self.running_batch is not None:
|
381
|
+
adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
|
382
|
+
|
383
|
+
has_inflight = self.current_inflight_req is not None
|
384
|
+
if self.current_inflight_req is not None:
|
385
|
+
self.current_inflight_req.init_next_round_input(
|
386
|
+
None if prefix_computed else self.tree_cache
|
405
387
|
)
|
406
|
-
|
407
|
-
|
388
|
+
self.current_inflight_req = adder.add_inflight_req(
|
389
|
+
self.current_inflight_req
|
408
390
|
)
|
409
|
-
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
|
410
|
-
can_run_list.append(r)
|
411
|
-
|
412
|
-
if not truncated:
|
413
|
-
# Finish inflight
|
414
|
-
self.current_inflight_req = None
|
415
|
-
new_batch_total_tokens += (
|
416
|
-
r.extend_input_len + r.sampling_params.max_new_tokens
|
417
|
-
)
|
418
|
-
new_batch_input_tokens += r.extend_input_len
|
419
|
-
else:
|
420
|
-
new_batch_total_tokens += r.extend_input_len
|
421
|
-
new_batch_input_tokens += r.extend_input_len
|
422
391
|
|
423
392
|
for req in self.waiting_queue:
|
424
|
-
|
425
|
-
|
426
|
-
if req.extend_input_len < 2:
|
427
|
-
delta = 2 - req.extend_input_len
|
428
|
-
req.extend_input_len += delta
|
429
|
-
req.prefix_indices = req.prefix_indices[:-delta]
|
430
|
-
if req.image_offset is not None:
|
431
|
-
req.image_offset += delta
|
432
|
-
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
|
433
|
-
# Need at least one token to compute logits
|
434
|
-
req.extend_input_len = 1
|
435
|
-
req.prefix_indices = req.prefix_indices[:-1]
|
436
|
-
if req.image_offset is not None:
|
437
|
-
req.image_offset += 1
|
438
|
-
|
393
|
+
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
394
|
+
res = adder.add_one_req(req)
|
439
395
|
if (
|
440
|
-
|
441
|
-
|
442
|
-
+
|
443
|
-
< available_size
|
444
|
-
and (
|
445
|
-
req.extend_input_len + new_batch_input_tokens
|
446
|
-
<= self.max_prefill_tokens
|
447
|
-
or len(can_run_list) == 0
|
448
|
-
)
|
396
|
+
not res
|
397
|
+
or adder.no_remaining_tokens()
|
398
|
+
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
449
399
|
):
|
450
|
-
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
451
|
-
available_size += delta
|
452
|
-
|
453
|
-
if not (
|
454
|
-
req.extend_input_len
|
455
|
-
+ req.sampling_params.max_new_tokens
|
456
|
-
+ new_batch_total_tokens
|
457
|
-
< available_size
|
458
|
-
):
|
459
|
-
# Undo locking
|
460
|
-
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
461
|
-
available_size += delta
|
462
|
-
break
|
463
|
-
else:
|
464
|
-
# Add this request to the running batch
|
465
|
-
if (
|
466
|
-
self.chunked_prefill_size is None
|
467
|
-
or (
|
468
|
-
new_batch_input_tokens + req.extend_input_len
|
469
|
-
<= self.chunked_prefill_size
|
470
|
-
)
|
471
|
-
or (
|
472
|
-
req.return_logprob and req.normalized_prompt_logprob is None
|
473
|
-
)
|
474
|
-
):
|
475
|
-
can_run_list.append(req)
|
476
|
-
new_batch_total_tokens += (
|
477
|
-
req.extend_input_len + req.sampling_params.max_new_tokens
|
478
|
-
)
|
479
|
-
new_batch_input_tokens += req.extend_input_len
|
480
|
-
else:
|
481
|
-
trunc_len = self.chunked_prefill_size - new_batch_input_tokens
|
482
|
-
|
483
|
-
if trunc_len <= 0:
|
484
|
-
# Undo locking
|
485
|
-
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
486
|
-
available_size += delta
|
487
|
-
break
|
488
|
-
|
489
|
-
req.extend_input_len = trunc_len
|
490
|
-
req.input_ids = req.input_ids[
|
491
|
-
: len(req.prefix_indices) + req.extend_input_len
|
492
|
-
]
|
493
|
-
can_run_list.append(req)
|
494
|
-
self.current_inflight_req = req
|
495
|
-
new_batch_input_tokens += req.extend_input_len
|
496
|
-
new_batch_total_tokens += req.extend_input_len
|
497
|
-
break
|
498
|
-
else:
|
499
400
|
break
|
500
401
|
|
501
|
-
|
502
|
-
|
402
|
+
can_run_list = adder.can_run_list
|
403
|
+
|
404
|
+
if adder.new_inflight_req is not None:
|
405
|
+
assert self.current_inflight_req is None
|
406
|
+
self.current_inflight_req = adder.new_inflight_req
|
503
407
|
|
504
408
|
if len(can_run_list) == 0:
|
505
409
|
return None
|
506
410
|
|
507
411
|
# Print stats
|
508
412
|
if self.tp_rank == 0:
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
413
|
+
if isinstance(self.tree_cache, RadixCache):
|
414
|
+
self.tree_cache_metrics["total"] += (
|
415
|
+
adder.log_input_tokens + adder.log_hit_tokens
|
416
|
+
) / 10**9
|
417
|
+
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
418
|
+
tree_cache_hit_rate = (
|
419
|
+
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
420
|
+
)
|
421
|
+
else:
|
422
|
+
tree_cache_hit_rate = 0.0
|
517
423
|
logger.info(
|
518
424
|
f"[gpu={self.gpu_id}] Prefill batch. "
|
519
425
|
f"#new-seq: {len(can_run_list)}, "
|
520
|
-
f"#new-token: {
|
521
|
-
f"#cached-token: {
|
426
|
+
f"#new-token: {adder.log_input_tokens}, "
|
427
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
522
428
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
523
429
|
f"#running-req: {running_bs}, "
|
524
|
-
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) +
|
430
|
+
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
525
431
|
)
|
526
432
|
|
527
433
|
# Return the new batch
|
528
|
-
new_batch =
|
434
|
+
new_batch = ScheduleBatch.init_new(
|
529
435
|
can_run_list,
|
530
436
|
self.req_to_token_pool,
|
531
437
|
self.token_to_kv_pool,
|
@@ -534,47 +440,94 @@ class ModelTpServer:
|
|
534
440
|
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
535
441
|
return new_batch
|
536
442
|
|
537
|
-
def forward_prefill_batch(self, batch:
|
443
|
+
def forward_prefill_batch(self, batch: ScheduleBatch):
|
538
444
|
# Build batch tensors
|
539
445
|
batch.prepare_for_extend(
|
540
446
|
self.model_config.vocab_size, self.int_token_logit_bias
|
541
447
|
)
|
542
448
|
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
output.next_token_logprobs
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
output.normalized_prompt_logprobs
|
557
|
-
|
449
|
+
if self.model_runner.is_generation:
|
450
|
+
# Forward and sample the next tokens
|
451
|
+
if batch.extend_num_tokens != 0:
|
452
|
+
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
453
|
+
next_token_ids = batch.sample(output.next_token_logits)
|
454
|
+
|
455
|
+
# Move logprobs to cpu
|
456
|
+
if output.next_token_logprobs is not None:
|
457
|
+
output.next_token_logprobs = output.next_token_logprobs[
|
458
|
+
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
459
|
+
next_token_ids,
|
460
|
+
].tolist()
|
461
|
+
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
462
|
+
output.normalized_prompt_logprobs = (
|
463
|
+
output.normalized_prompt_logprobs.tolist()
|
464
|
+
)
|
558
465
|
|
559
|
-
|
560
|
-
|
561
|
-
|
466
|
+
next_token_ids = next_token_ids.tolist()
|
467
|
+
else:
|
468
|
+
if self.tokenizer is None:
|
469
|
+
next_token_ids = []
|
470
|
+
for req in batch.reqs:
|
471
|
+
next_token_ids.append(
|
472
|
+
next(iter(req.sampling_params.stop_token_ids))
|
473
|
+
)
|
474
|
+
else:
|
475
|
+
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
476
|
+
|
477
|
+
# Check finish conditions
|
478
|
+
pt = 0
|
479
|
+
for i, req in enumerate(batch.reqs):
|
480
|
+
if req is not self.current_inflight_req:
|
481
|
+
# Inflight reqs' prefill is not finished
|
482
|
+
req.completion_tokens_wo_jump_forward += 1
|
483
|
+
req.output_ids.append(next_token_ids[i])
|
484
|
+
req.check_finished()
|
485
|
+
|
486
|
+
if req.finished():
|
487
|
+
self.tree_cache.cache_finished_req(req)
|
488
|
+
else:
|
489
|
+
self.tree_cache.cache_unfinished_req(req)
|
562
490
|
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
if req is not self.current_inflight_req:
|
567
|
-
req.completion_tokens_wo_jump_forward += 1
|
568
|
-
req.output_ids.append(next_token_ids[i])
|
569
|
-
req.check_finished()
|
491
|
+
if req is self.current_inflight_req:
|
492
|
+
# Inflight request would get a new req idx
|
493
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
570
494
|
|
571
|
-
|
572
|
-
|
573
|
-
|
495
|
+
if req.return_logprob:
|
496
|
+
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
497
|
+
pt += req.extend_input_len
|
498
|
+
else:
|
499
|
+
assert batch.extend_num_tokens != 0
|
500
|
+
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
501
|
+
embeddings = output.embeddings.tolist()
|
502
|
+
|
503
|
+
# Check finish conditions
|
504
|
+
for i, req in enumerate(batch.reqs):
|
505
|
+
req.embedding = embeddings[i]
|
506
|
+
if req is not self.current_inflight_req:
|
507
|
+
# Inflight reqs' prefill is not finished
|
508
|
+
# dummy output token for embedding models
|
509
|
+
req.output_ids.append(0)
|
510
|
+
req.check_finished()
|
511
|
+
|
512
|
+
if req.finished():
|
513
|
+
self.tree_cache.cache_finished_req(req)
|
514
|
+
else:
|
515
|
+
self.tree_cache.cache_unfinished_req(req)
|
516
|
+
|
517
|
+
if req is self.current_inflight_req:
|
518
|
+
# Inflight request would get a new req idx
|
519
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
574
520
|
|
575
521
|
self.handle_finished_requests(batch)
|
576
522
|
|
577
|
-
def add_logprob_return_values(
|
523
|
+
def add_logprob_return_values(
|
524
|
+
self,
|
525
|
+
i,
|
526
|
+
req: Req,
|
527
|
+
pt: int,
|
528
|
+
next_token_ids: List[int],
|
529
|
+
output: LogitProcessorOutput,
|
530
|
+
):
|
578
531
|
if req.normalized_prompt_logprob is None:
|
579
532
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
580
533
|
|
@@ -583,12 +536,12 @@ class ModelTpServer:
|
|
583
536
|
req.input_token_logprobs = list(
|
584
537
|
zip(
|
585
538
|
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
|
586
|
-
req.
|
539
|
+
req.fill_ids[-req.extend_input_len + 1 :],
|
587
540
|
)
|
588
541
|
)
|
589
542
|
if req.logprob_start_len == 0:
|
590
543
|
req.input_token_logprobs = [
|
591
|
-
(None, req.
|
544
|
+
(None, req.fill_ids[0])
|
592
545
|
] + req.input_token_logprobs
|
593
546
|
|
594
547
|
if req.last_update_decode_tokens != 0:
|
@@ -602,7 +555,7 @@ class ModelTpServer:
|
|
602
555
|
+ req.extend_input_len
|
603
556
|
- 1
|
604
557
|
],
|
605
|
-
req.
|
558
|
+
req.fill_ids[-req.last_update_decode_tokens + 1 :],
|
606
559
|
)
|
607
560
|
)
|
608
561
|
)
|
@@ -623,24 +576,7 @@ class ModelTpServer:
|
|
623
576
|
)
|
624
577
|
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
625
578
|
|
626
|
-
def
|
627
|
-
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
628
|
-
for i, req in enumerate(batch.reqs):
|
629
|
-
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
630
|
-
rid=req.rid,
|
631
|
-
token_ids=tuple(req.input_ids),
|
632
|
-
last_uncached_pos=len(req.prefix_indices),
|
633
|
-
req_pool_idx=req_pool_indices_cpu[i],
|
634
|
-
del_in_memory_pool=False,
|
635
|
-
old_last_node=req.last_node,
|
636
|
-
)
|
637
|
-
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
638
|
-
|
639
|
-
if req is self.current_inflight_req:
|
640
|
-
# inflight request would get a new req idx
|
641
|
-
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
|
642
|
-
|
643
|
-
def forward_decode_batch(self, batch: Batch):
|
579
|
+
def forward_decode_batch(self, batch: ScheduleBatch):
|
644
580
|
# Check if decode out of memory
|
645
581
|
if not batch.check_decode_mem():
|
646
582
|
old_ratio = self.new_token_ratio
|
@@ -690,6 +626,9 @@ class ModelTpServer:
|
|
690
626
|
req.output_ids.append(next_token_id)
|
691
627
|
req.check_finished()
|
692
628
|
|
629
|
+
if req.finished():
|
630
|
+
self.tree_cache.cache_finished_req(req)
|
631
|
+
|
693
632
|
if req.return_logprob:
|
694
633
|
req.output_token_logprobs.append(
|
695
634
|
(next_token_logprobs[i], next_token_id)
|
@@ -699,22 +638,23 @@ class ModelTpServer:
|
|
699
638
|
|
700
639
|
self.handle_finished_requests(batch)
|
701
640
|
|
702
|
-
def handle_finished_requests(self, batch:
|
641
|
+
def handle_finished_requests(self, batch: ScheduleBatch):
|
703
642
|
output_rids = []
|
704
|
-
output_vids = []
|
705
|
-
decoded_texts = []
|
706
|
-
output_read_ids = []
|
707
|
-
output_read_offsets = []
|
708
|
-
output_skip_special_tokens = []
|
709
|
-
output_spaces_between_special_tokens = []
|
710
643
|
output_meta_info = []
|
711
644
|
output_finished_reason: List[BaseFinishReason] = []
|
712
|
-
|
645
|
+
if self.model_runner.is_generation:
|
646
|
+
output_vids = []
|
647
|
+
decoded_texts = []
|
648
|
+
output_read_ids = []
|
649
|
+
output_read_offsets = []
|
650
|
+
output_skip_special_tokens = []
|
651
|
+
output_spaces_between_special_tokens = []
|
652
|
+
else: # for embedding model
|
653
|
+
output_embeddings = []
|
713
654
|
unfinished_indices = []
|
655
|
+
|
714
656
|
for i, req in enumerate(batch.reqs):
|
715
|
-
if req.finished():
|
716
|
-
finished_indices.append(i)
|
717
|
-
else:
|
657
|
+
if not req.finished() and req is not self.current_inflight_req:
|
718
658
|
unfinished_indices.append(i)
|
719
659
|
|
720
660
|
if req.finished() or (
|
@@ -727,86 +667,75 @@ class ModelTpServer:
|
|
727
667
|
)
|
728
668
|
):
|
729
669
|
output_rids.append(req.rid)
|
730
|
-
output_vids.append(req.vid)
|
731
|
-
decoded_texts.append(req.decoded_text)
|
732
|
-
read_ids, read_offset = req.init_incremental_detokenize()
|
733
|
-
output_read_ids.append(read_ids)
|
734
|
-
output_read_offsets.append(read_offset)
|
735
|
-
output_skip_special_tokens.append(
|
736
|
-
req.sampling_params.skip_special_tokens
|
737
|
-
)
|
738
|
-
output_spaces_between_special_tokens.append(
|
739
|
-
req.sampling_params.spaces_between_special_tokens
|
740
|
-
)
|
741
|
-
|
742
|
-
meta_info = {
|
743
|
-
"prompt_tokens": len(req.origin_input_ids),
|
744
|
-
"completion_tokens": len(req.output_ids),
|
745
|
-
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
746
|
-
"finish_reason": str(req.finished_reason),
|
747
|
-
}
|
748
|
-
if req.return_logprob:
|
749
|
-
(
|
750
|
-
meta_info["input_token_logprobs"],
|
751
|
-
meta_info["output_token_logprobs"],
|
752
|
-
meta_info["input_top_logprobs"],
|
753
|
-
meta_info["output_top_logprobs"],
|
754
|
-
meta_info["normalized_prompt_logprob"],
|
755
|
-
) = (
|
756
|
-
req.input_token_logprobs,
|
757
|
-
req.output_token_logprobs,
|
758
|
-
req.input_top_logprobs,
|
759
|
-
req.output_top_logprobs,
|
760
|
-
req.normalized_prompt_logprob,
|
761
|
-
)
|
762
|
-
output_meta_info.append(meta_info)
|
763
670
|
output_finished_reason.append(req.finished_reason)
|
671
|
+
if self.model_runner.is_generation:
|
672
|
+
output_vids.append(req.vid)
|
673
|
+
decoded_texts.append(req.decoded_text)
|
674
|
+
read_ids, read_offset = req.init_incremental_detokenize()
|
675
|
+
output_read_ids.append(read_ids)
|
676
|
+
output_read_offsets.append(read_offset)
|
677
|
+
output_skip_special_tokens.append(
|
678
|
+
req.sampling_params.skip_special_tokens
|
679
|
+
)
|
680
|
+
output_spaces_between_special_tokens.append(
|
681
|
+
req.sampling_params.spaces_between_special_tokens
|
682
|
+
)
|
683
|
+
|
684
|
+
meta_info = {
|
685
|
+
"prompt_tokens": len(req.origin_input_ids),
|
686
|
+
"completion_tokens": len(req.output_ids),
|
687
|
+
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
688
|
+
"finish_reason": str(req.finished_reason),
|
689
|
+
}
|
690
|
+
if req.return_logprob:
|
691
|
+
(
|
692
|
+
meta_info["input_token_logprobs"],
|
693
|
+
meta_info["output_token_logprobs"],
|
694
|
+
meta_info["input_top_logprobs"],
|
695
|
+
meta_info["output_top_logprobs"],
|
696
|
+
meta_info["normalized_prompt_logprob"],
|
697
|
+
) = (
|
698
|
+
req.input_token_logprobs,
|
699
|
+
req.output_token_logprobs,
|
700
|
+
req.input_top_logprobs,
|
701
|
+
req.output_top_logprobs,
|
702
|
+
req.normalized_prompt_logprob,
|
703
|
+
)
|
704
|
+
output_meta_info.append(meta_info)
|
705
|
+
else: # for embedding model
|
706
|
+
output_embeddings.append(req.embedding)
|
707
|
+
meta_info = {
|
708
|
+
"prompt_tokens": len(req.origin_input_ids),
|
709
|
+
}
|
710
|
+
output_meta_info.append(meta_info)
|
764
711
|
|
765
712
|
# Send to detokenizer
|
766
713
|
if output_rids:
|
767
|
-
self.
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
714
|
+
if self.model_runner.is_generation:
|
715
|
+
self.out_pyobjs.append(
|
716
|
+
BatchTokenIDOut(
|
717
|
+
output_rids,
|
718
|
+
output_vids,
|
719
|
+
decoded_texts,
|
720
|
+
output_read_ids,
|
721
|
+
output_read_offsets,
|
722
|
+
output_skip_special_tokens,
|
723
|
+
output_spaces_between_special_tokens,
|
724
|
+
output_meta_info,
|
725
|
+
output_finished_reason,
|
726
|
+
)
|
778
727
|
)
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
self.tree_cache.cache_req(
|
788
|
-
rid=req.rid,
|
789
|
-
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
790
|
-
last_uncached_pos=len(req.prefix_indices),
|
791
|
-
req_pool_idx=req_pool_indices_cpu[i],
|
728
|
+
else: # for embedding model
|
729
|
+
self.out_pyobjs.append(
|
730
|
+
BatchEmbeddingOut(
|
731
|
+
output_rids,
|
732
|
+
output_embeddings,
|
733
|
+
output_meta_info,
|
734
|
+
output_finished_reason,
|
735
|
+
)
|
792
736
|
)
|
793
737
|
|
794
|
-
|
795
|
-
|
796
|
-
# Update batch tensors
|
797
|
-
if unfinished_indices:
|
798
|
-
batch.filter_batch(unfinished_indices)
|
799
|
-
else:
|
800
|
-
batch.reqs = []
|
801
|
-
|
802
|
-
def filter_out_inflight(self, batch: Batch):
|
803
|
-
# TODO(lsyin): reduce the overhead, make a special version for this
|
804
|
-
if self.current_inflight_req is None:
|
805
|
-
return
|
806
|
-
|
807
|
-
to_remove = batch.reqs.index(self.current_inflight_req)
|
808
|
-
unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
|
809
|
-
|
738
|
+
# Remove finished reqs: update batch tensors
|
810
739
|
batch.filter_batch(unfinished_indices)
|
811
740
|
|
812
741
|
def flush_cache(self):
|
@@ -873,7 +802,11 @@ def run_tp_server(
|
|
873
802
|
|
874
803
|
|
875
804
|
def launch_tp_servers(
|
876
|
-
gpu_ids
|
805
|
+
gpu_ids: List[int],
|
806
|
+
tp_rank_range: List[int],
|
807
|
+
server_args: ServerArgs,
|
808
|
+
nccl_port: int,
|
809
|
+
model_overide_args: dict,
|
877
810
|
):
|
878
811
|
"""Launch multiple tensor parallel servers."""
|
879
812
|
procs = []
|
@@ -888,7 +821,9 @@ def launch_tp_servers(
|
|
888
821
|
return procs
|
889
822
|
|
890
823
|
|
891
|
-
def broadcast_recv_input(
|
824
|
+
def broadcast_recv_input(
|
825
|
+
data: Any, rank: int, dist_group: torch.distributed.ProcessGroup
|
826
|
+
):
|
892
827
|
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
893
828
|
|
894
829
|
if rank == 0:
|