sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_latency.py +28 -10
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/layers/attention/__init__.py +27 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  7. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  8. sglang/srt/layers/attention/triton_backend.py +6 -4
  9. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  10. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  12. sglang/srt/layers/sampler.py +6 -2
  13. sglang/srt/managers/detokenizer_manager.py +31 -10
  14. sglang/srt/managers/io_struct.py +4 -0
  15. sglang/srt/managers/schedule_batch.py +120 -43
  16. sglang/srt/managers/schedule_policy.py +2 -1
  17. sglang/srt/managers/scheduler.py +202 -140
  18. sglang/srt/managers/tokenizer_manager.py +5 -1
  19. sglang/srt/managers/tp_worker.py +111 -1
  20. sglang/srt/mem_cache/chunk_cache.py +8 -4
  21. sglang/srt/mem_cache/memory_pool.py +77 -4
  22. sglang/srt/mem_cache/radix_cache.py +15 -7
  23. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  24. sglang/srt/model_executor/forward_batch_info.py +16 -21
  25. sglang/srt/model_executor/model_runner.py +60 -1
  26. sglang/srt/models/baichuan.py +2 -3
  27. sglang/srt/models/chatglm.py +5 -6
  28. sglang/srt/models/commandr.py +1 -2
  29. sglang/srt/models/dbrx.py +1 -2
  30. sglang/srt/models/deepseek.py +4 -5
  31. sglang/srt/models/deepseek_v2.py +5 -6
  32. sglang/srt/models/exaone.py +1 -2
  33. sglang/srt/models/gemma.py +2 -2
  34. sglang/srt/models/gemma2.py +5 -5
  35. sglang/srt/models/gpt_bigcode.py +5 -5
  36. sglang/srt/models/grok.py +1 -2
  37. sglang/srt/models/internlm2.py +1 -2
  38. sglang/srt/models/llama.py +1 -2
  39. sglang/srt/models/llama_classification.py +1 -2
  40. sglang/srt/models/llama_reward.py +2 -3
  41. sglang/srt/models/llava.py +4 -8
  42. sglang/srt/models/llavavid.py +1 -2
  43. sglang/srt/models/minicpm.py +1 -2
  44. sglang/srt/models/minicpm3.py +5 -6
  45. sglang/srt/models/mixtral.py +1 -2
  46. sglang/srt/models/mixtral_quant.py +1 -2
  47. sglang/srt/models/olmo.py +352 -0
  48. sglang/srt/models/olmoe.py +1 -2
  49. sglang/srt/models/qwen.py +1 -2
  50. sglang/srt/models/qwen2.py +1 -2
  51. sglang/srt/models/qwen2_moe.py +4 -5
  52. sglang/srt/models/stablelm.py +1 -2
  53. sglang/srt/models/torch_native_llama.py +1 -2
  54. sglang/srt/models/xverse.py +1 -2
  55. sglang/srt/models/xverse_moe.py +4 -5
  56. sglang/srt/models/yivl.py +1 -2
  57. sglang/srt/openai_api/adapter.py +92 -49
  58. sglang/srt/openai_api/protocol.py +10 -2
  59. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  60. sglang/srt/sampling/sampling_batch_info.py +92 -58
  61. sglang/srt/sampling/sampling_params.py +2 -0
  62. sglang/srt/server.py +116 -17
  63. sglang/srt/server_args.py +121 -45
  64. sglang/srt/utils.py +11 -3
  65. sglang/test/few_shot_gsm8k.py +4 -1
  66. sglang/test/few_shot_gsm8k_engine.py +144 -0
  67. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  68. sglang/version.py +1 -1
  69. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
  70. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
  71. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  72. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  73. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  74. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,11 @@ limitations under the License.
17
17
 
18
18
  import json
19
19
  import logging
20
- import multiprocessing
21
20
  import os
22
21
  import time
23
22
  import warnings
23
+ from collections import deque
24
+ from types import SimpleNamespace
24
25
  from typing import List, Optional, Union
25
26
 
26
27
  import torch
@@ -77,6 +78,9 @@ logger = logging.getLogger(__name__)
77
78
  # Crash on warning if we are running CI tests
78
79
  crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
79
80
 
81
+ # Test retract decode
82
+ test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
83
+
80
84
 
81
85
  class Scheduler:
82
86
  """A scheduler that manages a tensor parallel GPU worker."""
@@ -107,7 +111,8 @@ class Scheduler:
107
111
  self.send_to_detokenizer = context.socket(zmq.PUSH)
108
112
  self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
109
113
  else:
110
- self.recv_from_tokenizer = self.send_to_detokenizer = None
114
+ self.recv_from_tokenizer = None
115
+ self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
111
116
 
112
117
  # Init tokenizer
113
118
  self.model_config = ModelConfig(
@@ -145,6 +150,7 @@ class Scheduler:
145
150
  nccl_port=port_args.nccl_port,
146
151
  )
147
152
  self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
153
+ self.device = self.tp_worker.device
148
154
 
149
155
  # Get token and memory info from the model worker
150
156
  (
@@ -190,8 +196,8 @@ class Scheduler:
190
196
 
191
197
  # Init running status
192
198
  self.waiting_queue: List[Req] = []
193
- self.running_batch: ScheduleBatch = None
194
- self.out_pyobjs = []
199
+ self.running_batch: Optional[ScheduleBatch] = None
200
+ self.cur_batch: Optional[ScheduleBatch] = None
195
201
  self.decode_forward_ct = 0
196
202
  self.stream_interval = server_args.stream_interval
197
203
  self.num_generated_tokens = 0
@@ -230,6 +236,7 @@ class Scheduler:
230
236
  self.new_token_ratio_decay = global_config.new_token_ratio_decay
231
237
  self.batch_is_full = False
232
238
 
239
+ # Init profiler
233
240
  if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
234
241
  self.profiler = None
235
242
  else:
@@ -246,15 +253,75 @@ class Scheduler:
246
253
  with_stack=True,
247
254
  )
248
255
 
256
+ # Init states for overlap schedule
257
+ if self.server_args.enable_overlap_schedule:
258
+ self.forward_batch_generation = (
259
+ self.tp_worker.forward_batch_generation_non_blocking
260
+ )
261
+ self.resolve_next_token_ids = (
262
+ lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
263
+ )
264
+ self.cache_finished_req = self.tree_cache.cache_finished_req
265
+ else:
266
+ self.forward_batch_generation = self.tp_worker.forward_batch_generation
267
+ self.resolve_next_token_ids = lambda bid, x: x.tolist()
268
+ self.cache_finished_req = self.tree_cache.cache_finished_req
269
+
249
270
  @torch.inference_mode()
250
- def event_loop(self):
271
+ def event_loop_normal(self):
272
+ self.last_batch = None
273
+
251
274
  while True:
252
275
  recv_reqs = self.recv_requests()
253
276
  self.process_input_requests(recv_reqs)
254
277
 
255
- self.run_step()
278
+ batch = self.get_next_batch_to_run()
279
+
280
+ if batch:
281
+ result = self.run_batch(batch)
282
+ self.process_batch_result(batch, result)
283
+
284
+ # Decode multiple steps to reduce the overhead
285
+ if batch.forward_mode.is_decode():
286
+ for _ in range(self.server_args.num_continuous_decode_steps - 1):
287
+ if not self.running_batch:
288
+ break
289
+ self.update_running_batch()
290
+ if not self.running_batch:
291
+ break
292
+ result = self.run_batch(batch)
293
+ self.process_batch_result(batch, result)
294
+ else:
295
+ self.check_memory()
296
+ self.new_token_ratio = global_config.init_new_token_ratio
297
+
298
+ self.last_batch = batch
299
+
300
+ @torch.inference_mode()
301
+ def event_loop_overlap(self):
302
+ result_queue = deque()
303
+
304
+ self.last_batch = None
305
+ self.running_batch = None
256
306
 
257
- self.send_results()
307
+ while True:
308
+ recv_reqs = self.recv_requests()
309
+ self.process_input_requests(recv_reqs)
310
+
311
+ batch = self.get_next_batch_to_run()
312
+ self.cur_batch = batch
313
+ if batch:
314
+ result = self.run_batch(batch)
315
+ result_queue.append((batch.copy(), result))
316
+
317
+ if self.last_batch:
318
+ tmp_batch, tmp_result = result_queue.popleft()
319
+ self.process_batch_result(tmp_batch, tmp_result)
320
+ elif batch is None:
321
+ self.check_memory()
322
+ self.new_token_ratio = global_config.init_new_token_ratio
323
+
324
+ self.last_batch = batch
258
325
 
259
326
  def recv_requests(self):
260
327
  if self.tp_rank == 0:
@@ -287,7 +354,9 @@ class Scheduler:
287
354
  self.abort_request(recv_req)
288
355
  elif isinstance(recv_req, UpdateWeightReqInput):
289
356
  success, message = self.update_weights(recv_req)
290
- self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
357
+ self.send_to_detokenizer.send_pyobj(
358
+ UpdateWeightReqOutput(success, message)
359
+ )
291
360
  elif isinstance(recv_req, ProfileReq):
292
361
  if recv_req == ProfileReq.START_PROFILE:
293
362
  self.start_profile()
@@ -385,12 +454,6 @@ class Scheduler:
385
454
 
386
455
  self.waiting_queue.append(req)
387
456
 
388
- def send_results(self):
389
- if self.tp_rank == 0:
390
- for obj in self.out_pyobjs:
391
- self.send_to_detokenizer.send_pyobj(obj)
392
- self.out_pyobjs = []
393
-
394
457
  def print_decode_stats(self):
395
458
  num_used = self.max_total_num_tokens - (
396
459
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
@@ -398,9 +461,10 @@ class Scheduler:
398
461
  throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
399
462
  self.num_generated_tokens = 0
400
463
  self.last_stats_tic = time.time()
464
+ num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
401
465
  logger.info(
402
466
  f"Decode batch. "
403
- f"#running-req: {len(self.running_batch.reqs)}, "
467
+ f"#running-req: {num_running_reqs}, "
404
468
  f"#token: {num_used}, "
405
469
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
406
470
  f"gen throughput (token/s): {throughput:.2f}, "
@@ -428,44 +492,45 @@ class Scheduler:
428
492
  )
429
493
  exit(1) if crash_on_warning else None
430
494
 
431
- def run_step(self):
495
+ def get_next_batch_to_run(self):
496
+ # Merge the prefill batch into the running batch
497
+ if (
498
+ self.last_batch
499
+ and not self.last_batch.forward_mode.is_decode()
500
+ and not self.last_batch.is_empty()
501
+ ):
502
+ if self.current_inflight_req:
503
+ self.last_batch.filter_batch(
504
+ current_inflight_req=self.current_inflight_req
505
+ )
506
+ self.tree_cache.cache_unfinished_req(self.current_inflight_req)
507
+ # Inflight request keeps its rid but will get a new req_pool_idx.
508
+ self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
509
+ self.batch_is_full = False
510
+ if not self.last_batch.is_empty():
511
+ if self.running_batch is None:
512
+ self.running_batch = self.last_batch
513
+ else:
514
+ self.running_batch.merge_batch(self.last_batch)
515
+
516
+ # Prefill first
432
517
  new_batch = self.get_new_batch_prefill()
433
518
  if new_batch is not None:
434
- # Run a new prefill batch
435
- # replace run_batch with the uncommented line to use pytorch profiler
436
- # result = pytorch_profile(
437
- # "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
438
- # )
439
- result = self.run_batch(new_batch)
440
- self.process_batch_result(new_batch, result)
441
- else:
442
- if self.running_batch is not None:
443
- # Run a few decode batches continuously for reducing overhead
444
- for _ in range(global_config.num_continue_decode_steps):
445
- batch = self.get_new_batch_decode()
446
-
447
- if batch:
448
- # replace run_batch with the uncommented line to use pytorch profiler
449
- # result = pytorch_profile(
450
- # "profile_decode_step",
451
- # self.run_batch,
452
- # batch,
453
- # data_size=len(batch.reqs),
454
- # )
455
- result = self.run_batch(batch)
456
- self.process_batch_result(batch, result)
519
+ return new_batch
457
520
 
458
- if self.running_batch.is_empty():
459
- self.running_batch = None
521
+ # Check memory
522
+ if self.running_batch is None:
523
+ return
460
524
 
461
- if self.running_batch is None:
462
- break
463
-
464
- if self.out_pyobjs and self.running_batch.has_stream:
465
- break
466
- else:
467
- self.check_memory()
468
- self.new_token_ratio = global_config.init_new_token_ratio
525
+ # Run decode
526
+ before_bs = self.running_batch.batch_size()
527
+ self.update_running_batch()
528
+ if not self.running_batch:
529
+ self.batch_is_full = False
530
+ return None
531
+ if before_bs != self.running_batch.batch_size():
532
+ self.batch_is_full = False
533
+ return self.running_batch
469
534
 
470
535
  def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
471
536
  # Handle the cases where prefill is not allowed
@@ -474,9 +539,7 @@ class Scheduler:
474
539
  ) and self.current_inflight_req is None:
475
540
  return None
476
541
 
477
- running_bs = (
478
- len(self.running_batch.reqs) if self.running_batch is not None else 0
479
- )
542
+ running_bs = len(self.running_batch.reqs) if self.running_batch else 0
480
543
  if running_bs >= self.max_running_requests:
481
544
  self.batch_is_full = True
482
545
  return None
@@ -497,7 +560,7 @@ class Scheduler:
497
560
  )
498
561
 
499
562
  has_inflight = self.current_inflight_req is not None
500
- if self.current_inflight_req is not None:
563
+ if has_inflight:
501
564
  self.current_inflight_req.init_next_round_input(
502
565
  None if prefix_computed else self.tree_cache
503
566
  )
@@ -505,7 +568,7 @@ class Scheduler:
505
568
  self.current_inflight_req
506
569
  )
507
570
 
508
- if self.lora_paths is not None:
571
+ if self.lora_paths:
509
572
  lora_set = (
510
573
  set([req.lora_path for req in self.running_batch.reqs])
511
574
  if self.running_batch is not None
@@ -514,7 +577,7 @@ class Scheduler:
514
577
 
515
578
  for req in self.waiting_queue:
516
579
  if (
517
- self.lora_paths is not None
580
+ self.lora_paths
518
581
  and len(
519
582
  lora_set
520
583
  | set([req.lora_path for req in adder.can_run_list])
@@ -536,16 +599,20 @@ class Scheduler:
536
599
  self.batch_is_full = True
537
600
  break
538
601
 
602
+ # Update waiting queue
539
603
  can_run_list = adder.can_run_list
604
+ if len(can_run_list) == 0:
605
+ return None
606
+ self.waiting_queue = [
607
+ x for x in self.waiting_queue if x not in set(can_run_list)
608
+ ]
540
609
 
541
610
  if adder.new_inflight_req is not None:
542
611
  assert self.current_inflight_req is None
543
612
  self.current_inflight_req = adder.new_inflight_req
544
613
 
545
- if len(can_run_list) == 0:
546
- return None
547
-
548
- self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
614
+ if self.current_inflight_req:
615
+ self.current_inflight_req.is_inflight_req += 1
549
616
 
550
617
  # Print stats
551
618
  if self.tp_rank == 0:
@@ -598,21 +665,27 @@ class Scheduler:
598
665
  new_batch.prepare_for_extend(self.model_config.vocab_size)
599
666
 
600
667
  # Mixed-style chunked prefill
601
- decoding_reqs = []
602
668
  if self.is_mixed_chunk and self.running_batch is not None:
603
669
  self.running_batch.prepare_for_decode()
604
670
  new_batch.mix_with_running(self.running_batch)
605
- decoding_reqs = self.running_batch.reqs
671
+ new_batch.decoding_reqs = self.running_batch.reqs
606
672
  self.running_batch = None
607
- new_batch.decoding_reqs = decoding_reqs
673
+ else:
674
+ new_batch.decoding_reqs = None
608
675
 
609
676
  return new_batch
610
677
 
611
- def get_new_batch_decode(self) -> Optional[ScheduleBatch]:
678
+ def update_running_batch(self):
679
+ global test_retract
612
680
  batch = self.running_batch
613
681
 
682
+ batch.filter_batch()
683
+ if batch.is_empty():
684
+ self.running_batch = None
685
+ return
686
+
614
687
  # Check if decode out of memory
615
- if not batch.check_decode_mem():
688
+ if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
616
689
  old_ratio = self.new_token_ratio
617
690
 
618
691
  retracted_reqs, new_token_ratio = batch.retract_decode()
@@ -635,17 +708,17 @@ class Scheduler:
635
708
  jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
636
709
  self.waiting_queue.extend(jump_forward_reqs)
637
710
  if batch.is_empty():
638
- return None
711
+ self.running_batch = None
712
+ return
639
713
 
640
714
  # Update batch tensors
641
715
  batch.prepare_for_decode()
642
- return batch
643
716
 
644
717
  def run_batch(self, batch: ScheduleBatch):
645
718
  if self.is_generation:
646
719
  if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
647
720
  model_worker_batch = batch.get_model_worker_batch()
648
- logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
721
+ logits_output, next_token_ids = self.forward_batch_generation(
649
722
  model_worker_batch
650
723
  )
651
724
  else:
@@ -656,34 +729,32 @@ class Scheduler:
656
729
  )
657
730
  else:
658
731
  next_token_ids = torch.full((batch.batch_size(),), 0)
659
- return logits_output, next_token_ids
732
+ batch.output_ids = next_token_ids
733
+ ret = logits_output, next_token_ids, model_worker_batch.bid
660
734
  else: # embedding or reward model
661
735
  assert batch.extend_num_tokens != 0
662
736
  model_worker_batch = batch.get_model_worker_batch()
663
737
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
664
- return embeddings
738
+ ret = embeddings, model_worker_batch.bid
739
+ return ret
665
740
 
666
741
  def process_batch_result(self, batch: ScheduleBatch, result):
667
742
  if batch.forward_mode.is_decode():
668
743
  self.process_batch_result_decode(batch, result)
744
+ if batch.is_empty():
745
+ self.running_batch = None
669
746
  else:
670
747
  self.process_batch_result_prefill(batch, result)
671
748
 
672
749
  def process_batch_result_prefill(self, batch: ScheduleBatch, result):
673
750
  if self.is_generation:
674
- logits_output, next_token_ids = result
675
- batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
676
- next_token_ids
677
- )
678
-
679
- if logits_output:
751
+ logits_output, next_token_ids, bid = result
752
+ if batch.return_logprob:
680
753
  # Move logprobs to cpu
681
754
  if logits_output.next_token_logprobs is not None:
682
755
  logits_output.next_token_logprobs = (
683
756
  logits_output.next_token_logprobs[
684
- torch.arange(
685
- len(next_token_ids), device=next_token_ids.device
686
- ),
757
+ torch.arange(len(next_token_ids), device=self.device),
687
758
  next_token_ids,
688
759
  ].tolist()
689
760
  )
@@ -694,84 +765,76 @@ class Scheduler:
694
765
  logits_output.normalized_prompt_logprobs.tolist()
695
766
  )
696
767
 
697
- next_token_ids = next_token_ids.tolist()
768
+ next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
698
769
 
699
770
  # Check finish conditions
700
771
  logprob_pt = 0
701
772
  for i, req in enumerate(batch.reqs):
702
- if req is not self.current_inflight_req:
773
+ if req.is_inflight_req > 0:
774
+ req.is_inflight_req -= 1
775
+ else:
703
776
  # Inflight reqs' prefill is not finished
704
777
  req.completion_tokens_wo_jump_forward += 1
705
778
  req.output_ids.append(next_token_ids[i])
706
779
  req.check_finished()
707
780
 
708
- if req.regex_fsm is not None:
709
- req.regex_fsm_state = req.regex_fsm.get_next_state(
710
- req.regex_fsm_state, next_token_ids[i]
711
- )
712
-
713
- if req.finished():
714
- self.tree_cache.cache_finished_req(req)
715
- elif req not in batch.decoding_reqs:
716
- # To reduce overhead, only cache prefill reqs
717
- self.tree_cache.cache_unfinished_req(req)
781
+ if req.finished():
782
+ self.cache_finished_req(req)
783
+ elif not batch.decoding_reqs or req not in batch.decoding_reqs:
784
+ self.tree_cache.cache_unfinished_req(req)
718
785
 
719
- if req is self.current_inflight_req:
720
- # Inflight request would get a new req idx
721
- self.req_to_token_pool.free(req.req_pool_idx)
786
+ if req.regex_fsm is not None:
787
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
788
+ req.regex_fsm_state, next_token_ids[i]
789
+ )
722
790
 
723
- if req.return_logprob:
724
- logprob_pt += self.add_logprob_return_values(
725
- i, req, logprob_pt, next_token_ids, logits_output
726
- )
791
+ if req.return_logprob:
792
+ logprob_pt += self.add_logprob_return_values(
793
+ i, req, logprob_pt, next_token_ids, logits_output
794
+ )
727
795
  else: # embedding or reward model
728
- assert batch.extend_num_tokens != 0
729
- embeddings = result
796
+ embeddings, bid = result
797
+ embeddings = embeddings.tolist()
730
798
 
731
799
  # Check finish conditions
732
800
  for i, req in enumerate(batch.reqs):
733
801
  req.embedding = embeddings[i]
734
- if req is not self.current_inflight_req:
802
+ if req.is_inflight_req > 0:
803
+ req.is_inflight_req -= 1
804
+ else:
735
805
  # Inflight reqs' prefill is not finished
736
806
  # dummy output token for embedding models
737
807
  req.output_ids.append(0)
738
808
  req.check_finished()
739
809
 
740
810
  if req.finished():
741
- self.tree_cache.cache_finished_req(req)
811
+ self.cache_finished_req(req)
742
812
  else:
743
813
  self.tree_cache.cache_unfinished_req(req)
744
814
 
745
- if req is self.current_inflight_req:
746
- # Inflight request would get a new req idx
747
- self.req_to_token_pool.free(req.req_pool_idx)
748
-
749
- self.handle_finished_requests(batch)
750
-
751
- if not batch.is_empty():
752
- if self.running_batch is None:
753
- self.running_batch = batch
754
- else:
755
- self.running_batch.merge_batch(batch)
815
+ self.stream_output(batch.reqs)
756
816
 
757
817
  def process_batch_result_decode(self, batch: ScheduleBatch, result):
758
- logits_output, next_token_ids = result
759
- batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
760
- next_token_ids
761
- )
818
+ logits_output, next_token_ids, bid = result
762
819
  self.num_generated_tokens += len(batch.reqs)
763
820
 
764
821
  # Move logprobs to cpu
765
- if logits_output.next_token_logprobs is not None:
822
+ if batch.return_logprob:
766
823
  next_token_logprobs = logits_output.next_token_logprobs[
767
- torch.arange(len(next_token_ids), device=next_token_ids.device),
824
+ torch.arange(len(next_token_ids), device=self.device),
768
825
  next_token_ids,
769
826
  ].tolist()
770
827
 
771
- next_token_ids = next_token_ids.tolist()
828
+ next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
829
+
830
+ self.token_to_kv_pool.free_group_begin()
772
831
 
773
832
  # Check finish condition
774
833
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
834
+ if self.server_args.enable_overlap_schedule and req.finished():
835
+ self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
836
+ continue
837
+
775
838
  req.completion_tokens_wo_jump_forward += 1
776
839
  req.output_ids.append(next_token_id)
777
840
  req.check_finished()
@@ -782,7 +845,7 @@ class Scheduler:
782
845
  )
783
846
 
784
847
  if req.finished():
785
- self.tree_cache.cache_finished_req(req)
848
+ self.cache_finished_req(req)
786
849
 
787
850
  if req.return_logprob:
788
851
  req.output_token_logprobs.append(
@@ -791,7 +854,9 @@ class Scheduler:
791
854
  if req.top_logprobs_num > 0:
792
855
  req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
793
856
 
794
- self.handle_finished_requests(batch)
857
+ self.stream_output(batch.reqs)
858
+
859
+ self.token_to_kv_pool.free_group_end()
795
860
 
796
861
  self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
797
862
  if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
@@ -870,7 +935,7 @@ class Scheduler:
870
935
 
871
936
  return num_input_logprobs
872
937
 
873
- def handle_finished_requests(self, batch: ScheduleBatch):
938
+ def stream_output(self, reqs: List[Req]):
874
939
  output_rids = []
875
940
  output_meta_info = []
876
941
  output_finished_reason: List[BaseFinishReason] = []
@@ -881,22 +946,15 @@ class Scheduler:
881
946
  output_read_offsets = []
882
947
  output_skip_special_tokens = []
883
948
  output_spaces_between_special_tokens = []
949
+ output_no_stop_trim = []
884
950
  else: # embedding or reward model
885
951
  output_embeddings = []
886
- unfinished_indices = []
887
952
 
888
- for i, req in enumerate(batch.reqs):
889
- if not req.finished() and req is not self.current_inflight_req:
890
- unfinished_indices.append(i)
891
- else:
892
- self.batch_is_full = False
953
+ is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
893
954
 
955
+ for req in reqs:
894
956
  if req.finished() or (
895
- req.stream
896
- and (
897
- self.decode_forward_ct % self.stream_interval == 0
898
- or len(req.output_ids) == 1
899
- )
957
+ req.stream and (is_stream_iter or len(req.output_ids) == 1)
900
958
  ):
901
959
  output_rids.append(req.rid)
902
960
  output_finished_reason.append(req.finished_reason)
@@ -912,11 +970,13 @@ class Scheduler:
912
970
  output_spaces_between_special_tokens.append(
913
971
  req.sampling_params.spaces_between_special_tokens
914
972
  )
973
+ output_no_stop_trim.append(req.sampling_params.no_stop_trim)
915
974
 
916
975
  meta_info = {
917
976
  "prompt_tokens": len(req.origin_input_ids),
918
977
  "completion_tokens": len(req.output_ids),
919
978
  "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
979
+ "cached_tokens": req.cached_tokens,
920
980
  "finish_reason": (
921
981
  req.finished_reason.to_json()
922
982
  if req.finished_reason is not None
@@ -948,7 +1008,7 @@ class Scheduler:
948
1008
  # Send to detokenizer
949
1009
  if output_rids:
950
1010
  if self.is_generation:
951
- self.out_pyobjs.append(
1011
+ self.send_to_detokenizer.send_pyobj(
952
1012
  BatchTokenIDOut(
953
1013
  output_rids,
954
1014
  output_vids,
@@ -959,10 +1019,11 @@ class Scheduler:
959
1019
  output_spaces_between_special_tokens,
960
1020
  output_meta_info,
961
1021
  output_finished_reason,
1022
+ output_no_stop_trim,
962
1023
  )
963
1024
  )
964
1025
  else: # embedding or reward model
965
- self.out_pyobjs.append(
1026
+ self.send_to_detokenizer.send_pyobj(
966
1027
  BatchEmbeddingOut(
967
1028
  output_rids,
968
1029
  output_embeddings,
@@ -971,9 +1032,6 @@ class Scheduler:
971
1032
  )
972
1033
  )
973
1034
 
974
- # Remove finished reqs: update batch tensors
975
- batch.filter_batch(unfinished_indices)
976
-
977
1035
  def flush_cache(self):
978
1036
  if len(self.waiting_queue) == 0 and (
979
1037
  self.running_batch is None or len(self.running_batch.reqs) == 0
@@ -1009,8 +1067,9 @@ class Scheduler:
1009
1067
  # Delete requests in the running batch
1010
1068
  if self.running_batch:
1011
1069
  for req in self.running_batch.reqs:
1012
- if req.rid == recv_req.rid:
1070
+ if req.rid == recv_req.rid and not req.finished():
1013
1071
  req.finished_reason = FINISH_ABORT()
1072
+ self.cache_finished_req(req)
1014
1073
  break
1015
1074
 
1016
1075
  def update_weights(self, recv_req: UpdateWeightReqInput):
@@ -1055,7 +1114,10 @@ def run_scheduler_process(
1055
1114
  try:
1056
1115
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
1057
1116
  pipe_writer.send("ready")
1058
- scheduler.event_loop()
1117
+ if server_args.enable_overlap_schedule:
1118
+ scheduler.event_loop_overlap()
1119
+ else:
1120
+ scheduler.event_loop_normal()
1059
1121
  except Exception:
1060
1122
  msg = get_exception_traceback()
1061
1123
  logger.error(msg)
@@ -150,9 +150,13 @@ class TokenizerManager:
150
150
  while self.model_update_lock.locked():
151
151
  await asyncio.sleep(0.001)
152
152
 
153
+ if isinstance(obj, EmbeddingReqInput) and self.is_generation:
154
+ raise ValueError(
155
+ "This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
156
+ )
157
+
153
158
  obj.post_init()
154
159
  is_single = obj.is_single
155
-
156
160
  if is_single:
157
161
  async for response in self._handle_single_request(obj, request):
158
162
  yield response