sglang 0.3.5__py3-none-any.whl → 0.3.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +309 -0
- sglang/bench_serving.py +148 -24
- sglang/srt/configs/model_config.py +5 -2
- sglang/srt/constrained/__init__.py +2 -66
- sglang/srt/constrained/base_grammar_backend.py +73 -0
- sglang/srt/constrained/outlines_backend.py +165 -0
- sglang/srt/constrained/outlines_jump_forward.py +182 -0
- sglang/srt/constrained/xgrammar_backend.py +150 -0
- sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
- sglang/srt/layers/fused_moe/fused_moe.py +23 -7
- sglang/srt/layers/fused_moe/patch.py +4 -2
- sglang/srt/layers/quantization/base_config.py +4 -6
- sglang/srt/layers/vocab_parallel_embedding.py +216 -150
- sglang/srt/managers/detokenizer_manager.py +0 -14
- sglang/srt/managers/io_struct.py +5 -3
- sglang/srt/managers/schedule_batch.py +14 -20
- sglang/srt/managers/scheduler.py +159 -96
- sglang/srt/managers/tokenizer_manager.py +81 -17
- sglang/srt/metrics/collector.py +211 -0
- sglang/srt/metrics/func_timer.py +108 -0
- sglang/srt/mm_utils.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/forward_batch_info.py +7 -3
- sglang/srt/model_executor/model_runner.py +6 -2
- sglang/srt/models/gemma2_reward.py +69 -0
- sglang/srt/models/gpt2.py +31 -37
- sglang/srt/models/internlm2_reward.py +62 -0
- sglang/srt/models/llama.py +11 -6
- sglang/srt/models/llama_reward.py +5 -26
- sglang/srt/models/qwen2_vl.py +5 -7
- sglang/srt/openai_api/adapter.py +11 -4
- sglang/srt/openai_api/protocol.py +29 -26
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/sampling/sampling_params.py +2 -16
- sglang/srt/server.py +60 -17
- sglang/srt/server_args.py +66 -25
- sglang/srt/utils.py +120 -0
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_mgsm.py +2 -2
- sglang/test/test_utils.py +21 -7
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/METADATA +12 -8
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/RECORD +49 -45
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/WHEEL +1 -1
- sglang/srt/constrained/base_tool_cache.py +0 -65
- sglang/srt/constrained/bnf_cache.py +0 -61
- sglang/srt/constrained/fsm_cache.py +0 -95
- sglang/srt/constrained/grammar.py +0 -190
- sglang/srt/constrained/jump_forward.py +0 -203
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/top_level.txt +0 -0
@@ -37,7 +37,7 @@ import torch
|
|
37
37
|
|
38
38
|
from sglang.global_config import global_config
|
39
39
|
from sglang.srt.configs.model_config import ModelConfig
|
40
|
-
from sglang.srt.constrained.
|
40
|
+
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
41
41
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
42
42
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
43
43
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
@@ -107,12 +107,14 @@ class FINISH_LENGTH(BaseFinishReason):
|
|
107
107
|
|
108
108
|
|
109
109
|
class FINISH_ABORT(BaseFinishReason):
|
110
|
-
def __init__(self):
|
110
|
+
def __init__(self, message="Unknown error"):
|
111
111
|
super().__init__(is_error=True)
|
112
|
+
self.message = message
|
112
113
|
|
113
114
|
def to_json(self):
|
114
115
|
return {
|
115
116
|
"type": "abort",
|
117
|
+
"message": self.message,
|
116
118
|
}
|
117
119
|
|
118
120
|
|
@@ -133,6 +135,7 @@ class ImageInputs:
|
|
133
135
|
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
134
136
|
# QWen2-VL related
|
135
137
|
image_grid_thws: List[Tuple[int, int, int]] = None
|
138
|
+
mrope_position_delta: Optional[torch.Tensor] = None
|
136
139
|
|
137
140
|
@staticmethod
|
138
141
|
def from_dict(obj, vocab_size):
|
@@ -211,7 +214,7 @@ class Req:
|
|
211
214
|
# this does not include the jump forward tokens.
|
212
215
|
self.completion_tokens_wo_jump_forward = 0
|
213
216
|
|
214
|
-
# For
|
217
|
+
# For multimodal inputs
|
215
218
|
self.image_inputs: Optional[ImageInputs] = None
|
216
219
|
|
217
220
|
# Prefix info
|
@@ -246,14 +249,11 @@ class Req:
|
|
246
249
|
self.embedding = None
|
247
250
|
|
248
251
|
# Constrained decoding
|
249
|
-
self.grammar: Optional[
|
252
|
+
self.grammar: Optional[BaseGrammarObject] = None
|
250
253
|
|
251
254
|
# The number of cached tokens, that were already cached in the KV cache
|
252
255
|
self.cached_tokens = 0
|
253
256
|
|
254
|
-
# For Qwen2-VL
|
255
|
-
self.mrope_position_delta = [] # use mutable object
|
256
|
-
|
257
257
|
# whether request reached finished condition
|
258
258
|
def finished(self) -> bool:
|
259
259
|
return self.finished_reason is not None
|
@@ -359,8 +359,6 @@ class Req:
|
|
359
359
|
return
|
360
360
|
|
361
361
|
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
362
|
-
assert self.grammar is not None and self.tokenizer is not None
|
363
|
-
|
364
362
|
if self.origin_input_text is None:
|
365
363
|
# Recovering text can only use unpadded ids
|
366
364
|
self.origin_input_text = self.tokenizer.decode(
|
@@ -809,9 +807,10 @@ class ScheduleBatch:
|
|
809
807
|
|
810
808
|
for i, req in enumerate(self.reqs):
|
811
809
|
if req.grammar is not None:
|
812
|
-
jump_helper = req.grammar.
|
813
|
-
if jump_helper
|
814
|
-
suffix_ids = jump_helper
|
810
|
+
jump_helper = req.grammar.try_jump_forward(req.tokenizer)
|
811
|
+
if jump_helper:
|
812
|
+
suffix_ids, _ = jump_helper
|
813
|
+
|
815
814
|
# Current ids, for cache and revert
|
816
815
|
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
817
816
|
cur_output_ids = req.output_ids
|
@@ -827,6 +826,8 @@ class ScheduleBatch:
|
|
827
826
|
next_state,
|
828
827
|
) = req.grammar.jump_forward_str_state(jump_helper)
|
829
828
|
|
829
|
+
# Make the incrementally decoded text part of jump_forward_str
|
830
|
+
# so that the UTF-8 will not corrupt
|
830
831
|
jump_forward_str = new_text + jump_forward_str
|
831
832
|
if not req.jump_forward_and_retokenize(
|
832
833
|
jump_forward_str, next_state
|
@@ -900,8 +901,7 @@ class ScheduleBatch:
|
|
900
901
|
keep_indices = [
|
901
902
|
i
|
902
903
|
for i in range(len(self.reqs))
|
903
|
-
if not self.reqs[i].finished()
|
904
|
-
and self.reqs[i] is not being_chunked_req
|
904
|
+
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
|
905
905
|
]
|
906
906
|
|
907
907
|
if keep_indices is None or len(keep_indices) == 0:
|
@@ -984,8 +984,6 @@ class ScheduleBatch:
|
|
984
984
|
global bid
|
985
985
|
bid += 1
|
986
986
|
|
987
|
-
mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]
|
988
|
-
|
989
987
|
return ModelWorkerBatch(
|
990
988
|
bid=bid,
|
991
989
|
forward_mode=self.forward_mode,
|
@@ -1008,7 +1006,6 @@ class ScheduleBatch:
|
|
1008
1006
|
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
1009
1007
|
lora_paths=[req.lora_path for req in self.reqs],
|
1010
1008
|
sampling_info=self.sampling_info,
|
1011
|
-
mrope_positions_delta=mrope_positions_delta,
|
1012
1009
|
)
|
1013
1010
|
|
1014
1011
|
def copy(self):
|
@@ -1075,9 +1072,6 @@ class ModelWorkerBatch:
|
|
1075
1072
|
# Sampling info
|
1076
1073
|
sampling_info: SamplingBatchInfo
|
1077
1074
|
|
1078
|
-
# For Qwen2-VL
|
1079
|
-
mrope_positions_delta: List[List[int]]
|
1080
|
-
|
1081
1075
|
def copy(self):
|
1082
1076
|
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
|
1083
1077
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -21,6 +21,7 @@ import threading
|
|
21
21
|
import time
|
22
22
|
import warnings
|
23
23
|
from collections import deque
|
24
|
+
from concurrent import futures
|
24
25
|
from types import SimpleNamespace
|
25
26
|
from typing import List, Optional
|
26
27
|
|
@@ -29,7 +30,6 @@ import zmq
|
|
29
30
|
|
30
31
|
from sglang.global_config import global_config
|
31
32
|
from sglang.srt.configs.model_config import ModelConfig
|
32
|
-
from sglang.srt.constrained.grammar import GrammarCache
|
33
33
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
34
34
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
35
35
|
from sglang.srt.managers.io_struct import (
|
@@ -62,6 +62,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
|
|
62
62
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
63
63
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
64
64
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
65
|
+
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
65
66
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
66
67
|
from sglang.srt.utils import (
|
67
68
|
broadcast_pyobj,
|
@@ -99,11 +100,12 @@ class Scheduler:
|
|
99
100
|
self.tp_rank = tp_rank
|
100
101
|
self.tp_size = server_args.tp_size
|
101
102
|
self.schedule_policy = server_args.schedule_policy
|
102
|
-
self.
|
103
|
+
self.disable_jump_forward = server_args.disable_jump_forward
|
103
104
|
self.lora_paths = server_args.lora_paths
|
104
105
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
105
106
|
self.enable_overlap = server_args.enable_overlap_schedule
|
106
107
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
108
|
+
self.enable_metrics = server_args.enable_metrics
|
107
109
|
|
108
110
|
# Init inter-process communication
|
109
111
|
context = zmq.Context(2)
|
@@ -112,6 +114,9 @@ class Scheduler:
|
|
112
114
|
self.recv_from_tokenizer = get_zmq_socket(
|
113
115
|
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
114
116
|
)
|
117
|
+
self.send_to_tokenizer = get_zmq_socket(
|
118
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
119
|
+
)
|
115
120
|
|
116
121
|
if server_args.skip_tokenizer_init:
|
117
122
|
# Directly send to the tokenizer/api
|
@@ -125,6 +130,7 @@ class Scheduler:
|
|
125
130
|
)
|
126
131
|
else:
|
127
132
|
self.recv_from_tokenizer = None
|
133
|
+
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
128
134
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
129
135
|
|
130
136
|
# Init tokenizer
|
@@ -222,7 +228,7 @@ class Scheduler:
|
|
222
228
|
self.forward_ct = 0
|
223
229
|
self.forward_ct_decode = 0
|
224
230
|
self.num_generated_tokens = 0
|
225
|
-
self.
|
231
|
+
self.last_decode_stats_tic = time.time()
|
226
232
|
self.stream_interval = server_args.stream_interval
|
227
233
|
|
228
234
|
# Init chunked prefill
|
@@ -232,21 +238,33 @@ class Scheduler:
|
|
232
238
|
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
233
239
|
)
|
234
240
|
|
235
|
-
# Init the
|
236
|
-
self.
|
237
|
-
|
241
|
+
# Init the grammar backend for constrained generation
|
242
|
+
self.grammar_queue: List[Req] = []
|
238
243
|
if not server_args.skip_tokenizer_init:
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
244
|
+
if server_args.grammar_backend == "outlines":
|
245
|
+
from sglang.srt.constrained.outlines_backend import (
|
246
|
+
OutlinesGrammarBackend,
|
247
|
+
)
|
248
|
+
|
249
|
+
self.grammar_backend = OutlinesGrammarBackend(
|
250
|
+
self.tokenizer,
|
251
|
+
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
252
|
+
allow_jump_forward=not server_args.disable_jump_forward,
|
253
|
+
)
|
254
|
+
elif server_args.grammar_backend == "xgrammar":
|
255
|
+
from sglang.srt.constrained.xgrammar_backend import (
|
256
|
+
XGrammarGrammarBackend,
|
257
|
+
)
|
258
|
+
|
259
|
+
self.grammar_backend = XGrammarGrammarBackend(
|
260
|
+
self.tokenizer, vocab_size=self.model_config.vocab_size
|
261
|
+
)
|
262
|
+
else:
|
263
|
+
raise ValueError(
|
264
|
+
f"Invalid grammar backend: {server_args.grammar_backend}"
|
265
|
+
)
|
266
|
+
else:
|
267
|
+
self.grammar_backend = None
|
250
268
|
|
251
269
|
# Init new token estimation
|
252
270
|
assert (
|
@@ -292,6 +310,16 @@ class Scheduler:
|
|
292
310
|
with_stack=True,
|
293
311
|
)
|
294
312
|
|
313
|
+
# Init metrics stats
|
314
|
+
self.stats = SchedulerStats()
|
315
|
+
if self.enable_metrics:
|
316
|
+
self.metrics_collector = SchedulerMetricsCollector(
|
317
|
+
labels={
|
318
|
+
"model_name": self.server_args.served_model_name,
|
319
|
+
# TODO: Add lora name/path in the future,
|
320
|
+
},
|
321
|
+
)
|
322
|
+
|
295
323
|
def watchdog_thread(self):
|
296
324
|
self.watchdog_last_forward_ct = 0
|
297
325
|
self.watchdog_last_time = time.time()
|
@@ -397,7 +425,7 @@ class Scheduler:
|
|
397
425
|
self.abort_request(recv_req)
|
398
426
|
elif isinstance(recv_req, UpdateWeightReqInput):
|
399
427
|
success, message = self.update_weights(recv_req)
|
400
|
-
self.
|
428
|
+
self.send_to_tokenizer.send_pyobj(
|
401
429
|
UpdateWeightReqOutput(success, message)
|
402
430
|
)
|
403
431
|
elif isinstance(recv_req, ProfileReq):
|
@@ -406,7 +434,7 @@ class Scheduler:
|
|
406
434
|
else:
|
407
435
|
self.stop_profile()
|
408
436
|
elif isinstance(recv_req, GetMemPoolSizeReq):
|
409
|
-
self.
|
437
|
+
self.send_to_tokenizer.send_pyobj(
|
410
438
|
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
411
439
|
)
|
412
440
|
else:
|
@@ -443,22 +471,6 @@ class Scheduler:
|
|
443
471
|
# By default, only return the logprobs for output tokens
|
444
472
|
req.logprob_start_len = len(recv_req.input_ids) - 1
|
445
473
|
|
446
|
-
# Init regex FSM or BNF
|
447
|
-
if (
|
448
|
-
req.sampling_params.json_schema is not None
|
449
|
-
or req.sampling_params.regex is not None
|
450
|
-
):
|
451
|
-
assert self.grammar_cache is not None
|
452
|
-
if req.sampling_params.json_schema is not None:
|
453
|
-
req.grammar = self.grammar_cache.query(
|
454
|
-
("json", req.sampling_params.json_schema),
|
455
|
-
self.model_config.vocab_size,
|
456
|
-
)
|
457
|
-
elif req.sampling_params.regex is not None:
|
458
|
-
req.grammar = self.grammar_cache.query(
|
459
|
-
("regex", req.sampling_params.regex), self.model_config.vocab_size
|
460
|
-
)
|
461
|
-
|
462
474
|
# Truncate prompts that are too long
|
463
475
|
if len(req.origin_input_ids) > self.max_req_input_len:
|
464
476
|
logger.warning(
|
@@ -476,7 +488,27 @@ class Scheduler:
|
|
476
488
|
self.max_req_len - len(req.origin_input_ids) - 1,
|
477
489
|
)
|
478
490
|
|
479
|
-
|
491
|
+
# Init grammar cache for this request
|
492
|
+
add_to_grammar_queue = False
|
493
|
+
if (
|
494
|
+
req.sampling_params.json_schema is not None
|
495
|
+
or req.sampling_params.regex is not None
|
496
|
+
):
|
497
|
+
assert self.grammar_backend is not None
|
498
|
+
if req.sampling_params.json_schema is not None:
|
499
|
+
key = ("json", req.sampling_params.json_schema)
|
500
|
+
elif req.sampling_params.regex is not None:
|
501
|
+
key = ("regex", req.sampling_params.regex)
|
502
|
+
|
503
|
+
req.grammar = self.grammar_backend.get_cached_value(key)
|
504
|
+
if not req.grammar:
|
505
|
+
req.grammar = self.grammar_backend.get_future_value(key)
|
506
|
+
add_to_grammar_queue = True
|
507
|
+
|
508
|
+
if add_to_grammar_queue:
|
509
|
+
self.grammar_queue.append(req)
|
510
|
+
else:
|
511
|
+
self.waiting_queue.append(req)
|
480
512
|
|
481
513
|
def handle_embedding_request(
|
482
514
|
self,
|
@@ -500,23 +532,68 @@ class Scheduler:
|
|
500
532
|
|
501
533
|
self.waiting_queue.append(req)
|
502
534
|
|
503
|
-
def
|
535
|
+
def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
|
536
|
+
if isinstance(self.tree_cache, RadixCache):
|
537
|
+
self.tree_cache_metrics["total"] += (
|
538
|
+
adder.log_input_tokens + adder.log_hit_tokens
|
539
|
+
) / 10**9
|
540
|
+
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
541
|
+
tree_cache_hit_rate = (
|
542
|
+
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
543
|
+
)
|
544
|
+
else:
|
545
|
+
tree_cache_hit_rate = 0.0
|
546
|
+
|
547
|
+
num_used = self.max_total_num_tokens - (
|
548
|
+
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
549
|
+
)
|
550
|
+
|
551
|
+
logger.info(
|
552
|
+
f"Prefill batch. "
|
553
|
+
f"#new-seq: {len(can_run_list)}, "
|
554
|
+
f"#new-token: {adder.log_input_tokens}, "
|
555
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
556
|
+
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
557
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
558
|
+
f"#running-req: {running_bs}, "
|
559
|
+
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
560
|
+
)
|
561
|
+
|
562
|
+
if self.enable_metrics:
|
563
|
+
self.stats.num_running_reqs = running_bs
|
564
|
+
self.stats.num_used_tokens = num_used
|
565
|
+
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
566
|
+
self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
|
567
|
+
self.stats.cache_hit_rate = tree_cache_hit_rate
|
568
|
+
self.metrics_collector.log_stats(self.stats)
|
569
|
+
|
570
|
+
def log_decode_stats(self):
|
504
571
|
num_used = self.max_total_num_tokens - (
|
505
572
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
506
573
|
)
|
507
|
-
|
574
|
+
gen_throughput = self.num_generated_tokens / (
|
575
|
+
time.time() - self.last_decode_stats_tic
|
576
|
+
)
|
508
577
|
self.num_generated_tokens = 0
|
509
|
-
self.
|
578
|
+
self.last_decode_stats_tic = time.time()
|
510
579
|
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
511
580
|
logger.info(
|
512
581
|
f"Decode batch. "
|
513
582
|
f"#running-req: {num_running_reqs}, "
|
514
583
|
f"#token: {num_used}, "
|
515
584
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
516
|
-
f"gen throughput (token/s): {
|
585
|
+
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
517
586
|
f"#queue-req: {len(self.waiting_queue)}"
|
518
587
|
)
|
519
588
|
|
589
|
+
if self.enable_metrics:
|
590
|
+
self.stats.num_running_reqs = num_running_reqs
|
591
|
+
self.stats.num_used_tokens = num_used
|
592
|
+
self.stats.token_usage = num_used / self.max_total_num_tokens
|
593
|
+
self.stats.gen_throughput = gen_throughput
|
594
|
+
self.stats.num_queue_reqs = len(self.waiting_queue)
|
595
|
+
self.metrics_collector.log_stats(self.stats)
|
596
|
+
|
520
597
|
def check_memory(self):
|
521
598
|
available_size = (
|
522
599
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
@@ -546,9 +623,7 @@ class Scheduler:
|
|
546
623
|
and not self.last_batch.is_empty()
|
547
624
|
):
|
548
625
|
if self.being_chunked_req:
|
549
|
-
self.last_batch.filter_batch(
|
550
|
-
being_chunked_req=self.being_chunked_req
|
551
|
-
)
|
626
|
+
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
|
552
627
|
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
553
628
|
# Inflight request keeps its rid but will get a new req_pool_idx.
|
554
629
|
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
@@ -579,6 +654,10 @@ class Scheduler:
|
|
579
654
|
return self.running_batch
|
580
655
|
|
581
656
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
657
|
+
# Check if the grammar is ready in the grammar queue
|
658
|
+
if self.grammar_queue:
|
659
|
+
self.move_ready_grammar_requests()
|
660
|
+
|
582
661
|
# Handle the cases where prefill is not allowed
|
583
662
|
if (
|
584
663
|
self.batch_is_full or len(self.waiting_queue) == 0
|
@@ -594,7 +673,6 @@ class Scheduler:
|
|
594
673
|
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
595
674
|
|
596
675
|
# Prefill policy
|
597
|
-
num_mixed_running = running_bs if self.is_mixed_chunk else 0
|
598
676
|
adder = PrefillAdder(
|
599
677
|
self.tree_cache,
|
600
678
|
self.running_batch,
|
@@ -602,15 +680,13 @@ class Scheduler:
|
|
602
680
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
603
681
|
self.max_prefill_tokens,
|
604
682
|
self.chunked_prefill_size,
|
605
|
-
|
683
|
+
running_bs if self.is_mixed_chunk else 0,
|
606
684
|
)
|
607
685
|
|
608
686
|
has_inflight = self.being_chunked_req is not None
|
609
687
|
if has_inflight:
|
610
688
|
self.being_chunked_req.init_next_round_input()
|
611
|
-
self.being_chunked_req = adder.add_inflight_req(
|
612
|
-
self.being_chunked_req
|
613
|
-
)
|
689
|
+
self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
|
614
690
|
|
615
691
|
if self.lora_paths:
|
616
692
|
lora_set = (
|
@@ -661,44 +737,7 @@ class Scheduler:
|
|
661
737
|
|
662
738
|
# Print stats
|
663
739
|
if self.tp_rank == 0:
|
664
|
-
|
665
|
-
self.tree_cache_metrics["total"] += (
|
666
|
-
adder.log_input_tokens + adder.log_hit_tokens
|
667
|
-
) / 10**9
|
668
|
-
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
669
|
-
tree_cache_hit_rate = (
|
670
|
-
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
671
|
-
)
|
672
|
-
else:
|
673
|
-
tree_cache_hit_rate = 0.0
|
674
|
-
|
675
|
-
num_used = self.max_total_num_tokens - (
|
676
|
-
self.token_to_kv_pool.available_size()
|
677
|
-
+ self.tree_cache.evictable_size()
|
678
|
-
)
|
679
|
-
|
680
|
-
if num_mixed_running > 0:
|
681
|
-
logger.info(
|
682
|
-
f"Prefill batch"
|
683
|
-
f"(mixed #running-req: {num_mixed_running}). "
|
684
|
-
f"#new-seq: {len(can_run_list)}, "
|
685
|
-
f"#new-token: {adder.log_input_tokens}, "
|
686
|
-
f"#cached-token: {adder.log_hit_tokens}, "
|
687
|
-
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
688
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
689
|
-
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
690
|
-
)
|
691
|
-
else:
|
692
|
-
logger.info(
|
693
|
-
f"Prefill batch. "
|
694
|
-
f"#new-seq: {len(can_run_list)}, "
|
695
|
-
f"#new-token: {adder.log_input_tokens}, "
|
696
|
-
f"#cached-token: {adder.log_hit_tokens}, "
|
697
|
-
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
698
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
699
|
-
f"#running-req: {running_bs}, "
|
700
|
-
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
701
|
-
)
|
740
|
+
self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
|
702
741
|
|
703
742
|
# Create a new batch
|
704
743
|
new_batch = ScheduleBatch.init_new(
|
@@ -753,7 +792,7 @@ class Scheduler:
|
|
753
792
|
)
|
754
793
|
|
755
794
|
# Check for jump-forward
|
756
|
-
if not self.
|
795
|
+
if not self.disable_jump_forward:
|
757
796
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
758
797
|
self.waiting_queue.extend(jump_forward_reqs)
|
759
798
|
if batch.is_empty():
|
@@ -768,8 +807,8 @@ class Scheduler:
|
|
768
807
|
self.forward_ct += 1
|
769
808
|
|
770
809
|
if self.is_generation:
|
810
|
+
model_worker_batch = batch.get_model_worker_batch()
|
771
811
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
772
|
-
model_worker_batch = batch.get_model_worker_batch()
|
773
812
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
774
813
|
model_worker_batch
|
775
814
|
)
|
@@ -897,9 +936,7 @@ class Scheduler:
|
|
897
936
|
if req.is_retracted:
|
898
937
|
continue
|
899
938
|
|
900
|
-
if self.server_args.enable_overlap_schedule and (
|
901
|
-
req.finished()
|
902
|
-
):
|
939
|
+
if self.server_args.enable_overlap_schedule and (req.finished()):
|
903
940
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
904
941
|
continue
|
905
942
|
|
@@ -925,8 +962,11 @@ class Scheduler:
|
|
925
962
|
self.token_to_kv_pool.free_group_end()
|
926
963
|
|
927
964
|
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
928
|
-
if
|
929
|
-
self.
|
965
|
+
if (
|
966
|
+
self.tp_rank == 0
|
967
|
+
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
968
|
+
):
|
969
|
+
self.log_decode_stats()
|
930
970
|
|
931
971
|
def add_logprob_return_values(
|
932
972
|
self,
|
@@ -1104,6 +1144,30 @@ class Scheduler:
|
|
1104
1144
|
)
|
1105
1145
|
)
|
1106
1146
|
|
1147
|
+
def move_ready_grammar_requests(self):
|
1148
|
+
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
1149
|
+
num_ready_reqs = 0
|
1150
|
+
for req in self.grammar_queue:
|
1151
|
+
try:
|
1152
|
+
req.grammar = req.grammar.result(timeout=0.05)
|
1153
|
+
num_ready_reqs += 1
|
1154
|
+
except futures._base.TimeoutError:
|
1155
|
+
break
|
1156
|
+
|
1157
|
+
if self.tp_size > 1:
|
1158
|
+
# Sync across TP ranks to make sure they have the same number of ready requests
|
1159
|
+
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
1160
|
+
torch.distributed.all_reduce(
|
1161
|
+
tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
|
1162
|
+
)
|
1163
|
+
num_ready_reqs_max = tensor.item()
|
1164
|
+
for i in range(num_ready_reqs, num_ready_reqs_max):
|
1165
|
+
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
|
1166
|
+
num_ready_reqs = num_ready_reqs_max
|
1167
|
+
|
1168
|
+
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
|
1169
|
+
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
1170
|
+
|
1107
1171
|
def flush_cache(self):
|
1108
1172
|
"""Flush the memory pool and cache."""
|
1109
1173
|
if len(self.waiting_queue) == 0 and (
|
@@ -1111,9 +1175,8 @@ class Scheduler:
|
|
1111
1175
|
):
|
1112
1176
|
self.tree_cache.reset()
|
1113
1177
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
1114
|
-
if self.
|
1115
|
-
self.
|
1116
|
-
# TODO(dark): reset the bnf cache
|
1178
|
+
if self.grammar_backend:
|
1179
|
+
self.grammar_backend.reset()
|
1117
1180
|
self.req_to_token_pool.clear()
|
1118
1181
|
self.token_to_kv_pool.clear()
|
1119
1182
|
torch.cuda.empty_cache()
|