sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_latency.py +3 -3
- sglang/bench_server_latency.py +2 -3
- sglang/bench_serving.py +92 -0
- sglang/global_config.py +9 -3
- sglang/lang/chat_template.py +50 -25
- sglang/lang/interpreter.py +9 -1
- sglang/lang/ir.py +11 -2
- sglang/launch_server.py +1 -1
- sglang/srt/configs/model_config.py +76 -15
- sglang/srt/constrained/__init__.py +18 -0
- sglang/srt/constrained/bnf_cache.py +61 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/constrained/grammar.py +190 -0
- sglang/srt/hf_transformers_utils.py +20 -5
- sglang/srt/layers/attention/flashinfer_backend.py +5 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
- sglang/srt/layers/fused_moe/fused_moe.py +4 -3
- sglang/srt/layers/fused_moe/layer.py +28 -0
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/quantization/base_config.py +16 -1
- sglang/srt/layers/rotary_embedding.py +15 -48
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/layers/vocab_parallel_embedding.py +486 -0
- sglang/srt/managers/data_parallel_controller.py +8 -7
- sglang/srt/managers/detokenizer_manager.py +11 -9
- sglang/srt/managers/image_processor.py +4 -3
- sglang/srt/managers/io_struct.py +80 -78
- sglang/srt/managers/schedule_batch.py +46 -52
- sglang/srt/managers/schedule_policy.py +24 -13
- sglang/srt/managers/scheduler.py +145 -82
- sglang/srt/managers/tokenizer_manager.py +236 -334
- sglang/srt/managers/tp_worker.py +5 -5
- sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
- sglang/srt/mem_cache/flush_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +10 -3
- sglang/srt/model_executor/cuda_graph_runner.py +34 -23
- sglang/srt/model_executor/forward_batch_info.py +6 -9
- sglang/srt/model_executor/model_runner.py +10 -19
- sglang/srt/models/baichuan.py +4 -4
- sglang/srt/models/chatglm.py +4 -4
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +5 -5
- sglang/srt/models/deepseek.py +4 -4
- sglang/srt/models/deepseek_v2.py +4 -4
- sglang/srt/models/exaone.py +4 -4
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt2.py +287 -0
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +4 -4
- sglang/srt/models/internlm2.py +4 -4
- sglang/srt/models/llama.py +15 -7
- sglang/srt/models/llama_embedding.py +2 -10
- sglang/srt/models/llama_reward.py +5 -0
- sglang/srt/models/minicpm.py +4 -4
- sglang/srt/models/minicpm3.py +4 -4
- sglang/srt/models/mixtral.py +7 -5
- sglang/srt/models/mixtral_quant.py +4 -4
- sglang/srt/models/mllama.py +5 -5
- sglang/srt/models/olmo.py +4 -4
- sglang/srt/models/olmoe.py +4 -4
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +4 -4
- sglang/srt/models/qwen2_moe.py +4 -4
- sglang/srt/models/qwen2_vl.py +4 -8
- sglang/srt/models/stablelm.py +4 -4
- sglang/srt/models/torch_native_llama.py +4 -4
- sglang/srt/models/xverse.py +4 -4
- sglang/srt/models/xverse_moe.py +4 -4
- sglang/srt/openai_api/adapter.py +52 -66
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -13
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +41 -33
- sglang/srt/server_args.py +34 -5
- sglang/srt/utils.py +40 -56
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +2 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +151 -6
- sglang/utils.py +62 -1
- sglang/version.py +1 -1
- sglang-0.3.5.dist-info/METADATA +344 -0
- sglang-0.3.5.dist-info/RECORD +152 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
- sglang-0.3.4.post1.dist-info/METADATA +0 -900
- sglang-0.3.4.post1.dist-info/RECORD +0 -148
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -15,22 +15,21 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
17
17
|
|
18
|
-
import json
|
19
18
|
import logging
|
20
19
|
import os
|
20
|
+
import threading
|
21
21
|
import time
|
22
22
|
import warnings
|
23
23
|
from collections import deque
|
24
24
|
from types import SimpleNamespace
|
25
|
-
from typing import List, Optional
|
25
|
+
from typing import List, Optional
|
26
26
|
|
27
27
|
import torch
|
28
28
|
import zmq
|
29
29
|
|
30
30
|
from sglang.global_config import global_config
|
31
31
|
from sglang.srt.configs.model_config import ModelConfig
|
32
|
-
from sglang.srt.constrained.
|
33
|
-
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
32
|
+
from sglang.srt.constrained.grammar import GrammarCache
|
34
33
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
35
34
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
36
35
|
from sglang.srt.managers.io_struct import (
|
@@ -38,10 +37,11 @@ from sglang.srt.managers.io_struct import (
|
|
38
37
|
BatchEmbeddingOut,
|
39
38
|
BatchTokenIDOut,
|
40
39
|
FlushCacheReq,
|
40
|
+
GetMemPoolSizeReq,
|
41
|
+
GetMemPoolSizeReqOutput,
|
41
42
|
ProfileReq,
|
42
43
|
TokenizedEmbeddingReqInput,
|
43
44
|
TokenizedGenerateReqInput,
|
44
|
-
TokenizedRewardReqInput,
|
45
45
|
UpdateWeightReqInput,
|
46
46
|
UpdateWeightReqOutput,
|
47
47
|
)
|
@@ -66,10 +66,8 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|
66
66
|
from sglang.srt.utils import (
|
67
67
|
broadcast_pyobj,
|
68
68
|
configure_logger,
|
69
|
-
|
70
|
-
is_multimodal_model,
|
69
|
+
get_zmq_socket,
|
71
70
|
kill_parent_process,
|
72
|
-
pytorch_profile,
|
73
71
|
set_random_seed,
|
74
72
|
suppress_other_loggers,
|
75
73
|
)
|
@@ -77,6 +75,7 @@ from sglang.utils import get_exception_traceback
|
|
77
75
|
|
78
76
|
logger = logging.getLogger(__name__)
|
79
77
|
|
78
|
+
|
80
79
|
# Crash on warning if we are running CI tests
|
81
80
|
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
82
81
|
|
@@ -104,16 +103,26 @@ class Scheduler:
|
|
104
103
|
self.lora_paths = server_args.lora_paths
|
105
104
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
106
105
|
self.enable_overlap = server_args.enable_overlap_schedule
|
106
|
+
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
107
107
|
|
108
108
|
# Init inter-process communication
|
109
109
|
context = zmq.Context(2)
|
110
110
|
|
111
111
|
if self.tp_rank == 0:
|
112
|
-
self.recv_from_tokenizer =
|
113
|
-
|
112
|
+
self.recv_from_tokenizer = get_zmq_socket(
|
113
|
+
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
114
|
+
)
|
114
115
|
|
115
|
-
|
116
|
-
|
116
|
+
if server_args.skip_tokenizer_init:
|
117
|
+
# Directly send to the tokenizer/api
|
118
|
+
self.send_to_detokenizer = get_zmq_socket(
|
119
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
120
|
+
)
|
121
|
+
else:
|
122
|
+
# Send to the detokenizer
|
123
|
+
self.send_to_detokenizer = get_zmq_socket(
|
124
|
+
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
125
|
+
)
|
117
126
|
else:
|
118
127
|
self.recv_from_tokenizer = None
|
119
128
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
@@ -121,15 +130,17 @@ class Scheduler:
|
|
121
130
|
# Init tokenizer
|
122
131
|
self.model_config = ModelConfig(
|
123
132
|
server_args.model_path,
|
124
|
-
server_args.trust_remote_code,
|
133
|
+
trust_remote_code=server_args.trust_remote_code,
|
125
134
|
context_length=server_args.context_length,
|
126
|
-
model_override_args=
|
135
|
+
model_override_args=server_args.json_model_override_args,
|
136
|
+
is_embedding=server_args.is_embedding,
|
127
137
|
)
|
138
|
+
self.is_generation = self.model_config.is_generation
|
128
139
|
|
129
140
|
if server_args.skip_tokenizer_init:
|
130
141
|
self.tokenizer = self.processor = None
|
131
142
|
else:
|
132
|
-
if
|
143
|
+
if self.model_config.is_multimodal:
|
133
144
|
self.processor = get_processor(
|
134
145
|
server_args.tokenizer_path,
|
135
146
|
tokenizer_mode=server_args.tokenizer_mode,
|
@@ -142,9 +153,6 @@ class Scheduler:
|
|
142
153
|
tokenizer_mode=server_args.tokenizer_mode,
|
143
154
|
trust_remote_code=server_args.trust_remote_code,
|
144
155
|
)
|
145
|
-
self.is_generation = is_generation_model(
|
146
|
-
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
147
|
-
)
|
148
156
|
|
149
157
|
# Launch a tensor parallel worker
|
150
158
|
if self.enable_overlap:
|
@@ -211,44 +219,62 @@ class Scheduler:
|
|
211
219
|
self.waiting_queue: List[Req] = []
|
212
220
|
self.running_batch: Optional[ScheduleBatch] = None
|
213
221
|
self.cur_batch: Optional[ScheduleBatch] = None
|
214
|
-
self.
|
215
|
-
self.
|
222
|
+
self.forward_ct = 0
|
223
|
+
self.forward_ct_decode = 0
|
216
224
|
self.num_generated_tokens = 0
|
217
225
|
self.last_stats_tic = time.time()
|
226
|
+
self.stream_interval = server_args.stream_interval
|
218
227
|
|
219
228
|
# Init chunked prefill
|
220
229
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
221
|
-
self.
|
230
|
+
self.being_chunked_req = None
|
222
231
|
self.is_mixed_chunk = (
|
223
232
|
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
224
233
|
)
|
225
234
|
|
226
235
|
# Init the FSM cache for constrained generation
|
236
|
+
self.grammar_cache = None
|
237
|
+
|
227
238
|
if not server_args.skip_tokenizer_init:
|
228
|
-
self.
|
239
|
+
self.grammar_cache = GrammarCache(
|
229
240
|
server_args.tokenizer_path,
|
230
241
|
{
|
231
242
|
"tokenizer_mode": server_args.tokenizer_mode,
|
232
243
|
"trust_remote_code": server_args.trust_remote_code,
|
233
244
|
},
|
234
245
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
235
|
-
|
246
|
+
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
|
247
|
+
backend=server_args.grammar_backend,
|
248
|
+
allow_jump=not server_args.disable_regex_jump_forward,
|
236
249
|
)
|
237
|
-
self.jump_forward_cache = JumpForwardCache()
|
238
250
|
|
239
251
|
# Init new token estimation
|
240
252
|
assert (
|
241
253
|
server_args.schedule_conservativeness >= 0
|
242
254
|
), "Invalid schedule_conservativeness"
|
243
|
-
|
244
|
-
|
255
|
+
|
256
|
+
self.init_new_token_ratio = min(
|
257
|
+
global_config.default_init_new_token_ratio
|
245
258
|
* server_args.schedule_conservativeness,
|
246
259
|
1.0,
|
247
260
|
)
|
248
|
-
self.
|
249
|
-
|
261
|
+
self.min_new_token_ratio = min(
|
262
|
+
self.init_new_token_ratio
|
263
|
+
* global_config.default_min_new_token_ratio_factor,
|
264
|
+
1.0,
|
265
|
+
)
|
266
|
+
self.new_token_ratio_decay = (
|
267
|
+
self.init_new_token_ratio - self.min_new_token_ratio
|
268
|
+
) / global_config.default_new_token_ratio_decay_steps
|
269
|
+
self.new_token_ratio = self.init_new_token_ratio
|
270
|
+
|
250
271
|
self.batch_is_full = False
|
251
272
|
|
273
|
+
# Init watchdog thread
|
274
|
+
self.watchdog_timeout = server_args.watchdog_timeout
|
275
|
+
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
276
|
+
t.start()
|
277
|
+
|
252
278
|
# Init profiler
|
253
279
|
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
254
280
|
self.profiler = None
|
@@ -266,6 +292,23 @@ class Scheduler:
|
|
266
292
|
with_stack=True,
|
267
293
|
)
|
268
294
|
|
295
|
+
def watchdog_thread(self):
|
296
|
+
self.watchdog_last_forward_ct = 0
|
297
|
+
self.watchdog_last_time = time.time()
|
298
|
+
|
299
|
+
while True:
|
300
|
+
if self.cur_batch is not None:
|
301
|
+
if self.watchdog_last_forward_ct == self.forward_ct:
|
302
|
+
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
|
303
|
+
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
304
|
+
break
|
305
|
+
else:
|
306
|
+
self.watchdog_last_forward_ct = self.forward_ct
|
307
|
+
self.watchdog_last_time = time.time()
|
308
|
+
time.sleep(self.watchdog_timeout / 2)
|
309
|
+
|
310
|
+
kill_parent_process()
|
311
|
+
|
269
312
|
@torch.inference_mode()
|
270
313
|
def event_loop_normal(self):
|
271
314
|
"""A normal blocking scheduler loop."""
|
@@ -276,6 +319,7 @@ class Scheduler:
|
|
276
319
|
self.process_input_requests(recv_reqs)
|
277
320
|
|
278
321
|
batch = self.get_next_batch_to_run()
|
322
|
+
self.cur_batch = batch
|
279
323
|
|
280
324
|
if batch:
|
281
325
|
result = self.run_batch(batch)
|
@@ -293,7 +337,7 @@ class Scheduler:
|
|
293
337
|
self.process_batch_result(batch, result)
|
294
338
|
else:
|
295
339
|
self.check_memory()
|
296
|
-
self.new_token_ratio =
|
340
|
+
self.new_token_ratio = self.init_new_token_ratio
|
297
341
|
|
298
342
|
self.last_batch = batch
|
299
343
|
|
@@ -320,7 +364,7 @@ class Scheduler:
|
|
320
364
|
self.process_batch_result(tmp_batch, tmp_result)
|
321
365
|
elif batch is None:
|
322
366
|
self.check_memory()
|
323
|
-
self.new_token_ratio =
|
367
|
+
self.new_token_ratio = self.init_new_token_ratio
|
324
368
|
|
325
369
|
self.last_batch = batch
|
326
370
|
|
@@ -345,9 +389,7 @@ class Scheduler:
|
|
345
389
|
for recv_req in recv_reqs:
|
346
390
|
if isinstance(recv_req, TokenizedGenerateReqInput):
|
347
391
|
self.handle_generate_request(recv_req)
|
348
|
-
elif isinstance(
|
349
|
-
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
|
350
|
-
):
|
392
|
+
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
|
351
393
|
self.handle_embedding_request(recv_req)
|
352
394
|
elif isinstance(recv_req, FlushCacheReq):
|
353
395
|
self.flush_cache()
|
@@ -363,6 +405,10 @@ class Scheduler:
|
|
363
405
|
self.start_profile()
|
364
406
|
else:
|
365
407
|
self.stop_profile()
|
408
|
+
elif isinstance(recv_req, GetMemPoolSizeReq):
|
409
|
+
self.send_to_detokenizer.send_pyobj(
|
410
|
+
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
411
|
+
)
|
366
412
|
else:
|
367
413
|
raise ValueError(f"Invalid request: {recv_req}")
|
368
414
|
|
@@ -397,26 +443,24 @@ class Scheduler:
|
|
397
443
|
# By default, only return the logprobs for output tokens
|
398
444
|
req.logprob_start_len = len(recv_req.input_ids) - 1
|
399
445
|
|
400
|
-
# Init regex FSM
|
446
|
+
# Init regex FSM or BNF
|
401
447
|
if (
|
402
448
|
req.sampling_params.json_schema is not None
|
403
449
|
or req.sampling_params.regex is not None
|
404
450
|
):
|
451
|
+
assert self.grammar_cache is not None
|
405
452
|
if req.sampling_params.json_schema is not None:
|
406
|
-
req.
|
407
|
-
("json", req.sampling_params.json_schema)
|
453
|
+
req.grammar = self.grammar_cache.query(
|
454
|
+
("json", req.sampling_params.json_schema),
|
455
|
+
self.model_config.vocab_size,
|
408
456
|
)
|
409
457
|
elif req.sampling_params.regex is not None:
|
410
|
-
req.
|
411
|
-
("regex", req.sampling_params.regex)
|
412
|
-
)
|
413
|
-
if not self.disable_regex_jump_forward:
|
414
|
-
req.jump_forward_map = self.jump_forward_cache.query(
|
415
|
-
computed_regex_string
|
458
|
+
req.grammar = self.grammar_cache.query(
|
459
|
+
("regex", req.sampling_params.regex), self.model_config.vocab_size
|
416
460
|
)
|
417
461
|
|
418
462
|
# Truncate prompts that are too long
|
419
|
-
if len(req.origin_input_ids)
|
463
|
+
if len(req.origin_input_ids) > self.max_req_input_len:
|
420
464
|
logger.warning(
|
421
465
|
"Request length is longer than the KV cache pool size or "
|
422
466
|
"the max context length. Truncated!!!"
|
@@ -436,7 +480,7 @@ class Scheduler:
|
|
436
480
|
|
437
481
|
def handle_embedding_request(
|
438
482
|
self,
|
439
|
-
recv_req:
|
483
|
+
recv_req: TokenizedEmbeddingReqInput,
|
440
484
|
):
|
441
485
|
req = Req(
|
442
486
|
recv_req.rid,
|
@@ -501,13 +545,13 @@ class Scheduler:
|
|
501
545
|
and not self.last_batch.forward_mode.is_decode()
|
502
546
|
and not self.last_batch.is_empty()
|
503
547
|
):
|
504
|
-
if self.
|
548
|
+
if self.being_chunked_req:
|
505
549
|
self.last_batch.filter_batch(
|
506
|
-
|
550
|
+
being_chunked_req=self.being_chunked_req
|
507
551
|
)
|
508
|
-
self.tree_cache.cache_unfinished_req(self.
|
552
|
+
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
509
553
|
# Inflight request keeps its rid but will get a new req_pool_idx.
|
510
|
-
self.req_to_token_pool.free(self.
|
554
|
+
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
511
555
|
self.batch_is_full = False
|
512
556
|
if not self.last_batch.is_empty():
|
513
557
|
if self.running_batch is None:
|
@@ -538,7 +582,7 @@ class Scheduler:
|
|
538
582
|
# Handle the cases where prefill is not allowed
|
539
583
|
if (
|
540
584
|
self.batch_is_full or len(self.waiting_queue) == 0
|
541
|
-
) and self.
|
585
|
+
) and self.being_chunked_req is None:
|
542
586
|
return None
|
543
587
|
|
544
588
|
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
@@ -561,13 +605,11 @@ class Scheduler:
|
|
561
605
|
num_mixed_running,
|
562
606
|
)
|
563
607
|
|
564
|
-
has_inflight = self.
|
608
|
+
has_inflight = self.being_chunked_req is not None
|
565
609
|
if has_inflight:
|
566
|
-
self.
|
567
|
-
|
568
|
-
|
569
|
-
self.current_inflight_req = adder.add_inflight_req(
|
570
|
-
self.current_inflight_req
|
610
|
+
self.being_chunked_req.init_next_round_input()
|
611
|
+
self.being_chunked_req = adder.add_inflight_req(
|
612
|
+
self.being_chunked_req
|
571
613
|
)
|
572
614
|
|
573
615
|
if self.lora_paths:
|
@@ -611,11 +653,11 @@ class Scheduler:
|
|
611
653
|
]
|
612
654
|
|
613
655
|
if adder.new_inflight_req is not None:
|
614
|
-
assert self.
|
615
|
-
self.
|
656
|
+
assert self.being_chunked_req is None
|
657
|
+
self.being_chunked_req = adder.new_inflight_req
|
616
658
|
|
617
|
-
if self.
|
618
|
-
self.
|
659
|
+
if self.being_chunked_req:
|
660
|
+
self.being_chunked_req.is_being_chunked += 1
|
619
661
|
|
620
662
|
# Print stats
|
621
663
|
if self.tp_rank == 0:
|
@@ -670,9 +712,11 @@ class Scheduler:
|
|
670
712
|
|
671
713
|
# Mixed-style chunked prefill
|
672
714
|
if self.is_mixed_chunk and self.running_batch is not None:
|
673
|
-
self.running_batch.
|
674
|
-
|
675
|
-
|
715
|
+
self.running_batch.filter_batch()
|
716
|
+
if not self.running_batch.is_empty():
|
717
|
+
self.running_batch.prepare_for_decode(self.enable_overlap)
|
718
|
+
new_batch.mix_with_running(self.running_batch)
|
719
|
+
new_batch.decoding_reqs = self.running_batch.reqs
|
676
720
|
self.running_batch = None
|
677
721
|
else:
|
678
722
|
new_batch.decoding_reqs = None
|
@@ -721,6 +765,8 @@ class Scheduler:
|
|
721
765
|
|
722
766
|
def run_batch(self, batch: ScheduleBatch):
|
723
767
|
"""Run a batch."""
|
768
|
+
self.forward_ct += 1
|
769
|
+
|
724
770
|
if self.is_generation:
|
725
771
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
726
772
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -729,7 +775,7 @@ class Scheduler:
|
|
729
775
|
)
|
730
776
|
else:
|
731
777
|
logits_output = None
|
732
|
-
if self.
|
778
|
+
if self.skip_tokenizer_init:
|
733
779
|
next_token_ids = torch.full(
|
734
780
|
(batch.batch_size(),), self.tokenizer.eos_token_id
|
735
781
|
)
|
@@ -753,6 +799,7 @@ class Scheduler:
|
|
753
799
|
self.process_batch_result_prefill(batch, result)
|
754
800
|
|
755
801
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
802
|
+
|
756
803
|
if self.is_generation:
|
757
804
|
logits_output, next_token_ids, bid = result
|
758
805
|
|
@@ -778,9 +825,10 @@ class Scheduler:
|
|
778
825
|
# Check finish conditions
|
779
826
|
logprob_pt = 0
|
780
827
|
for i, req in enumerate(batch.reqs):
|
781
|
-
if req.
|
782
|
-
|
783
|
-
|
828
|
+
if req.is_retracted:
|
829
|
+
continue
|
830
|
+
|
831
|
+
if req.is_being_chunked <= 0:
|
784
832
|
# Inflight reqs' prefill is not finished
|
785
833
|
req.completion_tokens_wo_jump_forward += 1
|
786
834
|
req.output_ids.append(next_token_ids[i])
|
@@ -791,24 +839,28 @@ class Scheduler:
|
|
791
839
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
792
840
|
self.tree_cache.cache_unfinished_req(req)
|
793
841
|
|
794
|
-
if req.
|
795
|
-
req.
|
796
|
-
req.regex_fsm_state, next_token_ids[i]
|
797
|
-
)
|
842
|
+
if req.grammar is not None:
|
843
|
+
req.grammar.accept_token(next_token_ids[i])
|
798
844
|
|
799
845
|
if req.return_logprob:
|
800
846
|
logprob_pt += self.add_logprob_return_values(
|
801
847
|
i, req, logprob_pt, next_token_ids, logits_output
|
802
848
|
)
|
849
|
+
else:
|
850
|
+
req.is_being_chunked -= 1
|
851
|
+
|
803
852
|
else: # embedding or reward model
|
804
853
|
embeddings, bid = result
|
805
854
|
embeddings = embeddings.tolist()
|
806
855
|
|
807
856
|
# Check finish conditions
|
808
857
|
for i, req in enumerate(batch.reqs):
|
858
|
+
if req.is_retracted:
|
859
|
+
continue
|
860
|
+
|
809
861
|
req.embedding = embeddings[i]
|
810
|
-
if req.
|
811
|
-
req.
|
862
|
+
if req.is_being_chunked > 0:
|
863
|
+
req.is_being_chunked -= 1
|
812
864
|
else:
|
813
865
|
# Inflight reqs' prefill is not finished
|
814
866
|
# dummy output token for embedding models
|
@@ -828,6 +880,7 @@ class Scheduler:
|
|
828
880
|
|
829
881
|
if self.enable_overlap:
|
830
882
|
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
|
883
|
+
next_token_logprobs = logits_output.next_token_logprobs
|
831
884
|
else:
|
832
885
|
# Move next_token_ids and logprobs to cpu
|
833
886
|
if batch.return_logprob:
|
@@ -841,7 +894,12 @@ class Scheduler:
|
|
841
894
|
|
842
895
|
# Check finish condition
|
843
896
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
844
|
-
if
|
897
|
+
if req.is_retracted:
|
898
|
+
continue
|
899
|
+
|
900
|
+
if self.server_args.enable_overlap_schedule and (
|
901
|
+
req.finished()
|
902
|
+
):
|
845
903
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
846
904
|
continue
|
847
905
|
|
@@ -849,10 +907,8 @@ class Scheduler:
|
|
849
907
|
req.output_ids.append(next_token_id)
|
850
908
|
req.check_finished()
|
851
909
|
|
852
|
-
if req.
|
853
|
-
req.
|
854
|
-
req.regex_fsm_state, next_token_id
|
855
|
-
)
|
910
|
+
if req.grammar is not None:
|
911
|
+
req.grammar.accept_token(next_token_id)
|
856
912
|
|
857
913
|
if req.finished():
|
858
914
|
self.tree_cache.cache_finished_req(req)
|
@@ -868,8 +924,8 @@ class Scheduler:
|
|
868
924
|
|
869
925
|
self.token_to_kv_pool.free_group_end()
|
870
926
|
|
871
|
-
self.
|
872
|
-
if self.tp_rank == 0 and self.
|
927
|
+
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
928
|
+
if self.tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0:
|
873
929
|
self.print_decode_stats()
|
874
930
|
|
875
931
|
def add_logprob_return_values(
|
@@ -948,22 +1004,24 @@ class Scheduler:
|
|
948
1004
|
def stream_output(self, reqs: List[Req]):
|
949
1005
|
"""Stream the output to detokenizer."""
|
950
1006
|
output_rids = []
|
951
|
-
output_meta_info = []
|
1007
|
+
output_meta_info: List[dict] = []
|
952
1008
|
output_finished_reason: List[BaseFinishReason] = []
|
953
1009
|
if self.is_generation:
|
954
1010
|
output_vids = []
|
955
1011
|
decoded_texts = []
|
956
1012
|
output_read_ids = []
|
957
1013
|
output_read_offsets = []
|
1014
|
+
output_ids = []
|
958
1015
|
output_skip_special_tokens = []
|
959
1016
|
output_spaces_between_special_tokens = []
|
960
1017
|
output_no_stop_trim = []
|
961
1018
|
else: # embedding or reward model
|
962
1019
|
output_embeddings = []
|
963
1020
|
|
964
|
-
is_stream_iter = self.
|
1021
|
+
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
|
965
1022
|
|
966
1023
|
for req in reqs:
|
1024
|
+
# TODO(lianmin): revisit this for overlap + retract + stream
|
967
1025
|
if req.finished() or (
|
968
1026
|
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
969
1027
|
):
|
@@ -975,6 +1033,8 @@ class Scheduler:
|
|
975
1033
|
read_ids, read_offset = req.init_incremental_detokenize()
|
976
1034
|
output_read_ids.append(read_ids)
|
977
1035
|
output_read_offsets.append(read_offset)
|
1036
|
+
if self.skip_tokenizer_init:
|
1037
|
+
output_ids.append(req.output_ids)
|
978
1038
|
output_skip_special_tokens.append(
|
979
1039
|
req.sampling_params.skip_special_tokens
|
980
1040
|
)
|
@@ -1026,6 +1086,7 @@ class Scheduler:
|
|
1026
1086
|
decoded_texts,
|
1027
1087
|
output_read_ids,
|
1028
1088
|
output_read_offsets,
|
1089
|
+
output_ids,
|
1029
1090
|
output_skip_special_tokens,
|
1030
1091
|
output_spaces_between_special_tokens,
|
1031
1092
|
output_meta_info,
|
@@ -1050,7 +1111,9 @@ class Scheduler:
|
|
1050
1111
|
):
|
1051
1112
|
self.tree_cache.reset()
|
1052
1113
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
1053
|
-
self.
|
1114
|
+
if self.grammar_cache is not None:
|
1115
|
+
self.grammar_cache.reset()
|
1116
|
+
# TODO(dark): reset the bnf cache
|
1054
1117
|
self.req_to_token_pool.clear()
|
1055
1118
|
self.token_to_kv_pool.clear()
|
1056
1119
|
torch.cuda.empty_cache()
|