sglang 0.3.5__py3-none-any.whl → 0.3.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +113 -3
- sglang/srt/configs/model_config.py +5 -2
- sglang/srt/constrained/__init__.py +2 -66
- sglang/srt/constrained/base_grammar_backend.py +72 -0
- sglang/srt/constrained/outlines_backend.py +165 -0
- sglang/srt/constrained/outlines_jump_forward.py +182 -0
- sglang/srt/constrained/xgrammar_backend.py +114 -0
- sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
- sglang/srt/layers/fused_moe/fused_moe.py +23 -7
- sglang/srt/layers/quantization/base_config.py +4 -6
- sglang/srt/layers/vocab_parallel_embedding.py +216 -150
- sglang/srt/managers/io_struct.py +5 -3
- sglang/srt/managers/schedule_batch.py +14 -20
- sglang/srt/managers/scheduler.py +153 -94
- sglang/srt/managers/tokenizer_manager.py +81 -17
- sglang/srt/metrics/collector.py +211 -0
- sglang/srt/metrics/func_timer.py +108 -0
- sglang/srt/mm_utils.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/forward_batch_info.py +7 -3
- sglang/srt/model_executor/model_runner.py +2 -1
- sglang/srt/models/gemma2_reward.py +69 -0
- sglang/srt/models/gpt2.py +31 -37
- sglang/srt/models/internlm2_reward.py +62 -0
- sglang/srt/models/llama.py +11 -6
- sglang/srt/models/llama_reward.py +5 -26
- sglang/srt/models/qwen2_vl.py +5 -7
- sglang/srt/openai_api/adapter.py +6 -2
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/sampling/sampling_params.py +0 -14
- sglang/srt/server.py +58 -16
- sglang/srt/server_args.py +42 -22
- sglang/srt/utils.py +87 -0
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_mgsm.py +2 -2
- sglang/test/test_utils.py +18 -4
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/METADATA +11 -7
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/RECORD +45 -42
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/WHEEL +1 -1
- sglang/srt/constrained/base_tool_cache.py +0 -65
- sglang/srt/constrained/bnf_cache.py +0 -61
- sglang/srt/constrained/fsm_cache.py +0 -95
- sglang/srt/constrained/grammar.py +0 -190
- sglang/srt/constrained/jump_forward.py +0 -203
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/top_level.txt +0 -0
@@ -37,7 +37,7 @@ import torch
|
|
37
37
|
|
38
38
|
from sglang.global_config import global_config
|
39
39
|
from sglang.srt.configs.model_config import ModelConfig
|
40
|
-
from sglang.srt.constrained.
|
40
|
+
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
41
41
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
42
42
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
43
43
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
@@ -107,12 +107,14 @@ class FINISH_LENGTH(BaseFinishReason):
|
|
107
107
|
|
108
108
|
|
109
109
|
class FINISH_ABORT(BaseFinishReason):
|
110
|
-
def __init__(self):
|
110
|
+
def __init__(self, message="Unknown error"):
|
111
111
|
super().__init__(is_error=True)
|
112
|
+
self.message = message
|
112
113
|
|
113
114
|
def to_json(self):
|
114
115
|
return {
|
115
116
|
"type": "abort",
|
117
|
+
"message": self.message,
|
116
118
|
}
|
117
119
|
|
118
120
|
|
@@ -133,6 +135,7 @@ class ImageInputs:
|
|
133
135
|
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
134
136
|
# QWen2-VL related
|
135
137
|
image_grid_thws: List[Tuple[int, int, int]] = None
|
138
|
+
mrope_position_delta: Optional[torch.Tensor] = None
|
136
139
|
|
137
140
|
@staticmethod
|
138
141
|
def from_dict(obj, vocab_size):
|
@@ -211,7 +214,7 @@ class Req:
|
|
211
214
|
# this does not include the jump forward tokens.
|
212
215
|
self.completion_tokens_wo_jump_forward = 0
|
213
216
|
|
214
|
-
# For
|
217
|
+
# For multimodal inputs
|
215
218
|
self.image_inputs: Optional[ImageInputs] = None
|
216
219
|
|
217
220
|
# Prefix info
|
@@ -246,14 +249,11 @@ class Req:
|
|
246
249
|
self.embedding = None
|
247
250
|
|
248
251
|
# Constrained decoding
|
249
|
-
self.grammar: Optional[
|
252
|
+
self.grammar: Optional[BaseGrammarObject] = None
|
250
253
|
|
251
254
|
# The number of cached tokens, that were already cached in the KV cache
|
252
255
|
self.cached_tokens = 0
|
253
256
|
|
254
|
-
# For Qwen2-VL
|
255
|
-
self.mrope_position_delta = [] # use mutable object
|
256
|
-
|
257
257
|
# whether request reached finished condition
|
258
258
|
def finished(self) -> bool:
|
259
259
|
return self.finished_reason is not None
|
@@ -359,8 +359,6 @@ class Req:
|
|
359
359
|
return
|
360
360
|
|
361
361
|
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
362
|
-
assert self.grammar is not None and self.tokenizer is not None
|
363
|
-
|
364
362
|
if self.origin_input_text is None:
|
365
363
|
# Recovering text can only use unpadded ids
|
366
364
|
self.origin_input_text = self.tokenizer.decode(
|
@@ -809,9 +807,10 @@ class ScheduleBatch:
|
|
809
807
|
|
810
808
|
for i, req in enumerate(self.reqs):
|
811
809
|
if req.grammar is not None:
|
812
|
-
jump_helper = req.grammar.
|
813
|
-
if jump_helper
|
814
|
-
suffix_ids = jump_helper
|
810
|
+
jump_helper = req.grammar.try_jump_forward(req.tokenizer)
|
811
|
+
if jump_helper:
|
812
|
+
suffix_ids, _ = jump_helper
|
813
|
+
|
815
814
|
# Current ids, for cache and revert
|
816
815
|
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
817
816
|
cur_output_ids = req.output_ids
|
@@ -827,6 +826,8 @@ class ScheduleBatch:
|
|
827
826
|
next_state,
|
828
827
|
) = req.grammar.jump_forward_str_state(jump_helper)
|
829
828
|
|
829
|
+
# Make the incrementally decoded text part of jump_forward_str
|
830
|
+
# so that the UTF-8 will not corrupt
|
830
831
|
jump_forward_str = new_text + jump_forward_str
|
831
832
|
if not req.jump_forward_and_retokenize(
|
832
833
|
jump_forward_str, next_state
|
@@ -900,8 +901,7 @@ class ScheduleBatch:
|
|
900
901
|
keep_indices = [
|
901
902
|
i
|
902
903
|
for i in range(len(self.reqs))
|
903
|
-
if not self.reqs[i].finished()
|
904
|
-
and self.reqs[i] is not being_chunked_req
|
904
|
+
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
|
905
905
|
]
|
906
906
|
|
907
907
|
if keep_indices is None or len(keep_indices) == 0:
|
@@ -984,8 +984,6 @@ class ScheduleBatch:
|
|
984
984
|
global bid
|
985
985
|
bid += 1
|
986
986
|
|
987
|
-
mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]
|
988
|
-
|
989
987
|
return ModelWorkerBatch(
|
990
988
|
bid=bid,
|
991
989
|
forward_mode=self.forward_mode,
|
@@ -1008,7 +1006,6 @@ class ScheduleBatch:
|
|
1008
1006
|
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
1009
1007
|
lora_paths=[req.lora_path for req in self.reqs],
|
1010
1008
|
sampling_info=self.sampling_info,
|
1011
|
-
mrope_positions_delta=mrope_positions_delta,
|
1012
1009
|
)
|
1013
1010
|
|
1014
1011
|
def copy(self):
|
@@ -1075,9 +1072,6 @@ class ModelWorkerBatch:
|
|
1075
1072
|
# Sampling info
|
1076
1073
|
sampling_info: SamplingBatchInfo
|
1077
1074
|
|
1078
|
-
# For Qwen2-VL
|
1079
|
-
mrope_positions_delta: List[List[int]]
|
1080
|
-
|
1081
1075
|
def copy(self):
|
1082
1076
|
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
|
1083
1077
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -21,6 +21,7 @@ import threading
|
|
21
21
|
import time
|
22
22
|
import warnings
|
23
23
|
from collections import deque
|
24
|
+
from concurrent import futures
|
24
25
|
from types import SimpleNamespace
|
25
26
|
from typing import List, Optional
|
26
27
|
|
@@ -29,7 +30,6 @@ import zmq
|
|
29
30
|
|
30
31
|
from sglang.global_config import global_config
|
31
32
|
from sglang.srt.configs.model_config import ModelConfig
|
32
|
-
from sglang.srt.constrained.grammar import GrammarCache
|
33
33
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
34
34
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
35
35
|
from sglang.srt.managers.io_struct import (
|
@@ -62,6 +62,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
|
|
62
62
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
63
63
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
64
64
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
65
|
+
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
65
66
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
66
67
|
from sglang.srt.utils import (
|
67
68
|
broadcast_pyobj,
|
@@ -99,11 +100,12 @@ class Scheduler:
|
|
99
100
|
self.tp_rank = tp_rank
|
100
101
|
self.tp_size = server_args.tp_size
|
101
102
|
self.schedule_policy = server_args.schedule_policy
|
102
|
-
self.
|
103
|
+
self.disable_jump_forward = server_args.disable_jump_forward
|
103
104
|
self.lora_paths = server_args.lora_paths
|
104
105
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
105
106
|
self.enable_overlap = server_args.enable_overlap_schedule
|
106
107
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
108
|
+
self.enable_metrics = server_args.enable_metrics
|
107
109
|
|
108
110
|
# Init inter-process communication
|
109
111
|
context = zmq.Context(2)
|
@@ -222,7 +224,7 @@ class Scheduler:
|
|
222
224
|
self.forward_ct = 0
|
223
225
|
self.forward_ct_decode = 0
|
224
226
|
self.num_generated_tokens = 0
|
225
|
-
self.
|
227
|
+
self.last_decode_stats_tic = time.time()
|
226
228
|
self.stream_interval = server_args.stream_interval
|
227
229
|
|
228
230
|
# Init chunked prefill
|
@@ -232,21 +234,33 @@ class Scheduler:
|
|
232
234
|
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
233
235
|
)
|
234
236
|
|
235
|
-
# Init the
|
236
|
-
self.
|
237
|
-
|
237
|
+
# Init the grammar backend for constrained generation
|
238
|
+
self.grammar_queue: List[Req] = []
|
238
239
|
if not server_args.skip_tokenizer_init:
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
240
|
+
if server_args.grammar_backend == "outlines":
|
241
|
+
from sglang.srt.constrained.outlines_backend import (
|
242
|
+
OutlinesGrammarBackend,
|
243
|
+
)
|
244
|
+
|
245
|
+
self.grammar_backend = OutlinesGrammarBackend(
|
246
|
+
self.tokenizer,
|
247
|
+
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
248
|
+
allow_jump_forward=not server_args.disable_jump_forward,
|
249
|
+
)
|
250
|
+
elif server_args.grammar_backend == "xgrammar":
|
251
|
+
from sglang.srt.constrained.xgrammar_backend import (
|
252
|
+
XGrammarGrammarBackend,
|
253
|
+
)
|
254
|
+
|
255
|
+
self.grammar_backend = XGrammarGrammarBackend(
|
256
|
+
self.tokenizer, vocab_size=self.model_config.vocab_size
|
257
|
+
)
|
258
|
+
else:
|
259
|
+
raise ValueError(
|
260
|
+
f"Invalid grammar backend: {server_args.grammar_backend}"
|
261
|
+
)
|
262
|
+
else:
|
263
|
+
self.grammar_backend = None
|
250
264
|
|
251
265
|
# Init new token estimation
|
252
266
|
assert (
|
@@ -292,6 +306,16 @@ class Scheduler:
|
|
292
306
|
with_stack=True,
|
293
307
|
)
|
294
308
|
|
309
|
+
# Init metrics stats
|
310
|
+
self.stats = SchedulerStats()
|
311
|
+
if self.enable_metrics:
|
312
|
+
self.metrics_collector = SchedulerMetricsCollector(
|
313
|
+
labels={
|
314
|
+
"model_name": self.server_args.served_model_name,
|
315
|
+
# TODO: Add lora name/path in the future,
|
316
|
+
},
|
317
|
+
)
|
318
|
+
|
295
319
|
def watchdog_thread(self):
|
296
320
|
self.watchdog_last_forward_ct = 0
|
297
321
|
self.watchdog_last_time = time.time()
|
@@ -443,22 +467,6 @@ class Scheduler:
|
|
443
467
|
# By default, only return the logprobs for output tokens
|
444
468
|
req.logprob_start_len = len(recv_req.input_ids) - 1
|
445
469
|
|
446
|
-
# Init regex FSM or BNF
|
447
|
-
if (
|
448
|
-
req.sampling_params.json_schema is not None
|
449
|
-
or req.sampling_params.regex is not None
|
450
|
-
):
|
451
|
-
assert self.grammar_cache is not None
|
452
|
-
if req.sampling_params.json_schema is not None:
|
453
|
-
req.grammar = self.grammar_cache.query(
|
454
|
-
("json", req.sampling_params.json_schema),
|
455
|
-
self.model_config.vocab_size,
|
456
|
-
)
|
457
|
-
elif req.sampling_params.regex is not None:
|
458
|
-
req.grammar = self.grammar_cache.query(
|
459
|
-
("regex", req.sampling_params.regex), self.model_config.vocab_size
|
460
|
-
)
|
461
|
-
|
462
470
|
# Truncate prompts that are too long
|
463
471
|
if len(req.origin_input_ids) > self.max_req_input_len:
|
464
472
|
logger.warning(
|
@@ -476,7 +484,27 @@ class Scheduler:
|
|
476
484
|
self.max_req_len - len(req.origin_input_ids) - 1,
|
477
485
|
)
|
478
486
|
|
479
|
-
|
487
|
+
# Init grammar cache for this request
|
488
|
+
add_to_grammar_queue = False
|
489
|
+
if (
|
490
|
+
req.sampling_params.json_schema is not None
|
491
|
+
or req.sampling_params.regex is not None
|
492
|
+
):
|
493
|
+
assert self.grammar_backend is not None
|
494
|
+
if req.sampling_params.json_schema is not None:
|
495
|
+
key = ("json", req.sampling_params.json_schema)
|
496
|
+
elif req.sampling_params.regex is not None:
|
497
|
+
key = ("regex", req.sampling_params.regex)
|
498
|
+
|
499
|
+
req.grammar = self.grammar_backend.get_cached_value(key)
|
500
|
+
if not req.grammar:
|
501
|
+
req.grammar = self.grammar_backend.get_future_value(key)
|
502
|
+
add_to_grammar_queue = True
|
503
|
+
|
504
|
+
if add_to_grammar_queue:
|
505
|
+
self.grammar_queue.append(req)
|
506
|
+
else:
|
507
|
+
self.waiting_queue.append(req)
|
480
508
|
|
481
509
|
def handle_embedding_request(
|
482
510
|
self,
|
@@ -500,23 +528,68 @@ class Scheduler:
|
|
500
528
|
|
501
529
|
self.waiting_queue.append(req)
|
502
530
|
|
503
|
-
def
|
531
|
+
def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
|
532
|
+
if isinstance(self.tree_cache, RadixCache):
|
533
|
+
self.tree_cache_metrics["total"] += (
|
534
|
+
adder.log_input_tokens + adder.log_hit_tokens
|
535
|
+
) / 10**9
|
536
|
+
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
537
|
+
tree_cache_hit_rate = (
|
538
|
+
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
539
|
+
)
|
540
|
+
else:
|
541
|
+
tree_cache_hit_rate = 0.0
|
542
|
+
|
504
543
|
num_used = self.max_total_num_tokens - (
|
505
544
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
506
545
|
)
|
507
|
-
|
546
|
+
|
547
|
+
logger.info(
|
548
|
+
f"Prefill batch. "
|
549
|
+
f"#new-seq: {len(can_run_list)}, "
|
550
|
+
f"#new-token: {adder.log_input_tokens}, "
|
551
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
552
|
+
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
553
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
554
|
+
f"#running-req: {running_bs}, "
|
555
|
+
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
556
|
+
)
|
557
|
+
|
558
|
+
if self.enable_metrics:
|
559
|
+
self.stats.num_running_reqs = running_bs
|
560
|
+
self.stats.num_used_tokens = num_used
|
561
|
+
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
562
|
+
self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
|
563
|
+
self.stats.cache_hit_rate = tree_cache_hit_rate
|
564
|
+
self.metrics_collector.log_stats(self.stats)
|
565
|
+
|
566
|
+
def log_decode_stats(self):
|
567
|
+
num_used = self.max_total_num_tokens - (
|
568
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
569
|
+
)
|
570
|
+
gen_throughput = self.num_generated_tokens / (
|
571
|
+
time.time() - self.last_decode_stats_tic
|
572
|
+
)
|
508
573
|
self.num_generated_tokens = 0
|
509
|
-
self.
|
574
|
+
self.last_decode_stats_tic = time.time()
|
510
575
|
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
511
576
|
logger.info(
|
512
577
|
f"Decode batch. "
|
513
578
|
f"#running-req: {num_running_reqs}, "
|
514
579
|
f"#token: {num_used}, "
|
515
580
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
516
|
-
f"gen throughput (token/s): {
|
581
|
+
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
517
582
|
f"#queue-req: {len(self.waiting_queue)}"
|
518
583
|
)
|
519
584
|
|
585
|
+
if self.enable_metrics:
|
586
|
+
self.stats.num_running_reqs = num_running_reqs
|
587
|
+
self.stats.num_used_tokens = num_used
|
588
|
+
self.stats.token_usage = num_used / self.max_total_num_tokens
|
589
|
+
self.stats.gen_throughput = gen_throughput
|
590
|
+
self.stats.num_queue_reqs = len(self.waiting_queue)
|
591
|
+
self.metrics_collector.log_stats(self.stats)
|
592
|
+
|
520
593
|
def check_memory(self):
|
521
594
|
available_size = (
|
522
595
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
@@ -546,9 +619,7 @@ class Scheduler:
|
|
546
619
|
and not self.last_batch.is_empty()
|
547
620
|
):
|
548
621
|
if self.being_chunked_req:
|
549
|
-
self.last_batch.filter_batch(
|
550
|
-
being_chunked_req=self.being_chunked_req
|
551
|
-
)
|
622
|
+
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
|
552
623
|
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
553
624
|
# Inflight request keeps its rid but will get a new req_pool_idx.
|
554
625
|
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
@@ -579,6 +650,10 @@ class Scheduler:
|
|
579
650
|
return self.running_batch
|
580
651
|
|
581
652
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
653
|
+
# Check if the grammar is ready in the grammar queue
|
654
|
+
if self.grammar_queue:
|
655
|
+
self.move_ready_grammar_requests()
|
656
|
+
|
582
657
|
# Handle the cases where prefill is not allowed
|
583
658
|
if (
|
584
659
|
self.batch_is_full or len(self.waiting_queue) == 0
|
@@ -594,7 +669,6 @@ class Scheduler:
|
|
594
669
|
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
595
670
|
|
596
671
|
# Prefill policy
|
597
|
-
num_mixed_running = running_bs if self.is_mixed_chunk else 0
|
598
672
|
adder = PrefillAdder(
|
599
673
|
self.tree_cache,
|
600
674
|
self.running_batch,
|
@@ -602,15 +676,13 @@ class Scheduler:
|
|
602
676
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
603
677
|
self.max_prefill_tokens,
|
604
678
|
self.chunked_prefill_size,
|
605
|
-
|
679
|
+
running_bs if self.is_mixed_chunk else 0,
|
606
680
|
)
|
607
681
|
|
608
682
|
has_inflight = self.being_chunked_req is not None
|
609
683
|
if has_inflight:
|
610
684
|
self.being_chunked_req.init_next_round_input()
|
611
|
-
self.being_chunked_req = adder.add_inflight_req(
|
612
|
-
self.being_chunked_req
|
613
|
-
)
|
685
|
+
self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
|
614
686
|
|
615
687
|
if self.lora_paths:
|
616
688
|
lora_set = (
|
@@ -661,44 +733,7 @@ class Scheduler:
|
|
661
733
|
|
662
734
|
# Print stats
|
663
735
|
if self.tp_rank == 0:
|
664
|
-
|
665
|
-
self.tree_cache_metrics["total"] += (
|
666
|
-
adder.log_input_tokens + adder.log_hit_tokens
|
667
|
-
) / 10**9
|
668
|
-
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
669
|
-
tree_cache_hit_rate = (
|
670
|
-
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
671
|
-
)
|
672
|
-
else:
|
673
|
-
tree_cache_hit_rate = 0.0
|
674
|
-
|
675
|
-
num_used = self.max_total_num_tokens - (
|
676
|
-
self.token_to_kv_pool.available_size()
|
677
|
-
+ self.tree_cache.evictable_size()
|
678
|
-
)
|
679
|
-
|
680
|
-
if num_mixed_running > 0:
|
681
|
-
logger.info(
|
682
|
-
f"Prefill batch"
|
683
|
-
f"(mixed #running-req: {num_mixed_running}). "
|
684
|
-
f"#new-seq: {len(can_run_list)}, "
|
685
|
-
f"#new-token: {adder.log_input_tokens}, "
|
686
|
-
f"#cached-token: {adder.log_hit_tokens}, "
|
687
|
-
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
688
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
689
|
-
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
690
|
-
)
|
691
|
-
else:
|
692
|
-
logger.info(
|
693
|
-
f"Prefill batch. "
|
694
|
-
f"#new-seq: {len(can_run_list)}, "
|
695
|
-
f"#new-token: {adder.log_input_tokens}, "
|
696
|
-
f"#cached-token: {adder.log_hit_tokens}, "
|
697
|
-
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
698
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
699
|
-
f"#running-req: {running_bs}, "
|
700
|
-
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
701
|
-
)
|
736
|
+
self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
|
702
737
|
|
703
738
|
# Create a new batch
|
704
739
|
new_batch = ScheduleBatch.init_new(
|
@@ -753,7 +788,7 @@ class Scheduler:
|
|
753
788
|
)
|
754
789
|
|
755
790
|
# Check for jump-forward
|
756
|
-
if not self.
|
791
|
+
if not self.disable_jump_forward:
|
757
792
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
758
793
|
self.waiting_queue.extend(jump_forward_reqs)
|
759
794
|
if batch.is_empty():
|
@@ -768,8 +803,8 @@ class Scheduler:
|
|
768
803
|
self.forward_ct += 1
|
769
804
|
|
770
805
|
if self.is_generation:
|
806
|
+
model_worker_batch = batch.get_model_worker_batch()
|
771
807
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
772
|
-
model_worker_batch = batch.get_model_worker_batch()
|
773
808
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
774
809
|
model_worker_batch
|
775
810
|
)
|
@@ -897,9 +932,7 @@ class Scheduler:
|
|
897
932
|
if req.is_retracted:
|
898
933
|
continue
|
899
934
|
|
900
|
-
if self.server_args.enable_overlap_schedule and (
|
901
|
-
req.finished()
|
902
|
-
):
|
935
|
+
if self.server_args.enable_overlap_schedule and (req.finished()):
|
903
936
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
904
937
|
continue
|
905
938
|
|
@@ -925,8 +958,11 @@ class Scheduler:
|
|
925
958
|
self.token_to_kv_pool.free_group_end()
|
926
959
|
|
927
960
|
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
928
|
-
if
|
929
|
-
self.
|
961
|
+
if (
|
962
|
+
self.tp_rank == 0
|
963
|
+
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
964
|
+
):
|
965
|
+
self.log_decode_stats()
|
930
966
|
|
931
967
|
def add_logprob_return_values(
|
932
968
|
self,
|
@@ -1104,6 +1140,30 @@ class Scheduler:
|
|
1104
1140
|
)
|
1105
1141
|
)
|
1106
1142
|
|
1143
|
+
def move_ready_grammar_requests(self):
|
1144
|
+
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
1145
|
+
num_ready_reqs = 0
|
1146
|
+
for req in self.grammar_queue:
|
1147
|
+
try:
|
1148
|
+
req.grammar = req.grammar.result(timeout=0.05)
|
1149
|
+
num_ready_reqs += 1
|
1150
|
+
except futures._base.TimeoutError:
|
1151
|
+
break
|
1152
|
+
|
1153
|
+
if self.tp_size > 1:
|
1154
|
+
# Sync across TP ranks to make sure they have the same number of ready requests
|
1155
|
+
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
1156
|
+
torch.distributed.all_reduce(
|
1157
|
+
tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
|
1158
|
+
)
|
1159
|
+
num_ready_reqs_max = tensor.item()
|
1160
|
+
for i in range(num_ready_reqs, num_ready_reqs_max):
|
1161
|
+
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
|
1162
|
+
num_ready_reqs = num_ready_reqs_max
|
1163
|
+
|
1164
|
+
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
|
1165
|
+
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
1166
|
+
|
1107
1167
|
def flush_cache(self):
|
1108
1168
|
"""Flush the memory pool and cache."""
|
1109
1169
|
if len(self.waiting_queue) == 0 and (
|
@@ -1111,9 +1171,8 @@ class Scheduler:
|
|
1111
1171
|
):
|
1112
1172
|
self.tree_cache.reset()
|
1113
1173
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
1114
|
-
if self.
|
1115
|
-
self.
|
1116
|
-
# TODO(dark): reset the bnf cache
|
1174
|
+
if self.grammar_backend:
|
1175
|
+
self.grammar_backend.reset()
|
1117
1176
|
self.req_to_token_pool.clear()
|
1118
1177
|
self.token_to_kv_pool.clear()
|
1119
1178
|
torch.cuda.empty_cache()
|