sglang 0.3.4.post2__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_latency.py +3 -3
- sglang/bench_server_latency.py +2 -3
- sglang/bench_serving.py +92 -0
- sglang/global_config.py +9 -3
- sglang/lang/chat_template.py +50 -25
- sglang/lang/interpreter.py +9 -1
- sglang/lang/ir.py +11 -2
- sglang/launch_server.py +1 -1
- sglang/srt/configs/model_config.py +51 -13
- sglang/srt/constrained/__init__.py +18 -0
- sglang/srt/constrained/bnf_cache.py +61 -0
- sglang/srt/constrained/grammar.py +190 -0
- sglang/srt/hf_transformers_utils.py +6 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
- sglang/srt/layers/fused_moe/fused_moe.py +4 -3
- sglang/srt/layers/fused_moe/layer.py +28 -0
- sglang/srt/layers/quantization/base_config.py +16 -1
- sglang/srt/layers/vocab_parallel_embedding.py +486 -0
- sglang/srt/managers/data_parallel_controller.py +7 -6
- sglang/srt/managers/detokenizer_manager.py +9 -11
- sglang/srt/managers/image_processor.py +4 -3
- sglang/srt/managers/io_struct.py +70 -78
- sglang/srt/managers/schedule_batch.py +33 -49
- sglang/srt/managers/schedule_policy.py +24 -13
- sglang/srt/managers/scheduler.py +137 -80
- sglang/srt/managers/tokenizer_manager.py +224 -336
- sglang/srt/managers/tp_worker.py +5 -5
- sglang/srt/mem_cache/flush_cache.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/model_runner.py +8 -17
- sglang/srt/models/baichuan.py +4 -4
- sglang/srt/models/chatglm.py +4 -4
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +5 -5
- sglang/srt/models/deepseek.py +4 -4
- sglang/srt/models/deepseek_v2.py +4 -4
- sglang/srt/models/exaone.py +4 -4
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt2.py +287 -0
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +4 -4
- sglang/srt/models/internlm2.py +4 -4
- sglang/srt/models/llama.py +15 -7
- sglang/srt/models/llama_embedding.py +2 -10
- sglang/srt/models/llama_reward.py +5 -0
- sglang/srt/models/minicpm.py +4 -4
- sglang/srt/models/minicpm3.py +4 -4
- sglang/srt/models/mixtral.py +7 -5
- sglang/srt/models/mixtral_quant.py +4 -4
- sglang/srt/models/mllama.py +5 -5
- sglang/srt/models/olmo.py +4 -4
- sglang/srt/models/olmoe.py +4 -4
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +4 -4
- sglang/srt/models/qwen2_moe.py +4 -4
- sglang/srt/models/qwen2_vl.py +4 -8
- sglang/srt/models/stablelm.py +4 -4
- sglang/srt/models/torch_native_llama.py +4 -4
- sglang/srt/models/xverse.py +4 -4
- sglang/srt/models/xverse_moe.py +4 -4
- sglang/srt/openai_api/adapter.py +52 -66
- sglang/srt/sampling/sampling_batch_info.py +7 -13
- sglang/srt/server.py +31 -35
- sglang/srt/server_args.py +34 -5
- sglang/srt/utils.py +40 -56
- sglang/test/runners.py +2 -1
- sglang/test/test_utils.py +73 -25
- sglang/utils.py +62 -1
- sglang/version.py +1 -1
- sglang-0.3.5.dist-info/METADATA +344 -0
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/RECORD +77 -73
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
- sglang-0.3.4.post2.dist-info/METADATA +0 -899
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -15,22 +15,21 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
17
17
|
|
18
|
-
import json
|
19
18
|
import logging
|
20
19
|
import os
|
20
|
+
import threading
|
21
21
|
import time
|
22
22
|
import warnings
|
23
23
|
from collections import deque
|
24
24
|
from types import SimpleNamespace
|
25
|
-
from typing import List, Optional
|
25
|
+
from typing import List, Optional
|
26
26
|
|
27
27
|
import torch
|
28
28
|
import zmq
|
29
29
|
|
30
30
|
from sglang.global_config import global_config
|
31
31
|
from sglang.srt.configs.model_config import ModelConfig
|
32
|
-
from sglang.srt.constrained.
|
33
|
-
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
32
|
+
from sglang.srt.constrained.grammar import GrammarCache
|
34
33
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
35
34
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
36
35
|
from sglang.srt.managers.io_struct import (
|
@@ -43,7 +42,6 @@ from sglang.srt.managers.io_struct import (
|
|
43
42
|
ProfileReq,
|
44
43
|
TokenizedEmbeddingReqInput,
|
45
44
|
TokenizedGenerateReqInput,
|
46
|
-
TokenizedRewardReqInput,
|
47
45
|
UpdateWeightReqInput,
|
48
46
|
UpdateWeightReqOutput,
|
49
47
|
)
|
@@ -68,8 +66,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|
68
66
|
from sglang.srt.utils import (
|
69
67
|
broadcast_pyobj,
|
70
68
|
configure_logger,
|
71
|
-
|
72
|
-
is_multimodal_model,
|
69
|
+
get_zmq_socket,
|
73
70
|
kill_parent_process,
|
74
71
|
set_random_seed,
|
75
72
|
suppress_other_loggers,
|
@@ -78,6 +75,7 @@ from sglang.utils import get_exception_traceback
|
|
78
75
|
|
79
76
|
logger = logging.getLogger(__name__)
|
80
77
|
|
78
|
+
|
81
79
|
# Crash on warning if we are running CI tests
|
82
80
|
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
83
81
|
|
@@ -105,16 +103,26 @@ class Scheduler:
|
|
105
103
|
self.lora_paths = server_args.lora_paths
|
106
104
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
107
105
|
self.enable_overlap = server_args.enable_overlap_schedule
|
106
|
+
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
108
107
|
|
109
108
|
# Init inter-process communication
|
110
109
|
context = zmq.Context(2)
|
111
110
|
|
112
111
|
if self.tp_rank == 0:
|
113
|
-
self.recv_from_tokenizer =
|
114
|
-
|
112
|
+
self.recv_from_tokenizer = get_zmq_socket(
|
113
|
+
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
114
|
+
)
|
115
115
|
|
116
|
-
|
117
|
-
|
116
|
+
if server_args.skip_tokenizer_init:
|
117
|
+
# Directly send to the tokenizer/api
|
118
|
+
self.send_to_detokenizer = get_zmq_socket(
|
119
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
120
|
+
)
|
121
|
+
else:
|
122
|
+
# Send to the detokenizer
|
123
|
+
self.send_to_detokenizer = get_zmq_socket(
|
124
|
+
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
125
|
+
)
|
118
126
|
else:
|
119
127
|
self.recv_from_tokenizer = None
|
120
128
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
@@ -122,15 +130,17 @@ class Scheduler:
|
|
122
130
|
# Init tokenizer
|
123
131
|
self.model_config = ModelConfig(
|
124
132
|
server_args.model_path,
|
125
|
-
server_args.trust_remote_code,
|
133
|
+
trust_remote_code=server_args.trust_remote_code,
|
126
134
|
context_length=server_args.context_length,
|
127
|
-
model_override_args=
|
135
|
+
model_override_args=server_args.json_model_override_args,
|
136
|
+
is_embedding=server_args.is_embedding,
|
128
137
|
)
|
138
|
+
self.is_generation = self.model_config.is_generation
|
129
139
|
|
130
140
|
if server_args.skip_tokenizer_init:
|
131
141
|
self.tokenizer = self.processor = None
|
132
142
|
else:
|
133
|
-
if
|
143
|
+
if self.model_config.is_multimodal:
|
134
144
|
self.processor = get_processor(
|
135
145
|
server_args.tokenizer_path,
|
136
146
|
tokenizer_mode=server_args.tokenizer_mode,
|
@@ -143,9 +153,6 @@ class Scheduler:
|
|
143
153
|
tokenizer_mode=server_args.tokenizer_mode,
|
144
154
|
trust_remote_code=server_args.trust_remote_code,
|
145
155
|
)
|
146
|
-
self.is_generation = is_generation_model(
|
147
|
-
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
148
|
-
)
|
149
156
|
|
150
157
|
# Launch a tensor parallel worker
|
151
158
|
if self.enable_overlap:
|
@@ -212,44 +219,62 @@ class Scheduler:
|
|
212
219
|
self.waiting_queue: List[Req] = []
|
213
220
|
self.running_batch: Optional[ScheduleBatch] = None
|
214
221
|
self.cur_batch: Optional[ScheduleBatch] = None
|
215
|
-
self.
|
216
|
-
self.
|
222
|
+
self.forward_ct = 0
|
223
|
+
self.forward_ct_decode = 0
|
217
224
|
self.num_generated_tokens = 0
|
218
225
|
self.last_stats_tic = time.time()
|
226
|
+
self.stream_interval = server_args.stream_interval
|
219
227
|
|
220
228
|
# Init chunked prefill
|
221
229
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
222
|
-
self.
|
230
|
+
self.being_chunked_req = None
|
223
231
|
self.is_mixed_chunk = (
|
224
232
|
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
225
233
|
)
|
226
234
|
|
227
235
|
# Init the FSM cache for constrained generation
|
236
|
+
self.grammar_cache = None
|
237
|
+
|
228
238
|
if not server_args.skip_tokenizer_init:
|
229
|
-
self.
|
239
|
+
self.grammar_cache = GrammarCache(
|
230
240
|
server_args.tokenizer_path,
|
231
241
|
{
|
232
242
|
"tokenizer_mode": server_args.tokenizer_mode,
|
233
243
|
"trust_remote_code": server_args.trust_remote_code,
|
234
244
|
},
|
235
245
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
236
|
-
|
246
|
+
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
|
247
|
+
backend=server_args.grammar_backend,
|
248
|
+
allow_jump=not server_args.disable_regex_jump_forward,
|
237
249
|
)
|
238
|
-
self.jump_forward_cache = JumpForwardCache()
|
239
250
|
|
240
251
|
# Init new token estimation
|
241
252
|
assert (
|
242
253
|
server_args.schedule_conservativeness >= 0
|
243
254
|
), "Invalid schedule_conservativeness"
|
244
|
-
|
245
|
-
|
255
|
+
|
256
|
+
self.init_new_token_ratio = min(
|
257
|
+
global_config.default_init_new_token_ratio
|
246
258
|
* server_args.schedule_conservativeness,
|
247
259
|
1.0,
|
248
260
|
)
|
249
|
-
self.
|
250
|
-
|
261
|
+
self.min_new_token_ratio = min(
|
262
|
+
self.init_new_token_ratio
|
263
|
+
* global_config.default_min_new_token_ratio_factor,
|
264
|
+
1.0,
|
265
|
+
)
|
266
|
+
self.new_token_ratio_decay = (
|
267
|
+
self.init_new_token_ratio - self.min_new_token_ratio
|
268
|
+
) / global_config.default_new_token_ratio_decay_steps
|
269
|
+
self.new_token_ratio = self.init_new_token_ratio
|
270
|
+
|
251
271
|
self.batch_is_full = False
|
252
272
|
|
273
|
+
# Init watchdog thread
|
274
|
+
self.watchdog_timeout = server_args.watchdog_timeout
|
275
|
+
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
276
|
+
t.start()
|
277
|
+
|
253
278
|
# Init profiler
|
254
279
|
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
255
280
|
self.profiler = None
|
@@ -267,6 +292,23 @@ class Scheduler:
|
|
267
292
|
with_stack=True,
|
268
293
|
)
|
269
294
|
|
295
|
+
def watchdog_thread(self):
|
296
|
+
self.watchdog_last_forward_ct = 0
|
297
|
+
self.watchdog_last_time = time.time()
|
298
|
+
|
299
|
+
while True:
|
300
|
+
if self.cur_batch is not None:
|
301
|
+
if self.watchdog_last_forward_ct == self.forward_ct:
|
302
|
+
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
|
303
|
+
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
304
|
+
break
|
305
|
+
else:
|
306
|
+
self.watchdog_last_forward_ct = self.forward_ct
|
307
|
+
self.watchdog_last_time = time.time()
|
308
|
+
time.sleep(self.watchdog_timeout / 2)
|
309
|
+
|
310
|
+
kill_parent_process()
|
311
|
+
|
270
312
|
@torch.inference_mode()
|
271
313
|
def event_loop_normal(self):
|
272
314
|
"""A normal blocking scheduler loop."""
|
@@ -277,6 +319,7 @@ class Scheduler:
|
|
277
319
|
self.process_input_requests(recv_reqs)
|
278
320
|
|
279
321
|
batch = self.get_next_batch_to_run()
|
322
|
+
self.cur_batch = batch
|
280
323
|
|
281
324
|
if batch:
|
282
325
|
result = self.run_batch(batch)
|
@@ -294,7 +337,7 @@ class Scheduler:
|
|
294
337
|
self.process_batch_result(batch, result)
|
295
338
|
else:
|
296
339
|
self.check_memory()
|
297
|
-
self.new_token_ratio =
|
340
|
+
self.new_token_ratio = self.init_new_token_ratio
|
298
341
|
|
299
342
|
self.last_batch = batch
|
300
343
|
|
@@ -321,7 +364,7 @@ class Scheduler:
|
|
321
364
|
self.process_batch_result(tmp_batch, tmp_result)
|
322
365
|
elif batch is None:
|
323
366
|
self.check_memory()
|
324
|
-
self.new_token_ratio =
|
367
|
+
self.new_token_ratio = self.init_new_token_ratio
|
325
368
|
|
326
369
|
self.last_batch = batch
|
327
370
|
|
@@ -346,9 +389,7 @@ class Scheduler:
|
|
346
389
|
for recv_req in recv_reqs:
|
347
390
|
if isinstance(recv_req, TokenizedGenerateReqInput):
|
348
391
|
self.handle_generate_request(recv_req)
|
349
|
-
elif isinstance(
|
350
|
-
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
|
351
|
-
):
|
392
|
+
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
|
352
393
|
self.handle_embedding_request(recv_req)
|
353
394
|
elif isinstance(recv_req, FlushCacheReq):
|
354
395
|
self.flush_cache()
|
@@ -402,22 +443,20 @@ class Scheduler:
|
|
402
443
|
# By default, only return the logprobs for output tokens
|
403
444
|
req.logprob_start_len = len(recv_req.input_ids) - 1
|
404
445
|
|
405
|
-
# Init regex FSM
|
446
|
+
# Init regex FSM or BNF
|
406
447
|
if (
|
407
448
|
req.sampling_params.json_schema is not None
|
408
449
|
or req.sampling_params.regex is not None
|
409
450
|
):
|
451
|
+
assert self.grammar_cache is not None
|
410
452
|
if req.sampling_params.json_schema is not None:
|
411
|
-
req.
|
412
|
-
("json", req.sampling_params.json_schema)
|
453
|
+
req.grammar = self.grammar_cache.query(
|
454
|
+
("json", req.sampling_params.json_schema),
|
455
|
+
self.model_config.vocab_size,
|
413
456
|
)
|
414
457
|
elif req.sampling_params.regex is not None:
|
415
|
-
req.
|
416
|
-
("regex", req.sampling_params.regex)
|
417
|
-
)
|
418
|
-
if not self.disable_regex_jump_forward:
|
419
|
-
req.jump_forward_map = self.jump_forward_cache.query(
|
420
|
-
computed_regex_string
|
458
|
+
req.grammar = self.grammar_cache.query(
|
459
|
+
("regex", req.sampling_params.regex), self.model_config.vocab_size
|
421
460
|
)
|
422
461
|
|
423
462
|
# Truncate prompts that are too long
|
@@ -441,7 +480,7 @@ class Scheduler:
|
|
441
480
|
|
442
481
|
def handle_embedding_request(
|
443
482
|
self,
|
444
|
-
recv_req:
|
483
|
+
recv_req: TokenizedEmbeddingReqInput,
|
445
484
|
):
|
446
485
|
req = Req(
|
447
486
|
recv_req.rid,
|
@@ -506,13 +545,13 @@ class Scheduler:
|
|
506
545
|
and not self.last_batch.forward_mode.is_decode()
|
507
546
|
and not self.last_batch.is_empty()
|
508
547
|
):
|
509
|
-
if self.
|
548
|
+
if self.being_chunked_req:
|
510
549
|
self.last_batch.filter_batch(
|
511
|
-
|
550
|
+
being_chunked_req=self.being_chunked_req
|
512
551
|
)
|
513
|
-
self.tree_cache.cache_unfinished_req(self.
|
552
|
+
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
514
553
|
# Inflight request keeps its rid but will get a new req_pool_idx.
|
515
|
-
self.req_to_token_pool.free(self.
|
554
|
+
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
516
555
|
self.batch_is_full = False
|
517
556
|
if not self.last_batch.is_empty():
|
518
557
|
if self.running_batch is None:
|
@@ -543,7 +582,7 @@ class Scheduler:
|
|
543
582
|
# Handle the cases where prefill is not allowed
|
544
583
|
if (
|
545
584
|
self.batch_is_full or len(self.waiting_queue) == 0
|
546
|
-
) and self.
|
585
|
+
) and self.being_chunked_req is None:
|
547
586
|
return None
|
548
587
|
|
549
588
|
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
@@ -566,13 +605,11 @@ class Scheduler:
|
|
566
605
|
num_mixed_running,
|
567
606
|
)
|
568
607
|
|
569
|
-
has_inflight = self.
|
608
|
+
has_inflight = self.being_chunked_req is not None
|
570
609
|
if has_inflight:
|
571
|
-
self.
|
572
|
-
|
573
|
-
|
574
|
-
self.current_inflight_req = adder.add_inflight_req(
|
575
|
-
self.current_inflight_req
|
610
|
+
self.being_chunked_req.init_next_round_input()
|
611
|
+
self.being_chunked_req = adder.add_inflight_req(
|
612
|
+
self.being_chunked_req
|
576
613
|
)
|
577
614
|
|
578
615
|
if self.lora_paths:
|
@@ -616,11 +653,11 @@ class Scheduler:
|
|
616
653
|
]
|
617
654
|
|
618
655
|
if adder.new_inflight_req is not None:
|
619
|
-
assert self.
|
620
|
-
self.
|
656
|
+
assert self.being_chunked_req is None
|
657
|
+
self.being_chunked_req = adder.new_inflight_req
|
621
658
|
|
622
|
-
if self.
|
623
|
-
self.
|
659
|
+
if self.being_chunked_req:
|
660
|
+
self.being_chunked_req.is_being_chunked += 1
|
624
661
|
|
625
662
|
# Print stats
|
626
663
|
if self.tp_rank == 0:
|
@@ -675,9 +712,11 @@ class Scheduler:
|
|
675
712
|
|
676
713
|
# Mixed-style chunked prefill
|
677
714
|
if self.is_mixed_chunk and self.running_batch is not None:
|
678
|
-
self.running_batch.
|
679
|
-
|
680
|
-
|
715
|
+
self.running_batch.filter_batch()
|
716
|
+
if not self.running_batch.is_empty():
|
717
|
+
self.running_batch.prepare_for_decode(self.enable_overlap)
|
718
|
+
new_batch.mix_with_running(self.running_batch)
|
719
|
+
new_batch.decoding_reqs = self.running_batch.reqs
|
681
720
|
self.running_batch = None
|
682
721
|
else:
|
683
722
|
new_batch.decoding_reqs = None
|
@@ -726,6 +765,8 @@ class Scheduler:
|
|
726
765
|
|
727
766
|
def run_batch(self, batch: ScheduleBatch):
|
728
767
|
"""Run a batch."""
|
768
|
+
self.forward_ct += 1
|
769
|
+
|
729
770
|
if self.is_generation:
|
730
771
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
731
772
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -734,7 +775,7 @@ class Scheduler:
|
|
734
775
|
)
|
735
776
|
else:
|
736
777
|
logits_output = None
|
737
|
-
if self.
|
778
|
+
if self.skip_tokenizer_init:
|
738
779
|
next_token_ids = torch.full(
|
739
780
|
(batch.batch_size(),), self.tokenizer.eos_token_id
|
740
781
|
)
|
@@ -758,6 +799,7 @@ class Scheduler:
|
|
758
799
|
self.process_batch_result_prefill(batch, result)
|
759
800
|
|
760
801
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
802
|
+
|
761
803
|
if self.is_generation:
|
762
804
|
logits_output, next_token_ids, bid = result
|
763
805
|
|
@@ -783,9 +825,10 @@ class Scheduler:
|
|
783
825
|
# Check finish conditions
|
784
826
|
logprob_pt = 0
|
785
827
|
for i, req in enumerate(batch.reqs):
|
786
|
-
if req.
|
787
|
-
|
788
|
-
|
828
|
+
if req.is_retracted:
|
829
|
+
continue
|
830
|
+
|
831
|
+
if req.is_being_chunked <= 0:
|
789
832
|
# Inflight reqs' prefill is not finished
|
790
833
|
req.completion_tokens_wo_jump_forward += 1
|
791
834
|
req.output_ids.append(next_token_ids[i])
|
@@ -796,24 +839,28 @@ class Scheduler:
|
|
796
839
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
797
840
|
self.tree_cache.cache_unfinished_req(req)
|
798
841
|
|
799
|
-
if req.
|
800
|
-
req.
|
801
|
-
req.regex_fsm_state, next_token_ids[i]
|
802
|
-
)
|
842
|
+
if req.grammar is not None:
|
843
|
+
req.grammar.accept_token(next_token_ids[i])
|
803
844
|
|
804
845
|
if req.return_logprob:
|
805
846
|
logprob_pt += self.add_logprob_return_values(
|
806
847
|
i, req, logprob_pt, next_token_ids, logits_output
|
807
848
|
)
|
849
|
+
else:
|
850
|
+
req.is_being_chunked -= 1
|
851
|
+
|
808
852
|
else: # embedding or reward model
|
809
853
|
embeddings, bid = result
|
810
854
|
embeddings = embeddings.tolist()
|
811
855
|
|
812
856
|
# Check finish conditions
|
813
857
|
for i, req in enumerate(batch.reqs):
|
858
|
+
if req.is_retracted:
|
859
|
+
continue
|
860
|
+
|
814
861
|
req.embedding = embeddings[i]
|
815
|
-
if req.
|
816
|
-
req.
|
862
|
+
if req.is_being_chunked > 0:
|
863
|
+
req.is_being_chunked -= 1
|
817
864
|
else:
|
818
865
|
# Inflight reqs' prefill is not finished
|
819
866
|
# dummy output token for embedding models
|
@@ -847,7 +894,12 @@ class Scheduler:
|
|
847
894
|
|
848
895
|
# Check finish condition
|
849
896
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
850
|
-
if
|
897
|
+
if req.is_retracted:
|
898
|
+
continue
|
899
|
+
|
900
|
+
if self.server_args.enable_overlap_schedule and (
|
901
|
+
req.finished()
|
902
|
+
):
|
851
903
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
852
904
|
continue
|
853
905
|
|
@@ -855,10 +907,8 @@ class Scheduler:
|
|
855
907
|
req.output_ids.append(next_token_id)
|
856
908
|
req.check_finished()
|
857
909
|
|
858
|
-
if req.
|
859
|
-
req.
|
860
|
-
req.regex_fsm_state, next_token_id
|
861
|
-
)
|
910
|
+
if req.grammar is not None:
|
911
|
+
req.grammar.accept_token(next_token_id)
|
862
912
|
|
863
913
|
if req.finished():
|
864
914
|
self.tree_cache.cache_finished_req(req)
|
@@ -874,8 +924,8 @@ class Scheduler:
|
|
874
924
|
|
875
925
|
self.token_to_kv_pool.free_group_end()
|
876
926
|
|
877
|
-
self.
|
878
|
-
if self.tp_rank == 0 and self.
|
927
|
+
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
928
|
+
if self.tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0:
|
879
929
|
self.print_decode_stats()
|
880
930
|
|
881
931
|
def add_logprob_return_values(
|
@@ -954,22 +1004,24 @@ class Scheduler:
|
|
954
1004
|
def stream_output(self, reqs: List[Req]):
|
955
1005
|
"""Stream the output to detokenizer."""
|
956
1006
|
output_rids = []
|
957
|
-
output_meta_info = []
|
1007
|
+
output_meta_info: List[dict] = []
|
958
1008
|
output_finished_reason: List[BaseFinishReason] = []
|
959
1009
|
if self.is_generation:
|
960
1010
|
output_vids = []
|
961
1011
|
decoded_texts = []
|
962
1012
|
output_read_ids = []
|
963
1013
|
output_read_offsets = []
|
1014
|
+
output_ids = []
|
964
1015
|
output_skip_special_tokens = []
|
965
1016
|
output_spaces_between_special_tokens = []
|
966
1017
|
output_no_stop_trim = []
|
967
1018
|
else: # embedding or reward model
|
968
1019
|
output_embeddings = []
|
969
1020
|
|
970
|
-
is_stream_iter = self.
|
1021
|
+
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
|
971
1022
|
|
972
1023
|
for req in reqs:
|
1024
|
+
# TODO(lianmin): revisit this for overlap + retract + stream
|
973
1025
|
if req.finished() or (
|
974
1026
|
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
975
1027
|
):
|
@@ -981,6 +1033,8 @@ class Scheduler:
|
|
981
1033
|
read_ids, read_offset = req.init_incremental_detokenize()
|
982
1034
|
output_read_ids.append(read_ids)
|
983
1035
|
output_read_offsets.append(read_offset)
|
1036
|
+
if self.skip_tokenizer_init:
|
1037
|
+
output_ids.append(req.output_ids)
|
984
1038
|
output_skip_special_tokens.append(
|
985
1039
|
req.sampling_params.skip_special_tokens
|
986
1040
|
)
|
@@ -1032,6 +1086,7 @@ class Scheduler:
|
|
1032
1086
|
decoded_texts,
|
1033
1087
|
output_read_ids,
|
1034
1088
|
output_read_offsets,
|
1089
|
+
output_ids,
|
1035
1090
|
output_skip_special_tokens,
|
1036
1091
|
output_spaces_between_special_tokens,
|
1037
1092
|
output_meta_info,
|
@@ -1056,7 +1111,9 @@ class Scheduler:
|
|
1056
1111
|
):
|
1057
1112
|
self.tree_cache.reset()
|
1058
1113
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
1059
|
-
self.
|
1114
|
+
if self.grammar_cache is not None:
|
1115
|
+
self.grammar_cache.reset()
|
1116
|
+
# TODO(dark): reset the bnf cache
|
1060
1117
|
self.req_to_token_pool.clear()
|
1061
1118
|
self.token_to_kv_pool.clear()
|
1062
1119
|
torch.cuda.empty_cache()
|