sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +6 -4
- sglang/bench_serving.py +46 -22
- sglang/lang/compiler.py +2 -2
- sglang/lang/ir.py +3 -3
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +5 -0
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +110 -87
- sglang/srt/managers/tokenizer_manager.py +193 -111
- sglang/srt/managers/tp_worker.py +289 -352
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +168 -105
- sglang/srt/model_executor/model_runner.py +24 -37
- sglang/srt/models/gemma2.py +0 -1
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/qwen2_moe.py +0 -11
- sglang/srt/openai_api/adapter.py +155 -27
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +69 -15
- sglang/srt/server_args.py +26 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +20 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -17,25 +17,30 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import logging
|
19
19
|
import multiprocessing
|
20
|
+
import os
|
20
21
|
import pickle
|
21
22
|
import time
|
22
23
|
import warnings
|
23
|
-
from typing import List, Optional
|
24
|
+
from typing import Any, List, Optional, Union
|
24
25
|
|
25
26
|
import torch
|
27
|
+
import torch.distributed
|
26
28
|
import torch.distributed as dist
|
27
29
|
|
28
30
|
from sglang.global_config import global_config
|
29
31
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
30
32
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
31
33
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
34
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
32
35
|
from sglang.srt.managers.io_struct import (
|
33
36
|
AbortReq,
|
37
|
+
BatchEmbeddingOut,
|
34
38
|
BatchTokenIDOut,
|
35
39
|
FlushCacheReq,
|
40
|
+
TokenizedEmbeddingReqInput,
|
36
41
|
TokenizedGenerateReqInput,
|
37
42
|
)
|
38
|
-
from sglang.srt.managers.policy_scheduler import PolicyScheduler
|
43
|
+
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
39
44
|
from sglang.srt.managers.schedule_batch import (
|
40
45
|
FINISH_ABORT,
|
41
46
|
BaseFinishReason,
|
@@ -59,6 +64,9 @@ from sglang.utils import get_exception_traceback
|
|
59
64
|
logger = logging.getLogger(__name__)
|
60
65
|
|
61
66
|
|
67
|
+
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
68
|
+
|
69
|
+
|
62
70
|
class ModelTpServer:
|
63
71
|
def __init__(
|
64
72
|
self,
|
@@ -98,26 +106,24 @@ class ModelTpServer:
|
|
98
106
|
nccl_port=nccl_port,
|
99
107
|
server_args=server_args,
|
100
108
|
)
|
101
|
-
|
102
|
-
|
103
|
-
self.processor = get_processor(
|
104
|
-
server_args.tokenizer_path,
|
105
|
-
tokenizer_mode=server_args.tokenizer_mode,
|
106
|
-
trust_remote_code=server_args.trust_remote_code,
|
107
|
-
)
|
108
|
-
self.tokenizer = self.processor.tokenizer
|
109
|
+
if server_args.skip_tokenizer_init:
|
110
|
+
self.tokenizer = self.processor = None
|
109
111
|
else:
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
112
|
+
if is_multimodal_model(server_args.model_path):
|
113
|
+
self.processor = get_processor(
|
114
|
+
server_args.tokenizer_path,
|
115
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
116
|
+
trust_remote_code=server_args.trust_remote_code,
|
117
|
+
)
|
118
|
+
self.tokenizer = self.processor.tokenizer
|
119
|
+
else:
|
120
|
+
self.tokenizer = get_tokenizer(
|
121
|
+
server_args.tokenizer_path,
|
122
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
123
|
+
trust_remote_code=server_args.trust_remote_code,
|
124
|
+
)
|
115
125
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
116
|
-
self.max_prefill_tokens =
|
117
|
-
16384
|
118
|
-
if server_args.max_prefill_tokens is None
|
119
|
-
else server_args.max_prefill_tokens
|
120
|
-
)
|
126
|
+
self.max_prefill_tokens = server_args.max_prefill_tokens
|
121
127
|
self.max_running_requests = min(
|
122
128
|
(
|
123
129
|
self.max_total_num_tokens // 2
|
@@ -160,13 +166,7 @@ class ModelTpServer:
|
|
160
166
|
disable=server_args.disable_radix_cache,
|
161
167
|
)
|
162
168
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
163
|
-
self.scheduler = PolicyScheduler(
|
164
|
-
self.schedule_policy,
|
165
|
-
self.max_running_requests,
|
166
|
-
self.max_prefill_tokens,
|
167
|
-
self.max_total_num_tokens,
|
168
|
-
self.tree_cache,
|
169
|
-
)
|
169
|
+
self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache)
|
170
170
|
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
171
171
|
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
172
172
|
|
@@ -180,13 +180,15 @@ class ModelTpServer:
|
|
180
180
|
self.last_stats_tic = time.time()
|
181
181
|
|
182
182
|
# Init the FSM cache for constrained generation
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
183
|
+
if not server_args.skip_tokenizer_init:
|
184
|
+
self.regex_fsm_cache = FSMCache(
|
185
|
+
server_args.tokenizer_path,
|
186
|
+
{
|
187
|
+
"tokenizer_mode": server_args.tokenizer_mode,
|
188
|
+
"trust_remote_code": server_args.trust_remote_code,
|
189
|
+
},
|
190
|
+
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
191
|
+
)
|
190
192
|
self.jump_forward_cache = JumpForwardCache()
|
191
193
|
|
192
194
|
# Init new token estimation
|
@@ -201,11 +203,13 @@ class ModelTpServer:
|
|
201
203
|
self.new_token_ratio = self.min_new_token_ratio
|
202
204
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
203
205
|
|
204
|
-
def exposed_step(self, recv_reqs):
|
206
|
+
def exposed_step(self, recv_reqs: List):
|
205
207
|
try:
|
206
208
|
# Recv requests
|
207
209
|
for recv_req in recv_reqs:
|
208
|
-
if isinstance(
|
210
|
+
if isinstance(
|
211
|
+
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
212
|
+
):
|
209
213
|
self.handle_generate_request(recv_req)
|
210
214
|
elif isinstance(recv_req, FlushCacheReq):
|
211
215
|
self.flush_cache()
|
@@ -232,8 +236,6 @@ class ModelTpServer:
|
|
232
236
|
if new_batch is not None:
|
233
237
|
# Run a new prefill batch
|
234
238
|
self.forward_prefill_batch(new_batch)
|
235
|
-
self.cache_filled_batch(new_batch)
|
236
|
-
self.filter_out_inflight(new_batch)
|
237
239
|
|
238
240
|
if not new_batch.is_empty():
|
239
241
|
if self.running_batch is None:
|
@@ -250,7 +252,7 @@ class ModelTpServer:
|
|
250
252
|
|
251
253
|
# Print stats
|
252
254
|
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
253
|
-
self.
|
255
|
+
self.print_decode_stats()
|
254
256
|
|
255
257
|
if self.running_batch.is_empty():
|
256
258
|
self.running_batch = None
|
@@ -262,7 +264,7 @@ class ModelTpServer:
|
|
262
264
|
self.check_memory()
|
263
265
|
self.new_token_ratio = global_config.init_new_token_ratio
|
264
266
|
|
265
|
-
def
|
267
|
+
def print_decode_stats(self):
|
266
268
|
num_used = self.max_total_num_tokens - (
|
267
269
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
268
270
|
)
|
@@ -288,6 +290,7 @@ class ModelTpServer:
|
|
288
290
|
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
289
291
|
"KV cache pool leak detected!"
|
290
292
|
)
|
293
|
+
exit(1) if crash_on_warning else None
|
291
294
|
|
292
295
|
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
293
296
|
warnings.warn(
|
@@ -296,44 +299,46 @@ class ModelTpServer:
|
|
296
299
|
f"total slots={self.req_to_token_pool.size}\n"
|
297
300
|
"Memory pool leak detected!"
|
298
301
|
)
|
302
|
+
exit(1) if crash_on_warning else None
|
299
303
|
|
300
304
|
def handle_generate_request(
|
301
305
|
self,
|
302
|
-
recv_req: TokenizedGenerateReqInput,
|
306
|
+
recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
303
307
|
):
|
304
308
|
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
305
|
-
req.pixel_values = recv_req.pixel_values
|
306
|
-
if req.pixel_values is not None:
|
307
|
-
req.pad_value = [
|
308
|
-
(recv_req.image_hash) % self.model_config.vocab_size,
|
309
|
-
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
310
|
-
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
311
|
-
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
312
|
-
]
|
313
|
-
req.image_size = recv_req.image_size
|
314
|
-
(
|
315
|
-
req.origin_input_ids,
|
316
|
-
req.image_offset,
|
317
|
-
) = self.model_runner.model.pad_input_ids(
|
318
|
-
req.origin_input_ids_unpadded,
|
319
|
-
req.pad_value,
|
320
|
-
req.pixel_values.shape,
|
321
|
-
req.image_size,
|
322
|
-
)
|
323
|
-
req.sampling_params = recv_req.sampling_params
|
324
|
-
req.return_logprob = recv_req.return_logprob
|
325
|
-
req.logprob_start_len = recv_req.logprob_start_len
|
326
|
-
req.top_logprobs_num = recv_req.top_logprobs_num
|
327
|
-
req.stream = recv_req.stream
|
328
309
|
req.tokenizer = self.tokenizer
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
req.
|
333
|
-
|
334
|
-
|
335
|
-
|
310
|
+
req.sampling_params = recv_req.sampling_params
|
311
|
+
if self.model_runner.is_generation:
|
312
|
+
req.pixel_values = recv_req.pixel_values
|
313
|
+
if req.pixel_values is not None:
|
314
|
+
req.pad_value = [
|
315
|
+
(recv_req.image_hash) % self.model_config.vocab_size,
|
316
|
+
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
317
|
+
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
318
|
+
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
319
|
+
]
|
320
|
+
req.image_size = recv_req.image_size
|
321
|
+
(
|
322
|
+
req.origin_input_ids,
|
323
|
+
req.image_offset,
|
324
|
+
) = self.model_runner.model.pad_input_ids(
|
325
|
+
req.origin_input_ids_unpadded,
|
326
|
+
req.pad_value,
|
327
|
+
req.pixel_values.shape,
|
328
|
+
req.image_size,
|
336
329
|
)
|
330
|
+
req.return_logprob = recv_req.return_logprob
|
331
|
+
req.logprob_start_len = recv_req.logprob_start_len
|
332
|
+
req.top_logprobs_num = recv_req.top_logprobs_num
|
333
|
+
req.stream = recv_req.stream
|
334
|
+
|
335
|
+
# Init regex fsm
|
336
|
+
if req.sampling_params.regex is not None:
|
337
|
+
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
338
|
+
if not self.disable_regex_jump_forward:
|
339
|
+
req.jump_forward_map = self.jump_forward_cache.query(
|
340
|
+
req.sampling_params.regex
|
341
|
+
)
|
337
342
|
|
338
343
|
# Truncate prompts that are too long
|
339
344
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
@@ -342,186 +347,87 @@ class ModelTpServer:
|
|
342
347
|
"the max context length. Truncated!!!"
|
343
348
|
)
|
344
349
|
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
350
|
+
|
351
|
+
if self.model_runner.is_generation:
|
352
|
+
req.sampling_params.max_new_tokens = min(
|
353
|
+
(
|
354
|
+
req.sampling_params.max_new_tokens
|
355
|
+
if req.sampling_params.max_new_tokens is not None
|
356
|
+
else 1 << 30
|
357
|
+
),
|
358
|
+
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
359
|
+
)
|
360
|
+
|
353
361
|
self.waiting_queue.append(req)
|
354
362
|
|
355
363
|
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
356
|
-
# TODO(lsyin): organize this function
|
357
364
|
running_bs = (
|
358
365
|
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
359
366
|
)
|
360
367
|
if running_bs >= self.max_running_requests:
|
361
|
-
return
|
362
|
-
|
363
|
-
# Compute matched prefix length
|
364
|
-
for req in self.waiting_queue:
|
365
|
-
req.input_ids = req.origin_input_ids + req.output_ids
|
366
|
-
try_match_ids = req.input_ids
|
367
|
-
if req.return_logprob:
|
368
|
-
try_match_ids = req.input_ids[: req.logprob_start_len]
|
369
|
-
# NOTE: the prefix_indices must always be aligned with last_node
|
370
|
-
prefix_indices, last_node = self.tree_cache.match_prefix(
|
371
|
-
rid=req.rid, key=try_match_ids
|
372
|
-
)
|
373
|
-
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
374
|
-
req.prefix_indices = prefix_indices
|
375
|
-
req.last_node = last_node
|
368
|
+
return None
|
376
369
|
|
377
370
|
# Get priority queue
|
378
|
-
|
371
|
+
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
|
379
372
|
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
available_size = (
|
386
|
-
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
373
|
+
adder = PrefillAdder(
|
374
|
+
self.tree_cache,
|
375
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
376
|
+
self.max_prefill_tokens,
|
377
|
+
self.chunked_prefill_size,
|
387
378
|
)
|
388
|
-
if self.running_batch:
|
389
|
-
available_size -= sum(
|
390
|
-
[
|
391
|
-
(r.sampling_params.max_new_tokens - len(r.output_ids))
|
392
|
-
* self.new_token_ratio
|
393
|
-
for r in self.running_batch.reqs
|
394
|
-
]
|
395
|
-
)
|
396
379
|
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
|
380
|
+
if self.running_batch is not None:
|
381
|
+
adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
|
382
|
+
|
383
|
+
has_inflight = self.current_inflight_req is not None
|
384
|
+
if self.current_inflight_req is not None:
|
385
|
+
self.current_inflight_req.init_next_round_input(
|
386
|
+
None if prefix_computed else self.tree_cache
|
405
387
|
)
|
406
|
-
|
407
|
-
|
388
|
+
self.current_inflight_req = adder.add_inflight_req(
|
389
|
+
self.current_inflight_req
|
408
390
|
)
|
409
|
-
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
|
410
|
-
can_run_list.append(r)
|
411
|
-
|
412
|
-
if not truncated:
|
413
|
-
# Finish inflight
|
414
|
-
self.current_inflight_req = None
|
415
|
-
new_batch_total_tokens += (
|
416
|
-
r.extend_input_len + r.sampling_params.max_new_tokens
|
417
|
-
)
|
418
|
-
new_batch_input_tokens += r.extend_input_len
|
419
|
-
else:
|
420
|
-
new_batch_total_tokens += r.extend_input_len
|
421
|
-
new_batch_input_tokens += r.extend_input_len
|
422
391
|
|
423
392
|
for req in self.waiting_queue:
|
424
|
-
|
425
|
-
|
426
|
-
if req.extend_input_len < 2:
|
427
|
-
delta = 2 - req.extend_input_len
|
428
|
-
req.extend_input_len += delta
|
429
|
-
req.prefix_indices = req.prefix_indices[:-delta]
|
430
|
-
if req.image_offset is not None:
|
431
|
-
req.image_offset += delta
|
432
|
-
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
|
433
|
-
# Need at least one token to compute logits
|
434
|
-
req.extend_input_len = 1
|
435
|
-
req.prefix_indices = req.prefix_indices[:-1]
|
436
|
-
if req.image_offset is not None:
|
437
|
-
req.image_offset += 1
|
438
|
-
|
393
|
+
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
394
|
+
res = adder.add_one_req(req)
|
439
395
|
if (
|
440
|
-
|
441
|
-
|
442
|
-
+
|
443
|
-
< available_size
|
444
|
-
and (
|
445
|
-
req.extend_input_len + new_batch_input_tokens
|
446
|
-
<= self.max_prefill_tokens
|
447
|
-
or len(can_run_list) == 0
|
448
|
-
)
|
396
|
+
not res
|
397
|
+
or adder.no_remaining_tokens()
|
398
|
+
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
449
399
|
):
|
450
|
-
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
451
|
-
available_size += delta
|
452
|
-
|
453
|
-
if not (
|
454
|
-
req.extend_input_len
|
455
|
-
+ req.sampling_params.max_new_tokens
|
456
|
-
+ new_batch_total_tokens
|
457
|
-
< available_size
|
458
|
-
):
|
459
|
-
# Undo locking
|
460
|
-
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
461
|
-
available_size += delta
|
462
|
-
break
|
463
|
-
else:
|
464
|
-
# Add this request to the running batch
|
465
|
-
if (
|
466
|
-
self.chunked_prefill_size is None
|
467
|
-
or (
|
468
|
-
new_batch_input_tokens + req.extend_input_len
|
469
|
-
<= self.chunked_prefill_size
|
470
|
-
)
|
471
|
-
or (
|
472
|
-
req.return_logprob and req.normalized_prompt_logprob is None
|
473
|
-
)
|
474
|
-
):
|
475
|
-
can_run_list.append(req)
|
476
|
-
new_batch_total_tokens += (
|
477
|
-
req.extend_input_len + req.sampling_params.max_new_tokens
|
478
|
-
)
|
479
|
-
new_batch_input_tokens += req.extend_input_len
|
480
|
-
else:
|
481
|
-
trunc_len = self.chunked_prefill_size - new_batch_input_tokens
|
482
|
-
|
483
|
-
if trunc_len <= 0:
|
484
|
-
# Undo locking
|
485
|
-
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
486
|
-
available_size += delta
|
487
|
-
break
|
488
|
-
|
489
|
-
req.extend_input_len = trunc_len
|
490
|
-
req.input_ids = req.input_ids[
|
491
|
-
: len(req.prefix_indices) + req.extend_input_len
|
492
|
-
]
|
493
|
-
can_run_list.append(req)
|
494
|
-
self.current_inflight_req = req
|
495
|
-
new_batch_input_tokens += req.extend_input_len
|
496
|
-
new_batch_total_tokens += req.extend_input_len
|
497
|
-
break
|
498
|
-
else:
|
499
400
|
break
|
500
401
|
|
501
|
-
|
502
|
-
|
402
|
+
can_run_list = adder.can_run_list
|
403
|
+
|
404
|
+
if adder.new_inflight_req is not None:
|
405
|
+
assert self.current_inflight_req is None
|
406
|
+
self.current_inflight_req = adder.new_inflight_req
|
503
407
|
|
504
408
|
if len(can_run_list) == 0:
|
505
409
|
return None
|
506
410
|
|
507
411
|
# Print stats
|
508
412
|
if self.tp_rank == 0:
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
413
|
+
if isinstance(self.tree_cache, RadixCache):
|
414
|
+
self.tree_cache_metrics["total"] += (
|
415
|
+
adder.log_input_tokens + adder.log_hit_tokens
|
416
|
+
) / 10**9
|
417
|
+
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
418
|
+
tree_cache_hit_rate = (
|
419
|
+
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
420
|
+
)
|
421
|
+
else:
|
422
|
+
tree_cache_hit_rate = 0.0
|
517
423
|
logger.info(
|
518
424
|
f"[gpu={self.gpu_id}] Prefill batch. "
|
519
425
|
f"#new-seq: {len(can_run_list)}, "
|
520
|
-
f"#new-token: {
|
521
|
-
f"#cached-token: {
|
426
|
+
f"#new-token: {adder.log_input_tokens}, "
|
427
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
522
428
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
523
429
|
f"#running-req: {running_bs}, "
|
524
|
-
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) +
|
430
|
+
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
525
431
|
)
|
526
432
|
|
527
433
|
# Return the new batch
|
@@ -540,41 +446,88 @@ class ModelTpServer:
|
|
540
446
|
self.model_config.vocab_size, self.int_token_logit_bias
|
541
447
|
)
|
542
448
|
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
output.next_token_logprobs
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
output.normalized_prompt_logprobs
|
557
|
-
|
449
|
+
if self.model_runner.is_generation:
|
450
|
+
# Forward and sample the next tokens
|
451
|
+
if batch.extend_num_tokens != 0:
|
452
|
+
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
453
|
+
next_token_ids = batch.sample(output.next_token_logits)
|
454
|
+
|
455
|
+
# Move logprobs to cpu
|
456
|
+
if output.next_token_logprobs is not None:
|
457
|
+
output.next_token_logprobs = output.next_token_logprobs[
|
458
|
+
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
459
|
+
next_token_ids,
|
460
|
+
].tolist()
|
461
|
+
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
462
|
+
output.normalized_prompt_logprobs = (
|
463
|
+
output.normalized_prompt_logprobs.tolist()
|
464
|
+
)
|
558
465
|
|
559
|
-
|
560
|
-
|
561
|
-
|
466
|
+
next_token_ids = next_token_ids.tolist()
|
467
|
+
else:
|
468
|
+
if self.tokenizer is None:
|
469
|
+
next_token_ids = []
|
470
|
+
for req in batch.reqs:
|
471
|
+
next_token_ids.append(
|
472
|
+
next(iter(req.sampling_params.stop_token_ids))
|
473
|
+
)
|
474
|
+
else:
|
475
|
+
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
476
|
+
|
477
|
+
# Check finish conditions
|
478
|
+
pt = 0
|
479
|
+
for i, req in enumerate(batch.reqs):
|
480
|
+
if req is not self.current_inflight_req:
|
481
|
+
# Inflight reqs' prefill is not finished
|
482
|
+
req.completion_tokens_wo_jump_forward += 1
|
483
|
+
req.output_ids.append(next_token_ids[i])
|
484
|
+
req.check_finished()
|
485
|
+
|
486
|
+
if req.finished():
|
487
|
+
self.tree_cache.cache_finished_req(req)
|
488
|
+
else:
|
489
|
+
self.tree_cache.cache_unfinished_req(req)
|
562
490
|
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
if req is not self.current_inflight_req:
|
567
|
-
req.completion_tokens_wo_jump_forward += 1
|
568
|
-
req.output_ids.append(next_token_ids[i])
|
569
|
-
req.check_finished()
|
491
|
+
if req is self.current_inflight_req:
|
492
|
+
# Inflight request would get a new req idx
|
493
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
570
494
|
|
571
|
-
|
572
|
-
|
573
|
-
|
495
|
+
if req.return_logprob:
|
496
|
+
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
497
|
+
pt += req.extend_input_len
|
498
|
+
else:
|
499
|
+
assert batch.extend_num_tokens != 0
|
500
|
+
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
501
|
+
embeddings = output.embeddings.tolist()
|
502
|
+
|
503
|
+
# Check finish conditions
|
504
|
+
for i, req in enumerate(batch.reqs):
|
505
|
+
req.embedding = embeddings[i]
|
506
|
+
if req is not self.current_inflight_req:
|
507
|
+
# Inflight reqs' prefill is not finished
|
508
|
+
# dummy output token for embedding models
|
509
|
+
req.output_ids.append(0)
|
510
|
+
req.check_finished()
|
511
|
+
|
512
|
+
if req.finished():
|
513
|
+
self.tree_cache.cache_finished_req(req)
|
514
|
+
else:
|
515
|
+
self.tree_cache.cache_unfinished_req(req)
|
516
|
+
|
517
|
+
if req is self.current_inflight_req:
|
518
|
+
# Inflight request would get a new req idx
|
519
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
574
520
|
|
575
521
|
self.handle_finished_requests(batch)
|
576
522
|
|
577
|
-
def add_logprob_return_values(
|
523
|
+
def add_logprob_return_values(
|
524
|
+
self,
|
525
|
+
i,
|
526
|
+
req: Req,
|
527
|
+
pt: int,
|
528
|
+
next_token_ids: List[int],
|
529
|
+
output: LogitProcessorOutput,
|
530
|
+
):
|
578
531
|
if req.normalized_prompt_logprob is None:
|
579
532
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
580
533
|
|
@@ -583,12 +536,12 @@ class ModelTpServer:
|
|
583
536
|
req.input_token_logprobs = list(
|
584
537
|
zip(
|
585
538
|
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
|
586
|
-
req.
|
539
|
+
req.fill_ids[-req.extend_input_len + 1 :],
|
587
540
|
)
|
588
541
|
)
|
589
542
|
if req.logprob_start_len == 0:
|
590
543
|
req.input_token_logprobs = [
|
591
|
-
(None, req.
|
544
|
+
(None, req.fill_ids[0])
|
592
545
|
] + req.input_token_logprobs
|
593
546
|
|
594
547
|
if req.last_update_decode_tokens != 0:
|
@@ -602,7 +555,7 @@ class ModelTpServer:
|
|
602
555
|
+ req.extend_input_len
|
603
556
|
- 1
|
604
557
|
],
|
605
|
-
req.
|
558
|
+
req.fill_ids[-req.last_update_decode_tokens + 1 :],
|
606
559
|
)
|
607
560
|
)
|
608
561
|
)
|
@@ -623,22 +576,6 @@ class ModelTpServer:
|
|
623
576
|
)
|
624
577
|
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
625
578
|
|
626
|
-
def cache_filled_batch(self, batch: ScheduleBatch):
|
627
|
-
for i, req in enumerate(batch.reqs):
|
628
|
-
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
629
|
-
rid=req.rid,
|
630
|
-
token_ids=tuple(req.input_ids),
|
631
|
-
last_uncached_pos=len(req.prefix_indices),
|
632
|
-
req_pool_idx=req.req_pool_idx,
|
633
|
-
del_in_memory_pool=False,
|
634
|
-
old_last_node=req.last_node,
|
635
|
-
)
|
636
|
-
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
637
|
-
|
638
|
-
if req is self.current_inflight_req:
|
639
|
-
# inflight request would get a new req idx
|
640
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
641
|
-
|
642
579
|
def forward_decode_batch(self, batch: ScheduleBatch):
|
643
580
|
# Check if decode out of memory
|
644
581
|
if not batch.check_decode_mem():
|
@@ -689,6 +626,9 @@ class ModelTpServer:
|
|
689
626
|
req.output_ids.append(next_token_id)
|
690
627
|
req.check_finished()
|
691
628
|
|
629
|
+
if req.finished():
|
630
|
+
self.tree_cache.cache_finished_req(req)
|
631
|
+
|
692
632
|
if req.return_logprob:
|
693
633
|
req.output_token_logprobs.append(
|
694
634
|
(next_token_logprobs[i], next_token_id)
|
@@ -700,20 +640,21 @@ class ModelTpServer:
|
|
700
640
|
|
701
641
|
def handle_finished_requests(self, batch: ScheduleBatch):
|
702
642
|
output_rids = []
|
703
|
-
output_vids = []
|
704
|
-
decoded_texts = []
|
705
|
-
output_read_ids = []
|
706
|
-
output_read_offsets = []
|
707
|
-
output_skip_special_tokens = []
|
708
|
-
output_spaces_between_special_tokens = []
|
709
643
|
output_meta_info = []
|
710
644
|
output_finished_reason: List[BaseFinishReason] = []
|
711
|
-
|
645
|
+
if self.model_runner.is_generation:
|
646
|
+
output_vids = []
|
647
|
+
decoded_texts = []
|
648
|
+
output_read_ids = []
|
649
|
+
output_read_offsets = []
|
650
|
+
output_skip_special_tokens = []
|
651
|
+
output_spaces_between_special_tokens = []
|
652
|
+
else: # for embedding model
|
653
|
+
output_embeddings = []
|
712
654
|
unfinished_indices = []
|
655
|
+
|
713
656
|
for i, req in enumerate(batch.reqs):
|
714
|
-
if req.finished():
|
715
|
-
finished_indices.append(i)
|
716
|
-
else:
|
657
|
+
if not req.finished() and req is not self.current_inflight_req:
|
717
658
|
unfinished_indices.append(i)
|
718
659
|
|
719
660
|
if req.finished() or (
|
@@ -726,85 +667,75 @@ class ModelTpServer:
|
|
726
667
|
)
|
727
668
|
):
|
728
669
|
output_rids.append(req.rid)
|
729
|
-
output_vids.append(req.vid)
|
730
|
-
decoded_texts.append(req.decoded_text)
|
731
|
-
read_ids, read_offset = req.init_incremental_detokenize()
|
732
|
-
output_read_ids.append(read_ids)
|
733
|
-
output_read_offsets.append(read_offset)
|
734
|
-
output_skip_special_tokens.append(
|
735
|
-
req.sampling_params.skip_special_tokens
|
736
|
-
)
|
737
|
-
output_spaces_between_special_tokens.append(
|
738
|
-
req.sampling_params.spaces_between_special_tokens
|
739
|
-
)
|
740
|
-
|
741
|
-
meta_info = {
|
742
|
-
"prompt_tokens": len(req.origin_input_ids),
|
743
|
-
"completion_tokens": len(req.output_ids),
|
744
|
-
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
745
|
-
"finish_reason": str(req.finished_reason),
|
746
|
-
}
|
747
|
-
if req.return_logprob:
|
748
|
-
(
|
749
|
-
meta_info["input_token_logprobs"],
|
750
|
-
meta_info["output_token_logprobs"],
|
751
|
-
meta_info["input_top_logprobs"],
|
752
|
-
meta_info["output_top_logprobs"],
|
753
|
-
meta_info["normalized_prompt_logprob"],
|
754
|
-
) = (
|
755
|
-
req.input_token_logprobs,
|
756
|
-
req.output_token_logprobs,
|
757
|
-
req.input_top_logprobs,
|
758
|
-
req.output_top_logprobs,
|
759
|
-
req.normalized_prompt_logprob,
|
760
|
-
)
|
761
|
-
output_meta_info.append(meta_info)
|
762
670
|
output_finished_reason.append(req.finished_reason)
|
671
|
+
if self.model_runner.is_generation:
|
672
|
+
output_vids.append(req.vid)
|
673
|
+
decoded_texts.append(req.decoded_text)
|
674
|
+
read_ids, read_offset = req.init_incremental_detokenize()
|
675
|
+
output_read_ids.append(read_ids)
|
676
|
+
output_read_offsets.append(read_offset)
|
677
|
+
output_skip_special_tokens.append(
|
678
|
+
req.sampling_params.skip_special_tokens
|
679
|
+
)
|
680
|
+
output_spaces_between_special_tokens.append(
|
681
|
+
req.sampling_params.spaces_between_special_tokens
|
682
|
+
)
|
683
|
+
|
684
|
+
meta_info = {
|
685
|
+
"prompt_tokens": len(req.origin_input_ids),
|
686
|
+
"completion_tokens": len(req.output_ids),
|
687
|
+
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
688
|
+
"finish_reason": str(req.finished_reason),
|
689
|
+
}
|
690
|
+
if req.return_logprob:
|
691
|
+
(
|
692
|
+
meta_info["input_token_logprobs"],
|
693
|
+
meta_info["output_token_logprobs"],
|
694
|
+
meta_info["input_top_logprobs"],
|
695
|
+
meta_info["output_top_logprobs"],
|
696
|
+
meta_info["normalized_prompt_logprob"],
|
697
|
+
) = (
|
698
|
+
req.input_token_logprobs,
|
699
|
+
req.output_token_logprobs,
|
700
|
+
req.input_top_logprobs,
|
701
|
+
req.output_top_logprobs,
|
702
|
+
req.normalized_prompt_logprob,
|
703
|
+
)
|
704
|
+
output_meta_info.append(meta_info)
|
705
|
+
else: # for embedding model
|
706
|
+
output_embeddings.append(req.embedding)
|
707
|
+
meta_info = {
|
708
|
+
"prompt_tokens": len(req.origin_input_ids),
|
709
|
+
}
|
710
|
+
output_meta_info.append(meta_info)
|
763
711
|
|
764
712
|
# Send to detokenizer
|
765
713
|
if output_rids:
|
766
|
-
self.
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
714
|
+
if self.model_runner.is_generation:
|
715
|
+
self.out_pyobjs.append(
|
716
|
+
BatchTokenIDOut(
|
717
|
+
output_rids,
|
718
|
+
output_vids,
|
719
|
+
decoded_texts,
|
720
|
+
output_read_ids,
|
721
|
+
output_read_offsets,
|
722
|
+
output_skip_special_tokens,
|
723
|
+
output_spaces_between_special_tokens,
|
724
|
+
output_meta_info,
|
725
|
+
output_finished_reason,
|
726
|
+
)
|
777
727
|
)
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
rid=req.rid,
|
787
|
-
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
788
|
-
last_uncached_pos=len(req.prefix_indices),
|
789
|
-
req_pool_idx=req.req_pool_idx,
|
728
|
+
else: # for embedding model
|
729
|
+
self.out_pyobjs.append(
|
730
|
+
BatchEmbeddingOut(
|
731
|
+
output_rids,
|
732
|
+
output_embeddings,
|
733
|
+
output_meta_info,
|
734
|
+
output_finished_reason,
|
735
|
+
)
|
790
736
|
)
|
791
737
|
|
792
|
-
|
793
|
-
|
794
|
-
# Update batch tensors
|
795
|
-
if unfinished_indices:
|
796
|
-
batch.filter_batch(unfinished_indices)
|
797
|
-
else:
|
798
|
-
batch.reqs = []
|
799
|
-
|
800
|
-
def filter_out_inflight(self, batch: ScheduleBatch):
|
801
|
-
# TODO(lsyin): reduce the overhead, make a special version for this
|
802
|
-
if self.current_inflight_req is None:
|
803
|
-
return
|
804
|
-
|
805
|
-
to_remove = batch.reqs.index(self.current_inflight_req)
|
806
|
-
unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
|
807
|
-
|
738
|
+
# Remove finished reqs: update batch tensors
|
808
739
|
batch.filter_batch(unfinished_indices)
|
809
740
|
|
810
741
|
def flush_cache(self):
|
@@ -871,7 +802,11 @@ def run_tp_server(
|
|
871
802
|
|
872
803
|
|
873
804
|
def launch_tp_servers(
|
874
|
-
gpu_ids
|
805
|
+
gpu_ids: List[int],
|
806
|
+
tp_rank_range: List[int],
|
807
|
+
server_args: ServerArgs,
|
808
|
+
nccl_port: int,
|
809
|
+
model_overide_args: dict,
|
875
810
|
):
|
876
811
|
"""Launch multiple tensor parallel servers."""
|
877
812
|
procs = []
|
@@ -886,7 +821,9 @@ def launch_tp_servers(
|
|
886
821
|
return procs
|
887
822
|
|
888
823
|
|
889
|
-
def broadcast_recv_input(
|
824
|
+
def broadcast_recv_input(
|
825
|
+
data: Any, rank: int, dist_group: torch.distributed.ProcessGroup
|
826
|
+
):
|
890
827
|
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
891
828
|
|
892
829
|
if rank == 0:
|