sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +28 -10
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +120 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +202 -140
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +60 -1
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +92 -49
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +92 -58
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +116 -17
- sglang/srt/server_args.py +121 -45
- sglang/srt/utils.py +11 -3
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -17,10 +17,11 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import json
|
19
19
|
import logging
|
20
|
-
import multiprocessing
|
21
20
|
import os
|
22
21
|
import time
|
23
22
|
import warnings
|
23
|
+
from collections import deque
|
24
|
+
from types import SimpleNamespace
|
24
25
|
from typing import List, Optional, Union
|
25
26
|
|
26
27
|
import torch
|
@@ -77,6 +78,9 @@ logger = logging.getLogger(__name__)
|
|
77
78
|
# Crash on warning if we are running CI tests
|
78
79
|
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
79
80
|
|
81
|
+
# Test retract decode
|
82
|
+
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
|
83
|
+
|
80
84
|
|
81
85
|
class Scheduler:
|
82
86
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
@@ -107,7 +111,8 @@ class Scheduler:
|
|
107
111
|
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
108
112
|
self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
|
109
113
|
else:
|
110
|
-
self.recv_from_tokenizer =
|
114
|
+
self.recv_from_tokenizer = None
|
115
|
+
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
111
116
|
|
112
117
|
# Init tokenizer
|
113
118
|
self.model_config = ModelConfig(
|
@@ -145,6 +150,7 @@ class Scheduler:
|
|
145
150
|
nccl_port=port_args.nccl_port,
|
146
151
|
)
|
147
152
|
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
153
|
+
self.device = self.tp_worker.device
|
148
154
|
|
149
155
|
# Get token and memory info from the model worker
|
150
156
|
(
|
@@ -190,8 +196,8 @@ class Scheduler:
|
|
190
196
|
|
191
197
|
# Init running status
|
192
198
|
self.waiting_queue: List[Req] = []
|
193
|
-
self.running_batch: ScheduleBatch = None
|
194
|
-
self.
|
199
|
+
self.running_batch: Optional[ScheduleBatch] = None
|
200
|
+
self.cur_batch: Optional[ScheduleBatch] = None
|
195
201
|
self.decode_forward_ct = 0
|
196
202
|
self.stream_interval = server_args.stream_interval
|
197
203
|
self.num_generated_tokens = 0
|
@@ -230,6 +236,7 @@ class Scheduler:
|
|
230
236
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
231
237
|
self.batch_is_full = False
|
232
238
|
|
239
|
+
# Init profiler
|
233
240
|
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
234
241
|
self.profiler = None
|
235
242
|
else:
|
@@ -246,15 +253,75 @@ class Scheduler:
|
|
246
253
|
with_stack=True,
|
247
254
|
)
|
248
255
|
|
256
|
+
# Init states for overlap schedule
|
257
|
+
if self.server_args.enable_overlap_schedule:
|
258
|
+
self.forward_batch_generation = (
|
259
|
+
self.tp_worker.forward_batch_generation_non_blocking
|
260
|
+
)
|
261
|
+
self.resolve_next_token_ids = (
|
262
|
+
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
263
|
+
)
|
264
|
+
self.cache_finished_req = self.tree_cache.cache_finished_req
|
265
|
+
else:
|
266
|
+
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
267
|
+
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
268
|
+
self.cache_finished_req = self.tree_cache.cache_finished_req
|
269
|
+
|
249
270
|
@torch.inference_mode()
|
250
|
-
def
|
271
|
+
def event_loop_normal(self):
|
272
|
+
self.last_batch = None
|
273
|
+
|
251
274
|
while True:
|
252
275
|
recv_reqs = self.recv_requests()
|
253
276
|
self.process_input_requests(recv_reqs)
|
254
277
|
|
255
|
-
self.
|
278
|
+
batch = self.get_next_batch_to_run()
|
279
|
+
|
280
|
+
if batch:
|
281
|
+
result = self.run_batch(batch)
|
282
|
+
self.process_batch_result(batch, result)
|
283
|
+
|
284
|
+
# Decode multiple steps to reduce the overhead
|
285
|
+
if batch.forward_mode.is_decode():
|
286
|
+
for _ in range(self.server_args.num_continuous_decode_steps - 1):
|
287
|
+
if not self.running_batch:
|
288
|
+
break
|
289
|
+
self.update_running_batch()
|
290
|
+
if not self.running_batch:
|
291
|
+
break
|
292
|
+
result = self.run_batch(batch)
|
293
|
+
self.process_batch_result(batch, result)
|
294
|
+
else:
|
295
|
+
self.check_memory()
|
296
|
+
self.new_token_ratio = global_config.init_new_token_ratio
|
297
|
+
|
298
|
+
self.last_batch = batch
|
299
|
+
|
300
|
+
@torch.inference_mode()
|
301
|
+
def event_loop_overlap(self):
|
302
|
+
result_queue = deque()
|
303
|
+
|
304
|
+
self.last_batch = None
|
305
|
+
self.running_batch = None
|
256
306
|
|
257
|
-
|
307
|
+
while True:
|
308
|
+
recv_reqs = self.recv_requests()
|
309
|
+
self.process_input_requests(recv_reqs)
|
310
|
+
|
311
|
+
batch = self.get_next_batch_to_run()
|
312
|
+
self.cur_batch = batch
|
313
|
+
if batch:
|
314
|
+
result = self.run_batch(batch)
|
315
|
+
result_queue.append((batch.copy(), result))
|
316
|
+
|
317
|
+
if self.last_batch:
|
318
|
+
tmp_batch, tmp_result = result_queue.popleft()
|
319
|
+
self.process_batch_result(tmp_batch, tmp_result)
|
320
|
+
elif batch is None:
|
321
|
+
self.check_memory()
|
322
|
+
self.new_token_ratio = global_config.init_new_token_ratio
|
323
|
+
|
324
|
+
self.last_batch = batch
|
258
325
|
|
259
326
|
def recv_requests(self):
|
260
327
|
if self.tp_rank == 0:
|
@@ -287,7 +354,9 @@ class Scheduler:
|
|
287
354
|
self.abort_request(recv_req)
|
288
355
|
elif isinstance(recv_req, UpdateWeightReqInput):
|
289
356
|
success, message = self.update_weights(recv_req)
|
290
|
-
self.
|
357
|
+
self.send_to_detokenizer.send_pyobj(
|
358
|
+
UpdateWeightReqOutput(success, message)
|
359
|
+
)
|
291
360
|
elif isinstance(recv_req, ProfileReq):
|
292
361
|
if recv_req == ProfileReq.START_PROFILE:
|
293
362
|
self.start_profile()
|
@@ -385,12 +454,6 @@ class Scheduler:
|
|
385
454
|
|
386
455
|
self.waiting_queue.append(req)
|
387
456
|
|
388
|
-
def send_results(self):
|
389
|
-
if self.tp_rank == 0:
|
390
|
-
for obj in self.out_pyobjs:
|
391
|
-
self.send_to_detokenizer.send_pyobj(obj)
|
392
|
-
self.out_pyobjs = []
|
393
|
-
|
394
457
|
def print_decode_stats(self):
|
395
458
|
num_used = self.max_total_num_tokens - (
|
396
459
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
@@ -398,9 +461,10 @@ class Scheduler:
|
|
398
461
|
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
399
462
|
self.num_generated_tokens = 0
|
400
463
|
self.last_stats_tic = time.time()
|
464
|
+
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
401
465
|
logger.info(
|
402
466
|
f"Decode batch. "
|
403
|
-
f"#running-req: {
|
467
|
+
f"#running-req: {num_running_reqs}, "
|
404
468
|
f"#token: {num_used}, "
|
405
469
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
406
470
|
f"gen throughput (token/s): {throughput:.2f}, "
|
@@ -428,44 +492,45 @@ class Scheduler:
|
|
428
492
|
)
|
429
493
|
exit(1) if crash_on_warning else None
|
430
494
|
|
431
|
-
def
|
495
|
+
def get_next_batch_to_run(self):
|
496
|
+
# Merge the prefill batch into the running batch
|
497
|
+
if (
|
498
|
+
self.last_batch
|
499
|
+
and not self.last_batch.forward_mode.is_decode()
|
500
|
+
and not self.last_batch.is_empty()
|
501
|
+
):
|
502
|
+
if self.current_inflight_req:
|
503
|
+
self.last_batch.filter_batch(
|
504
|
+
current_inflight_req=self.current_inflight_req
|
505
|
+
)
|
506
|
+
self.tree_cache.cache_unfinished_req(self.current_inflight_req)
|
507
|
+
# Inflight request keeps its rid but will get a new req_pool_idx.
|
508
|
+
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
|
509
|
+
self.batch_is_full = False
|
510
|
+
if not self.last_batch.is_empty():
|
511
|
+
if self.running_batch is None:
|
512
|
+
self.running_batch = self.last_batch
|
513
|
+
else:
|
514
|
+
self.running_batch.merge_batch(self.last_batch)
|
515
|
+
|
516
|
+
# Prefill first
|
432
517
|
new_batch = self.get_new_batch_prefill()
|
433
518
|
if new_batch is not None:
|
434
|
-
|
435
|
-
# replace run_batch with the uncommented line to use pytorch profiler
|
436
|
-
# result = pytorch_profile(
|
437
|
-
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
|
438
|
-
# )
|
439
|
-
result = self.run_batch(new_batch)
|
440
|
-
self.process_batch_result(new_batch, result)
|
441
|
-
else:
|
442
|
-
if self.running_batch is not None:
|
443
|
-
# Run a few decode batches continuously for reducing overhead
|
444
|
-
for _ in range(global_config.num_continue_decode_steps):
|
445
|
-
batch = self.get_new_batch_decode()
|
446
|
-
|
447
|
-
if batch:
|
448
|
-
# replace run_batch with the uncommented line to use pytorch profiler
|
449
|
-
# result = pytorch_profile(
|
450
|
-
# "profile_decode_step",
|
451
|
-
# self.run_batch,
|
452
|
-
# batch,
|
453
|
-
# data_size=len(batch.reqs),
|
454
|
-
# )
|
455
|
-
result = self.run_batch(batch)
|
456
|
-
self.process_batch_result(batch, result)
|
519
|
+
return new_batch
|
457
520
|
|
458
|
-
|
459
|
-
|
521
|
+
# Check memory
|
522
|
+
if self.running_batch is None:
|
523
|
+
return
|
460
524
|
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
525
|
+
# Run decode
|
526
|
+
before_bs = self.running_batch.batch_size()
|
527
|
+
self.update_running_batch()
|
528
|
+
if not self.running_batch:
|
529
|
+
self.batch_is_full = False
|
530
|
+
return None
|
531
|
+
if before_bs != self.running_batch.batch_size():
|
532
|
+
self.batch_is_full = False
|
533
|
+
return self.running_batch
|
469
534
|
|
470
535
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
471
536
|
# Handle the cases where prefill is not allowed
|
@@ -474,9 +539,7 @@ class Scheduler:
|
|
474
539
|
) and self.current_inflight_req is None:
|
475
540
|
return None
|
476
541
|
|
477
|
-
running_bs = (
|
478
|
-
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
479
|
-
)
|
542
|
+
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
480
543
|
if running_bs >= self.max_running_requests:
|
481
544
|
self.batch_is_full = True
|
482
545
|
return None
|
@@ -497,7 +560,7 @@ class Scheduler:
|
|
497
560
|
)
|
498
561
|
|
499
562
|
has_inflight = self.current_inflight_req is not None
|
500
|
-
if
|
563
|
+
if has_inflight:
|
501
564
|
self.current_inflight_req.init_next_round_input(
|
502
565
|
None if prefix_computed else self.tree_cache
|
503
566
|
)
|
@@ -505,7 +568,7 @@ class Scheduler:
|
|
505
568
|
self.current_inflight_req
|
506
569
|
)
|
507
570
|
|
508
|
-
if self.lora_paths
|
571
|
+
if self.lora_paths:
|
509
572
|
lora_set = (
|
510
573
|
set([req.lora_path for req in self.running_batch.reqs])
|
511
574
|
if self.running_batch is not None
|
@@ -514,7 +577,7 @@ class Scheduler:
|
|
514
577
|
|
515
578
|
for req in self.waiting_queue:
|
516
579
|
if (
|
517
|
-
self.lora_paths
|
580
|
+
self.lora_paths
|
518
581
|
and len(
|
519
582
|
lora_set
|
520
583
|
| set([req.lora_path for req in adder.can_run_list])
|
@@ -536,16 +599,20 @@ class Scheduler:
|
|
536
599
|
self.batch_is_full = True
|
537
600
|
break
|
538
601
|
|
602
|
+
# Update waiting queue
|
539
603
|
can_run_list = adder.can_run_list
|
604
|
+
if len(can_run_list) == 0:
|
605
|
+
return None
|
606
|
+
self.waiting_queue = [
|
607
|
+
x for x in self.waiting_queue if x not in set(can_run_list)
|
608
|
+
]
|
540
609
|
|
541
610
|
if adder.new_inflight_req is not None:
|
542
611
|
assert self.current_inflight_req is None
|
543
612
|
self.current_inflight_req = adder.new_inflight_req
|
544
613
|
|
545
|
-
if
|
546
|
-
|
547
|
-
|
548
|
-
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
614
|
+
if self.current_inflight_req:
|
615
|
+
self.current_inflight_req.is_inflight_req += 1
|
549
616
|
|
550
617
|
# Print stats
|
551
618
|
if self.tp_rank == 0:
|
@@ -598,21 +665,27 @@ class Scheduler:
|
|
598
665
|
new_batch.prepare_for_extend(self.model_config.vocab_size)
|
599
666
|
|
600
667
|
# Mixed-style chunked prefill
|
601
|
-
decoding_reqs = []
|
602
668
|
if self.is_mixed_chunk and self.running_batch is not None:
|
603
669
|
self.running_batch.prepare_for_decode()
|
604
670
|
new_batch.mix_with_running(self.running_batch)
|
605
|
-
decoding_reqs = self.running_batch.reqs
|
671
|
+
new_batch.decoding_reqs = self.running_batch.reqs
|
606
672
|
self.running_batch = None
|
607
|
-
|
673
|
+
else:
|
674
|
+
new_batch.decoding_reqs = None
|
608
675
|
|
609
676
|
return new_batch
|
610
677
|
|
611
|
-
def
|
678
|
+
def update_running_batch(self):
|
679
|
+
global test_retract
|
612
680
|
batch = self.running_batch
|
613
681
|
|
682
|
+
batch.filter_batch()
|
683
|
+
if batch.is_empty():
|
684
|
+
self.running_batch = None
|
685
|
+
return
|
686
|
+
|
614
687
|
# Check if decode out of memory
|
615
|
-
if not batch.check_decode_mem():
|
688
|
+
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
|
616
689
|
old_ratio = self.new_token_ratio
|
617
690
|
|
618
691
|
retracted_reqs, new_token_ratio = batch.retract_decode()
|
@@ -635,17 +708,17 @@ class Scheduler:
|
|
635
708
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
636
709
|
self.waiting_queue.extend(jump_forward_reqs)
|
637
710
|
if batch.is_empty():
|
638
|
-
|
711
|
+
self.running_batch = None
|
712
|
+
return
|
639
713
|
|
640
714
|
# Update batch tensors
|
641
715
|
batch.prepare_for_decode()
|
642
|
-
return batch
|
643
716
|
|
644
717
|
def run_batch(self, batch: ScheduleBatch):
|
645
718
|
if self.is_generation:
|
646
719
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
647
720
|
model_worker_batch = batch.get_model_worker_batch()
|
648
|
-
logits_output, next_token_ids = self.
|
721
|
+
logits_output, next_token_ids = self.forward_batch_generation(
|
649
722
|
model_worker_batch
|
650
723
|
)
|
651
724
|
else:
|
@@ -656,34 +729,32 @@ class Scheduler:
|
|
656
729
|
)
|
657
730
|
else:
|
658
731
|
next_token_ids = torch.full((batch.batch_size(),), 0)
|
659
|
-
|
732
|
+
batch.output_ids = next_token_ids
|
733
|
+
ret = logits_output, next_token_ids, model_worker_batch.bid
|
660
734
|
else: # embedding or reward model
|
661
735
|
assert batch.extend_num_tokens != 0
|
662
736
|
model_worker_batch = batch.get_model_worker_batch()
|
663
737
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
664
|
-
|
738
|
+
ret = embeddings, model_worker_batch.bid
|
739
|
+
return ret
|
665
740
|
|
666
741
|
def process_batch_result(self, batch: ScheduleBatch, result):
|
667
742
|
if batch.forward_mode.is_decode():
|
668
743
|
self.process_batch_result_decode(batch, result)
|
744
|
+
if batch.is_empty():
|
745
|
+
self.running_batch = None
|
669
746
|
else:
|
670
747
|
self.process_batch_result_prefill(batch, result)
|
671
748
|
|
672
749
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
673
750
|
if self.is_generation:
|
674
|
-
logits_output, next_token_ids = result
|
675
|
-
batch.
|
676
|
-
next_token_ids
|
677
|
-
)
|
678
|
-
|
679
|
-
if logits_output:
|
751
|
+
logits_output, next_token_ids, bid = result
|
752
|
+
if batch.return_logprob:
|
680
753
|
# Move logprobs to cpu
|
681
754
|
if logits_output.next_token_logprobs is not None:
|
682
755
|
logits_output.next_token_logprobs = (
|
683
756
|
logits_output.next_token_logprobs[
|
684
|
-
torch.arange(
|
685
|
-
len(next_token_ids), device=next_token_ids.device
|
686
|
-
),
|
757
|
+
torch.arange(len(next_token_ids), device=self.device),
|
687
758
|
next_token_ids,
|
688
759
|
].tolist()
|
689
760
|
)
|
@@ -694,84 +765,76 @@ class Scheduler:
|
|
694
765
|
logits_output.normalized_prompt_logprobs.tolist()
|
695
766
|
)
|
696
767
|
|
697
|
-
next_token_ids =
|
768
|
+
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
698
769
|
|
699
770
|
# Check finish conditions
|
700
771
|
logprob_pt = 0
|
701
772
|
for i, req in enumerate(batch.reqs):
|
702
|
-
if req
|
773
|
+
if req.is_inflight_req > 0:
|
774
|
+
req.is_inflight_req -= 1
|
775
|
+
else:
|
703
776
|
# Inflight reqs' prefill is not finished
|
704
777
|
req.completion_tokens_wo_jump_forward += 1
|
705
778
|
req.output_ids.append(next_token_ids[i])
|
706
779
|
req.check_finished()
|
707
780
|
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
if req.finished():
|
714
|
-
self.tree_cache.cache_finished_req(req)
|
715
|
-
elif req not in batch.decoding_reqs:
|
716
|
-
# To reduce overhead, only cache prefill reqs
|
717
|
-
self.tree_cache.cache_unfinished_req(req)
|
781
|
+
if req.finished():
|
782
|
+
self.cache_finished_req(req)
|
783
|
+
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
784
|
+
self.tree_cache.cache_unfinished_req(req)
|
718
785
|
|
719
|
-
|
720
|
-
|
721
|
-
|
786
|
+
if req.regex_fsm is not None:
|
787
|
+
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
788
|
+
req.regex_fsm_state, next_token_ids[i]
|
789
|
+
)
|
722
790
|
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
791
|
+
if req.return_logprob:
|
792
|
+
logprob_pt += self.add_logprob_return_values(
|
793
|
+
i, req, logprob_pt, next_token_ids, logits_output
|
794
|
+
)
|
727
795
|
else: # embedding or reward model
|
728
|
-
|
729
|
-
embeddings =
|
796
|
+
embeddings, bid = result
|
797
|
+
embeddings = embeddings.tolist()
|
730
798
|
|
731
799
|
# Check finish conditions
|
732
800
|
for i, req in enumerate(batch.reqs):
|
733
801
|
req.embedding = embeddings[i]
|
734
|
-
if req
|
802
|
+
if req.is_inflight_req > 0:
|
803
|
+
req.is_inflight_req -= 1
|
804
|
+
else:
|
735
805
|
# Inflight reqs' prefill is not finished
|
736
806
|
# dummy output token for embedding models
|
737
807
|
req.output_ids.append(0)
|
738
808
|
req.check_finished()
|
739
809
|
|
740
810
|
if req.finished():
|
741
|
-
self.
|
811
|
+
self.cache_finished_req(req)
|
742
812
|
else:
|
743
813
|
self.tree_cache.cache_unfinished_req(req)
|
744
814
|
|
745
|
-
|
746
|
-
# Inflight request would get a new req idx
|
747
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
748
|
-
|
749
|
-
self.handle_finished_requests(batch)
|
750
|
-
|
751
|
-
if not batch.is_empty():
|
752
|
-
if self.running_batch is None:
|
753
|
-
self.running_batch = batch
|
754
|
-
else:
|
755
|
-
self.running_batch.merge_batch(batch)
|
815
|
+
self.stream_output(batch.reqs)
|
756
816
|
|
757
817
|
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
758
|
-
logits_output, next_token_ids = result
|
759
|
-
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
760
|
-
next_token_ids
|
761
|
-
)
|
818
|
+
logits_output, next_token_ids, bid = result
|
762
819
|
self.num_generated_tokens += len(batch.reqs)
|
763
820
|
|
764
821
|
# Move logprobs to cpu
|
765
|
-
if
|
822
|
+
if batch.return_logprob:
|
766
823
|
next_token_logprobs = logits_output.next_token_logprobs[
|
767
|
-
torch.arange(len(next_token_ids), device=
|
824
|
+
torch.arange(len(next_token_ids), device=self.device),
|
768
825
|
next_token_ids,
|
769
826
|
].tolist()
|
770
827
|
|
771
|
-
next_token_ids =
|
828
|
+
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
829
|
+
|
830
|
+
self.token_to_kv_pool.free_group_begin()
|
772
831
|
|
773
832
|
# Check finish condition
|
774
833
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
834
|
+
if self.server_args.enable_overlap_schedule and req.finished():
|
835
|
+
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
836
|
+
continue
|
837
|
+
|
775
838
|
req.completion_tokens_wo_jump_forward += 1
|
776
839
|
req.output_ids.append(next_token_id)
|
777
840
|
req.check_finished()
|
@@ -782,7 +845,7 @@ class Scheduler:
|
|
782
845
|
)
|
783
846
|
|
784
847
|
if req.finished():
|
785
|
-
self.
|
848
|
+
self.cache_finished_req(req)
|
786
849
|
|
787
850
|
if req.return_logprob:
|
788
851
|
req.output_token_logprobs.append(
|
@@ -791,7 +854,9 @@ class Scheduler:
|
|
791
854
|
if req.top_logprobs_num > 0:
|
792
855
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
793
856
|
|
794
|
-
self.
|
857
|
+
self.stream_output(batch.reqs)
|
858
|
+
|
859
|
+
self.token_to_kv_pool.free_group_end()
|
795
860
|
|
796
861
|
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
797
862
|
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
@@ -870,7 +935,7 @@ class Scheduler:
|
|
870
935
|
|
871
936
|
return num_input_logprobs
|
872
937
|
|
873
|
-
def
|
938
|
+
def stream_output(self, reqs: List[Req]):
|
874
939
|
output_rids = []
|
875
940
|
output_meta_info = []
|
876
941
|
output_finished_reason: List[BaseFinishReason] = []
|
@@ -881,22 +946,15 @@ class Scheduler:
|
|
881
946
|
output_read_offsets = []
|
882
947
|
output_skip_special_tokens = []
|
883
948
|
output_spaces_between_special_tokens = []
|
949
|
+
output_no_stop_trim = []
|
884
950
|
else: # embedding or reward model
|
885
951
|
output_embeddings = []
|
886
|
-
unfinished_indices = []
|
887
952
|
|
888
|
-
|
889
|
-
if not req.finished() and req is not self.current_inflight_req:
|
890
|
-
unfinished_indices.append(i)
|
891
|
-
else:
|
892
|
-
self.batch_is_full = False
|
953
|
+
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
|
893
954
|
|
955
|
+
for req in reqs:
|
894
956
|
if req.finished() or (
|
895
|
-
req.stream
|
896
|
-
and (
|
897
|
-
self.decode_forward_ct % self.stream_interval == 0
|
898
|
-
or len(req.output_ids) == 1
|
899
|
-
)
|
957
|
+
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
900
958
|
):
|
901
959
|
output_rids.append(req.rid)
|
902
960
|
output_finished_reason.append(req.finished_reason)
|
@@ -912,11 +970,13 @@ class Scheduler:
|
|
912
970
|
output_spaces_between_special_tokens.append(
|
913
971
|
req.sampling_params.spaces_between_special_tokens
|
914
972
|
)
|
973
|
+
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
915
974
|
|
916
975
|
meta_info = {
|
917
976
|
"prompt_tokens": len(req.origin_input_ids),
|
918
977
|
"completion_tokens": len(req.output_ids),
|
919
978
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
979
|
+
"cached_tokens": req.cached_tokens,
|
920
980
|
"finish_reason": (
|
921
981
|
req.finished_reason.to_json()
|
922
982
|
if req.finished_reason is not None
|
@@ -948,7 +1008,7 @@ class Scheduler:
|
|
948
1008
|
# Send to detokenizer
|
949
1009
|
if output_rids:
|
950
1010
|
if self.is_generation:
|
951
|
-
self.
|
1011
|
+
self.send_to_detokenizer.send_pyobj(
|
952
1012
|
BatchTokenIDOut(
|
953
1013
|
output_rids,
|
954
1014
|
output_vids,
|
@@ -959,10 +1019,11 @@ class Scheduler:
|
|
959
1019
|
output_spaces_between_special_tokens,
|
960
1020
|
output_meta_info,
|
961
1021
|
output_finished_reason,
|
1022
|
+
output_no_stop_trim,
|
962
1023
|
)
|
963
1024
|
)
|
964
1025
|
else: # embedding or reward model
|
965
|
-
self.
|
1026
|
+
self.send_to_detokenizer.send_pyobj(
|
966
1027
|
BatchEmbeddingOut(
|
967
1028
|
output_rids,
|
968
1029
|
output_embeddings,
|
@@ -971,9 +1032,6 @@ class Scheduler:
|
|
971
1032
|
)
|
972
1033
|
)
|
973
1034
|
|
974
|
-
# Remove finished reqs: update batch tensors
|
975
|
-
batch.filter_batch(unfinished_indices)
|
976
|
-
|
977
1035
|
def flush_cache(self):
|
978
1036
|
if len(self.waiting_queue) == 0 and (
|
979
1037
|
self.running_batch is None or len(self.running_batch.reqs) == 0
|
@@ -1009,8 +1067,9 @@ class Scheduler:
|
|
1009
1067
|
# Delete requests in the running batch
|
1010
1068
|
if self.running_batch:
|
1011
1069
|
for req in self.running_batch.reqs:
|
1012
|
-
if req.rid == recv_req.rid:
|
1070
|
+
if req.rid == recv_req.rid and not req.finished():
|
1013
1071
|
req.finished_reason = FINISH_ABORT()
|
1072
|
+
self.cache_finished_req(req)
|
1014
1073
|
break
|
1015
1074
|
|
1016
1075
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
@@ -1055,7 +1114,10 @@ def run_scheduler_process(
|
|
1055
1114
|
try:
|
1056
1115
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
|
1057
1116
|
pipe_writer.send("ready")
|
1058
|
-
|
1117
|
+
if server_args.enable_overlap_schedule:
|
1118
|
+
scheduler.event_loop_overlap()
|
1119
|
+
else:
|
1120
|
+
scheduler.event_loop_normal()
|
1059
1121
|
except Exception:
|
1060
1122
|
msg = get_exception_traceback()
|
1061
1123
|
logger.error(msg)
|
@@ -150,9 +150,13 @@ class TokenizerManager:
|
|
150
150
|
while self.model_update_lock.locked():
|
151
151
|
await asyncio.sleep(0.001)
|
152
152
|
|
153
|
+
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
154
|
+
raise ValueError(
|
155
|
+
"This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
|
156
|
+
)
|
157
|
+
|
153
158
|
obj.post_init()
|
154
159
|
is_single = obj.is_single
|
155
|
-
|
156
160
|
if is_single:
|
157
161
|
async for response in self._handle_single_request(obj, request):
|
158
162
|
yield response
|