sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. sglang/bench_latency.py +31 -13
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/conversation.py +11 -2
  6. sglang/srt/layers/attention/__init__.py +27 -5
  7. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  9. sglang/srt/layers/attention/triton_backend.py +6 -4
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  12. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  13. sglang/srt/layers/sampler.py +6 -2
  14. sglang/srt/managers/data_parallel_controller.py +177 -0
  15. sglang/srt/managers/detokenizer_manager.py +31 -10
  16. sglang/srt/managers/io_struct.py +11 -2
  17. sglang/srt/managers/schedule_batch.py +126 -43
  18. sglang/srt/managers/schedule_policy.py +2 -1
  19. sglang/srt/managers/scheduler.py +245 -142
  20. sglang/srt/managers/tokenizer_manager.py +14 -1
  21. sglang/srt/managers/tp_worker.py +111 -1
  22. sglang/srt/mem_cache/chunk_cache.py +8 -4
  23. sglang/srt/mem_cache/memory_pool.py +77 -4
  24. sglang/srt/mem_cache/radix_cache.py +15 -7
  25. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  26. sglang/srt/model_executor/forward_batch_info.py +16 -21
  27. sglang/srt/model_executor/model_runner.py +100 -36
  28. sglang/srt/models/baichuan.py +2 -3
  29. sglang/srt/models/chatglm.py +5 -6
  30. sglang/srt/models/commandr.py +1 -2
  31. sglang/srt/models/dbrx.py +1 -2
  32. sglang/srt/models/deepseek.py +4 -5
  33. sglang/srt/models/deepseek_v2.py +5 -6
  34. sglang/srt/models/exaone.py +1 -2
  35. sglang/srt/models/gemma.py +2 -2
  36. sglang/srt/models/gemma2.py +5 -5
  37. sglang/srt/models/gpt_bigcode.py +5 -5
  38. sglang/srt/models/grok.py +1 -2
  39. sglang/srt/models/internlm2.py +1 -2
  40. sglang/srt/models/llama.py +1 -2
  41. sglang/srt/models/llama_classification.py +1 -2
  42. sglang/srt/models/llama_reward.py +2 -3
  43. sglang/srt/models/llava.py +4 -8
  44. sglang/srt/models/llavavid.py +1 -2
  45. sglang/srt/models/minicpm.py +1 -2
  46. sglang/srt/models/minicpm3.py +5 -6
  47. sglang/srt/models/mixtral.py +1 -2
  48. sglang/srt/models/mixtral_quant.py +1 -2
  49. sglang/srt/models/olmo.py +352 -0
  50. sglang/srt/models/olmoe.py +1 -2
  51. sglang/srt/models/qwen.py +1 -2
  52. sglang/srt/models/qwen2.py +1 -2
  53. sglang/srt/models/qwen2_moe.py +4 -5
  54. sglang/srt/models/stablelm.py +1 -2
  55. sglang/srt/models/torch_native_llama.py +1 -2
  56. sglang/srt/models/xverse.py +1 -2
  57. sglang/srt/models/xverse_moe.py +4 -5
  58. sglang/srt/models/yivl.py +1 -2
  59. sglang/srt/openai_api/adapter.py +97 -52
  60. sglang/srt/openai_api/protocol.py +10 -2
  61. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  62. sglang/srt/sampling/sampling_batch_info.py +105 -59
  63. sglang/srt/sampling/sampling_params.py +2 -0
  64. sglang/srt/server.py +171 -37
  65. sglang/srt/server_args.py +127 -48
  66. sglang/srt/utils.py +37 -14
  67. sglang/test/few_shot_gsm8k.py +4 -1
  68. sglang/test/few_shot_gsm8k_engine.py +144 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  70. sglang/version.py +1 -1
  71. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
  72. sglang-0.3.4.dist-info/RECORD +143 -0
  73. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  74. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  75. sglang-0.3.3.dist-info/RECORD +0 -139
  76. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  77. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,11 @@ limitations under the License.
17
17
 
18
18
  import json
19
19
  import logging
20
- import multiprocessing
21
20
  import os
22
21
  import time
23
22
  import warnings
23
+ from collections import deque
24
+ from types import SimpleNamespace
24
25
  from typing import List, Optional, Union
25
26
 
26
27
  import torch
@@ -37,6 +38,7 @@ from sglang.srt.managers.io_struct import (
37
38
  BatchEmbeddingOut,
38
39
  BatchTokenIDOut,
39
40
  FlushCacheReq,
41
+ ProfileReq,
40
42
  TokenizedEmbeddingReqInput,
41
43
  TokenizedGenerateReqInput,
42
44
  TokenizedRewardReqInput,
@@ -76,6 +78,9 @@ logger = logging.getLogger(__name__)
76
78
  # Crash on warning if we are running CI tests
77
79
  crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
78
80
 
81
+ # Test retract decode
82
+ test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
83
+
79
84
 
80
85
  class Scheduler:
81
86
  """A scheduler that manages a tensor parallel GPU worker."""
@@ -106,7 +111,8 @@ class Scheduler:
106
111
  self.send_to_detokenizer = context.socket(zmq.PUSH)
107
112
  self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
108
113
  else:
109
- self.recv_from_tokenizer = self.send_to_detokenizer = None
114
+ self.recv_from_tokenizer = None
115
+ self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
110
116
 
111
117
  # Init tokenizer
112
118
  self.model_config = ModelConfig(
@@ -141,9 +147,10 @@ class Scheduler:
141
147
  gpu_id=gpu_id,
142
148
  tp_rank=tp_rank,
143
149
  server_args=server_args,
144
- nccl_port=port_args.nccl_ports[0],
150
+ nccl_port=port_args.nccl_port,
145
151
  )
146
152
  self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
153
+ self.device = self.tp_worker.device
147
154
 
148
155
  # Get token and memory info from the model worker
149
156
  (
@@ -189,8 +196,8 @@ class Scheduler:
189
196
 
190
197
  # Init running status
191
198
  self.waiting_queue: List[Req] = []
192
- self.running_batch: ScheduleBatch = None
193
- self.out_pyobjs = []
199
+ self.running_batch: Optional[ScheduleBatch] = None
200
+ self.cur_batch: Optional[ScheduleBatch] = None
194
201
  self.decode_forward_ct = 0
195
202
  self.stream_interval = server_args.stream_interval
196
203
  self.num_generated_tokens = 0
@@ -229,15 +236,92 @@ class Scheduler:
229
236
  self.new_token_ratio_decay = global_config.new_token_ratio_decay
230
237
  self.batch_is_full = False
231
238
 
239
+ # Init profiler
240
+ if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
241
+ self.profiler = None
242
+ else:
243
+ self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
244
+ logger.info(
245
+ "Profiling enabled. Traces will be saved to: %s",
246
+ self.torch_profiler_trace_dir,
247
+ )
248
+ self.profiler = torch.profiler.profile(
249
+ activities=[
250
+ torch.profiler.ProfilerActivity.CPU,
251
+ torch.profiler.ProfilerActivity.CUDA,
252
+ ],
253
+ with_stack=True,
254
+ )
255
+
256
+ # Init states for overlap schedule
257
+ if self.server_args.enable_overlap_schedule:
258
+ self.forward_batch_generation = (
259
+ self.tp_worker.forward_batch_generation_non_blocking
260
+ )
261
+ self.resolve_next_token_ids = (
262
+ lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
263
+ )
264
+ self.cache_finished_req = self.tree_cache.cache_finished_req
265
+ else:
266
+ self.forward_batch_generation = self.tp_worker.forward_batch_generation
267
+ self.resolve_next_token_ids = lambda bid, x: x.tolist()
268
+ self.cache_finished_req = self.tree_cache.cache_finished_req
269
+
270
+ @torch.inference_mode()
271
+ def event_loop_normal(self):
272
+ self.last_batch = None
273
+
274
+ while True:
275
+ recv_reqs = self.recv_requests()
276
+ self.process_input_requests(recv_reqs)
277
+
278
+ batch = self.get_next_batch_to_run()
279
+
280
+ if batch:
281
+ result = self.run_batch(batch)
282
+ self.process_batch_result(batch, result)
283
+
284
+ # Decode multiple steps to reduce the overhead
285
+ if batch.forward_mode.is_decode():
286
+ for _ in range(self.server_args.num_continuous_decode_steps - 1):
287
+ if not self.running_batch:
288
+ break
289
+ self.update_running_batch()
290
+ if not self.running_batch:
291
+ break
292
+ result = self.run_batch(batch)
293
+ self.process_batch_result(batch, result)
294
+ else:
295
+ self.check_memory()
296
+ self.new_token_ratio = global_config.init_new_token_ratio
297
+
298
+ self.last_batch = batch
299
+
232
300
  @torch.inference_mode()
233
- def event_loop(self):
301
+ def event_loop_overlap(self):
302
+ result_queue = deque()
303
+
304
+ self.last_batch = None
305
+ self.running_batch = None
306
+
234
307
  while True:
235
308
  recv_reqs = self.recv_requests()
236
309
  self.process_input_requests(recv_reqs)
237
310
 
238
- self.run_step()
311
+ batch = self.get_next_batch_to_run()
312
+ self.cur_batch = batch
313
+ if batch:
314
+ result = self.run_batch(batch)
315
+ result_queue.append((batch.copy(), result))
239
316
 
240
- self.send_results()
317
+ if self.last_batch:
318
+ tmp_batch, tmp_result = result_queue.popleft()
319
+ self.process_batch_result(tmp_batch, tmp_result)
320
+ elif batch is None:
321
+ self.check_memory()
322
+ self.new_token_ratio = global_config.init_new_token_ratio
323
+
324
+ self.last_batch = batch
241
325
 
242
326
  def recv_requests(self):
243
327
  if self.tp_rank == 0:
@@ -270,7 +354,14 @@ class Scheduler:
270
354
  self.abort_request(recv_req)
271
355
  elif isinstance(recv_req, UpdateWeightReqInput):
272
356
  success, message = self.update_weights(recv_req)
273
- self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
357
+ self.send_to_detokenizer.send_pyobj(
358
+ UpdateWeightReqOutput(success, message)
359
+ )
360
+ elif isinstance(recv_req, ProfileReq):
361
+ if recv_req == ProfileReq.START_PROFILE:
362
+ self.start_profile()
363
+ else:
364
+ self.stop_profile()
274
365
  else:
275
366
  raise ValueError(f"Invalid request: {recv_req}")
276
367
 
@@ -363,12 +454,6 @@ class Scheduler:
363
454
 
364
455
  self.waiting_queue.append(req)
365
456
 
366
- def send_results(self):
367
- if self.tp_rank == 0:
368
- for obj in self.out_pyobjs:
369
- self.send_to_detokenizer.send_pyobj(obj)
370
- self.out_pyobjs = []
371
-
372
457
  def print_decode_stats(self):
373
458
  num_used = self.max_total_num_tokens - (
374
459
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
@@ -376,9 +461,10 @@ class Scheduler:
376
461
  throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
377
462
  self.num_generated_tokens = 0
378
463
  self.last_stats_tic = time.time()
464
+ num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
379
465
  logger.info(
380
466
  f"Decode batch. "
381
- f"#running-req: {len(self.running_batch.reqs)}, "
467
+ f"#running-req: {num_running_reqs}, "
382
468
  f"#token: {num_used}, "
383
469
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
384
470
  f"gen throughput (token/s): {throughput:.2f}, "
@@ -406,41 +492,45 @@ class Scheduler:
406
492
  )
407
493
  exit(1) if crash_on_warning else None
408
494
 
409
- def run_step(self):
495
+ def get_next_batch_to_run(self):
496
+ # Merge the prefill batch into the running batch
497
+ if (
498
+ self.last_batch
499
+ and not self.last_batch.forward_mode.is_decode()
500
+ and not self.last_batch.is_empty()
501
+ ):
502
+ if self.current_inflight_req:
503
+ self.last_batch.filter_batch(
504
+ current_inflight_req=self.current_inflight_req
505
+ )
506
+ self.tree_cache.cache_unfinished_req(self.current_inflight_req)
507
+ # Inflight request keeps its rid but will get a new req_pool_idx.
508
+ self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
509
+ self.batch_is_full = False
510
+ if not self.last_batch.is_empty():
511
+ if self.running_batch is None:
512
+ self.running_batch = self.last_batch
513
+ else:
514
+ self.running_batch.merge_batch(self.last_batch)
515
+
516
+ # Prefill first
410
517
  new_batch = self.get_new_batch_prefill()
411
518
  if new_batch is not None:
412
- # Run a new prefill batch
413
- # replace run_batch with the uncommented line to use pytorch profiler
414
- # result = pytorch_profile(
415
- # "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
416
- # )
417
- result = self.run_batch(new_batch)
418
- self.process_batch_result(new_batch, result)
419
- else:
420
- if self.running_batch is not None:
421
- # Run a few decode batches continuously for reducing overhead
422
- for _ in range(global_config.num_continue_decode_steps):
423
- batch = self.get_new_batch_decode()
424
-
425
- if batch:
426
- # replace run_batch with the uncommented line to use pytorch profiler
427
- # result = pytorch_profile(
428
- # "profile_decode_step",
429
- # self.run_batch,
430
- # batch,
431
- # data_size=len(batch.reqs),
432
- # )
433
- result = self.run_batch(batch)
434
- self.process_batch_result(batch, result)
519
+ return new_batch
435
520
 
436
- if self.running_batch is None:
437
- break
521
+ # Check memory
522
+ if self.running_batch is None:
523
+ return
438
524
 
439
- if self.out_pyobjs and self.running_batch.has_stream:
440
- break
441
- else:
442
- self.check_memory()
443
- self.new_token_ratio = global_config.init_new_token_ratio
525
+ # Run decode
526
+ before_bs = self.running_batch.batch_size()
527
+ self.update_running_batch()
528
+ if not self.running_batch:
529
+ self.batch_is_full = False
530
+ return None
531
+ if before_bs != self.running_batch.batch_size():
532
+ self.batch_is_full = False
533
+ return self.running_batch
444
534
 
445
535
  def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
446
536
  # Handle the cases where prefill is not allowed
@@ -449,9 +539,7 @@ class Scheduler:
449
539
  ) and self.current_inflight_req is None:
450
540
  return None
451
541
 
452
- running_bs = (
453
- len(self.running_batch.reqs) if self.running_batch is not None else 0
454
- )
542
+ running_bs = len(self.running_batch.reqs) if self.running_batch else 0
455
543
  if running_bs >= self.max_running_requests:
456
544
  self.batch_is_full = True
457
545
  return None
@@ -472,7 +560,7 @@ class Scheduler:
472
560
  )
473
561
 
474
562
  has_inflight = self.current_inflight_req is not None
475
- if self.current_inflight_req is not None:
563
+ if has_inflight:
476
564
  self.current_inflight_req.init_next_round_input(
477
565
  None if prefix_computed else self.tree_cache
478
566
  )
@@ -480,7 +568,7 @@ class Scheduler:
480
568
  self.current_inflight_req
481
569
  )
482
570
 
483
- if self.lora_paths is not None:
571
+ if self.lora_paths:
484
572
  lora_set = (
485
573
  set([req.lora_path for req in self.running_batch.reqs])
486
574
  if self.running_batch is not None
@@ -489,7 +577,7 @@ class Scheduler:
489
577
 
490
578
  for req in self.waiting_queue:
491
579
  if (
492
- self.lora_paths is not None
580
+ self.lora_paths
493
581
  and len(
494
582
  lora_set
495
583
  | set([req.lora_path for req in adder.can_run_list])
@@ -511,16 +599,20 @@ class Scheduler:
511
599
  self.batch_is_full = True
512
600
  break
513
601
 
602
+ # Update waiting queue
514
603
  can_run_list = adder.can_run_list
604
+ if len(can_run_list) == 0:
605
+ return None
606
+ self.waiting_queue = [
607
+ x for x in self.waiting_queue if x not in set(can_run_list)
608
+ ]
515
609
 
516
610
  if adder.new_inflight_req is not None:
517
611
  assert self.current_inflight_req is None
518
612
  self.current_inflight_req = adder.new_inflight_req
519
613
 
520
- if len(can_run_list) == 0:
521
- return None
522
-
523
- self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
614
+ if self.current_inflight_req:
615
+ self.current_inflight_req.is_inflight_req += 1
524
616
 
525
617
  # Print stats
526
618
  if self.tp_rank == 0:
@@ -573,21 +665,27 @@ class Scheduler:
573
665
  new_batch.prepare_for_extend(self.model_config.vocab_size)
574
666
 
575
667
  # Mixed-style chunked prefill
576
- decoding_reqs = []
577
668
  if self.is_mixed_chunk and self.running_batch is not None:
578
669
  self.running_batch.prepare_for_decode()
579
670
  new_batch.mix_with_running(self.running_batch)
580
- decoding_reqs = self.running_batch.reqs
671
+ new_batch.decoding_reqs = self.running_batch.reqs
581
672
  self.running_batch = None
582
- new_batch.decoding_reqs = decoding_reqs
673
+ else:
674
+ new_batch.decoding_reqs = None
583
675
 
584
676
  return new_batch
585
677
 
586
- def get_new_batch_decode(self) -> Optional[ScheduleBatch]:
678
+ def update_running_batch(self):
679
+ global test_retract
587
680
  batch = self.running_batch
588
681
 
682
+ batch.filter_batch()
683
+ if batch.is_empty():
684
+ self.running_batch = None
685
+ return
686
+
589
687
  # Check if decode out of memory
590
- if not batch.check_decode_mem():
688
+ if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
591
689
  old_ratio = self.new_token_ratio
592
690
 
593
691
  retracted_reqs, new_token_ratio = batch.retract_decode()
@@ -610,17 +708,17 @@ class Scheduler:
610
708
  jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
611
709
  self.waiting_queue.extend(jump_forward_reqs)
612
710
  if batch.is_empty():
613
- return None
711
+ self.running_batch = None
712
+ return
614
713
 
615
714
  # Update batch tensors
616
715
  batch.prepare_for_decode()
617
- return batch
618
716
 
619
717
  def run_batch(self, batch: ScheduleBatch):
620
718
  if self.is_generation:
621
719
  if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
622
720
  model_worker_batch = batch.get_model_worker_batch()
623
- logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
721
+ logits_output, next_token_ids = self.forward_batch_generation(
624
722
  model_worker_batch
625
723
  )
626
724
  else:
@@ -631,34 +729,32 @@ class Scheduler:
631
729
  )
632
730
  else:
633
731
  next_token_ids = torch.full((batch.batch_size(),), 0)
634
- return logits_output, next_token_ids
732
+ batch.output_ids = next_token_ids
733
+ ret = logits_output, next_token_ids, model_worker_batch.bid
635
734
  else: # embedding or reward model
636
735
  assert batch.extend_num_tokens != 0
637
736
  model_worker_batch = batch.get_model_worker_batch()
638
737
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
639
- return embeddings
738
+ ret = embeddings, model_worker_batch.bid
739
+ return ret
640
740
 
641
741
  def process_batch_result(self, batch: ScheduleBatch, result):
642
742
  if batch.forward_mode.is_decode():
643
743
  self.process_batch_result_decode(batch, result)
744
+ if batch.is_empty():
745
+ self.running_batch = None
644
746
  else:
645
747
  self.process_batch_result_prefill(batch, result)
646
748
 
647
749
  def process_batch_result_prefill(self, batch: ScheduleBatch, result):
648
750
  if self.is_generation:
649
- logits_output, next_token_ids = result
650
- batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
651
- next_token_ids
652
- )
653
-
654
- if logits_output:
751
+ logits_output, next_token_ids, bid = result
752
+ if batch.return_logprob:
655
753
  # Move logprobs to cpu
656
754
  if logits_output.next_token_logprobs is not None:
657
755
  logits_output.next_token_logprobs = (
658
756
  logits_output.next_token_logprobs[
659
- torch.arange(
660
- len(next_token_ids), device=next_token_ids.device
661
- ),
757
+ torch.arange(len(next_token_ids), device=self.device),
662
758
  next_token_ids,
663
759
  ].tolist()
664
760
  )
@@ -669,84 +765,76 @@ class Scheduler:
669
765
  logits_output.normalized_prompt_logprobs.tolist()
670
766
  )
671
767
 
672
- next_token_ids = next_token_ids.tolist()
768
+ next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
673
769
 
674
770
  # Check finish conditions
675
771
  logprob_pt = 0
676
772
  for i, req in enumerate(batch.reqs):
677
- if req is not self.current_inflight_req:
773
+ if req.is_inflight_req > 0:
774
+ req.is_inflight_req -= 1
775
+ else:
678
776
  # Inflight reqs' prefill is not finished
679
777
  req.completion_tokens_wo_jump_forward += 1
680
778
  req.output_ids.append(next_token_ids[i])
681
779
  req.check_finished()
682
780
 
683
- if req.regex_fsm is not None:
684
- req.regex_fsm_state = req.regex_fsm.get_next_state(
685
- req.regex_fsm_state, next_token_ids[i]
686
- )
781
+ if req.finished():
782
+ self.cache_finished_req(req)
783
+ elif not batch.decoding_reqs or req not in batch.decoding_reqs:
784
+ self.tree_cache.cache_unfinished_req(req)
687
785
 
688
- if req.finished():
689
- self.tree_cache.cache_finished_req(req)
690
- elif req not in batch.decoding_reqs:
691
- # To reduce overhead, only cache prefill reqs
692
- self.tree_cache.cache_unfinished_req(req)
693
-
694
- if req is self.current_inflight_req:
695
- # Inflight request would get a new req idx
696
- self.req_to_token_pool.free(req.req_pool_idx)
786
+ if req.regex_fsm is not None:
787
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
788
+ req.regex_fsm_state, next_token_ids[i]
789
+ )
697
790
 
698
- if req.return_logprob:
699
- logprob_pt += self.add_logprob_return_values(
700
- i, req, logprob_pt, next_token_ids, logits_output
701
- )
791
+ if req.return_logprob:
792
+ logprob_pt += self.add_logprob_return_values(
793
+ i, req, logprob_pt, next_token_ids, logits_output
794
+ )
702
795
  else: # embedding or reward model
703
- assert batch.extend_num_tokens != 0
704
- embeddings = result
796
+ embeddings, bid = result
797
+ embeddings = embeddings.tolist()
705
798
 
706
799
  # Check finish conditions
707
800
  for i, req in enumerate(batch.reqs):
708
801
  req.embedding = embeddings[i]
709
- if req is not self.current_inflight_req:
802
+ if req.is_inflight_req > 0:
803
+ req.is_inflight_req -= 1
804
+ else:
710
805
  # Inflight reqs' prefill is not finished
711
806
  # dummy output token for embedding models
712
807
  req.output_ids.append(0)
713
808
  req.check_finished()
714
809
 
715
810
  if req.finished():
716
- self.tree_cache.cache_finished_req(req)
811
+ self.cache_finished_req(req)
717
812
  else:
718
813
  self.tree_cache.cache_unfinished_req(req)
719
814
 
720
- if req is self.current_inflight_req:
721
- # Inflight request would get a new req idx
722
- self.req_to_token_pool.free(req.req_pool_idx)
723
-
724
- self.handle_finished_requests(batch)
725
-
726
- if not batch.is_empty():
727
- if self.running_batch is None:
728
- self.running_batch = batch
729
- else:
730
- self.running_batch.merge_batch(batch)
815
+ self.stream_output(batch.reqs)
731
816
 
732
817
  def process_batch_result_decode(self, batch: ScheduleBatch, result):
733
- logits_output, next_token_ids = result
734
- batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
735
- next_token_ids
736
- )
818
+ logits_output, next_token_ids, bid = result
737
819
  self.num_generated_tokens += len(batch.reqs)
738
820
 
739
821
  # Move logprobs to cpu
740
- if logits_output.next_token_logprobs is not None:
822
+ if batch.return_logprob:
741
823
  next_token_logprobs = logits_output.next_token_logprobs[
742
- torch.arange(len(next_token_ids), device=next_token_ids.device),
824
+ torch.arange(len(next_token_ids), device=self.device),
743
825
  next_token_ids,
744
826
  ].tolist()
745
827
 
746
- next_token_ids = next_token_ids.tolist()
828
+ next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
829
+
830
+ self.token_to_kv_pool.free_group_begin()
747
831
 
748
832
  # Check finish condition
749
833
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
834
+ if self.server_args.enable_overlap_schedule and req.finished():
835
+ self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
836
+ continue
837
+
750
838
  req.completion_tokens_wo_jump_forward += 1
751
839
  req.output_ids.append(next_token_id)
752
840
  req.check_finished()
@@ -757,7 +845,7 @@ class Scheduler:
757
845
  )
758
846
 
759
847
  if req.finished():
760
- self.tree_cache.cache_finished_req(req)
848
+ self.cache_finished_req(req)
761
849
 
762
850
  if req.return_logprob:
763
851
  req.output_token_logprobs.append(
@@ -766,15 +854,14 @@ class Scheduler:
766
854
  if req.top_logprobs_num > 0:
767
855
  req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
768
856
 
769
- self.handle_finished_requests(batch)
857
+ self.stream_output(batch.reqs)
858
+
859
+ self.token_to_kv_pool.free_group_end()
770
860
 
771
861
  self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
772
862
  if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
773
863
  self.print_decode_stats()
774
864
 
775
- if self.running_batch.is_empty():
776
- self.running_batch = None
777
-
778
865
  def add_logprob_return_values(
779
866
  self,
780
867
  i: int,
@@ -848,7 +935,7 @@ class Scheduler:
848
935
 
849
936
  return num_input_logprobs
850
937
 
851
- def handle_finished_requests(self, batch: ScheduleBatch):
938
+ def stream_output(self, reqs: List[Req]):
852
939
  output_rids = []
853
940
  output_meta_info = []
854
941
  output_finished_reason: List[BaseFinishReason] = []
@@ -859,22 +946,15 @@ class Scheduler:
859
946
  output_read_offsets = []
860
947
  output_skip_special_tokens = []
861
948
  output_spaces_between_special_tokens = []
949
+ output_no_stop_trim = []
862
950
  else: # embedding or reward model
863
951
  output_embeddings = []
864
- unfinished_indices = []
865
952
 
866
- for i, req in enumerate(batch.reqs):
867
- if not req.finished() and req is not self.current_inflight_req:
868
- unfinished_indices.append(i)
869
- else:
870
- self.batch_is_full = False
953
+ is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
871
954
 
955
+ for req in reqs:
872
956
  if req.finished() or (
873
- req.stream
874
- and (
875
- self.decode_forward_ct % self.stream_interval == 0
876
- or len(req.output_ids) == 1
877
- )
957
+ req.stream and (is_stream_iter or len(req.output_ids) == 1)
878
958
  ):
879
959
  output_rids.append(req.rid)
880
960
  output_finished_reason.append(req.finished_reason)
@@ -890,11 +970,13 @@ class Scheduler:
890
970
  output_spaces_between_special_tokens.append(
891
971
  req.sampling_params.spaces_between_special_tokens
892
972
  )
973
+ output_no_stop_trim.append(req.sampling_params.no_stop_trim)
893
974
 
894
975
  meta_info = {
895
976
  "prompt_tokens": len(req.origin_input_ids),
896
977
  "completion_tokens": len(req.output_ids),
897
978
  "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
979
+ "cached_tokens": req.cached_tokens,
898
980
  "finish_reason": (
899
981
  req.finished_reason.to_json()
900
982
  if req.finished_reason is not None
@@ -926,7 +1008,7 @@ class Scheduler:
926
1008
  # Send to detokenizer
927
1009
  if output_rids:
928
1010
  if self.is_generation:
929
- self.out_pyobjs.append(
1011
+ self.send_to_detokenizer.send_pyobj(
930
1012
  BatchTokenIDOut(
931
1013
  output_rids,
932
1014
  output_vids,
@@ -937,10 +1019,11 @@ class Scheduler:
937
1019
  output_spaces_between_special_tokens,
938
1020
  output_meta_info,
939
1021
  output_finished_reason,
1022
+ output_no_stop_trim,
940
1023
  )
941
1024
  )
942
1025
  else: # embedding or reward model
943
- self.out_pyobjs.append(
1026
+ self.send_to_detokenizer.send_pyobj(
944
1027
  BatchEmbeddingOut(
945
1028
  output_rids,
946
1029
  output_embeddings,
@@ -949,9 +1032,6 @@ class Scheduler:
949
1032
  )
950
1033
  )
951
1034
 
952
- # Remove finished reqs: update batch tensors
953
- batch.filter_batch(unfinished_indices)
954
-
955
1035
  def flush_cache(self):
956
1036
  if len(self.waiting_queue) == 0 and (
957
1037
  self.running_batch is None or len(self.running_batch.reqs) == 0
@@ -987,8 +1067,9 @@ class Scheduler:
987
1067
  # Delete requests in the running batch
988
1068
  if self.running_batch:
989
1069
  for req in self.running_batch.reqs:
990
- if req.rid == recv_req.rid:
1070
+ if req.rid == recv_req.rid and not req.finished():
991
1071
  req.finished_reason = FINISH_ABORT()
1072
+ self.cache_finished_req(req)
992
1073
  break
993
1074
 
994
1075
  def update_weights(self, recv_req: UpdateWeightReqInput):
@@ -1000,21 +1081,43 @@ class Scheduler:
1000
1081
  logger.error(message)
1001
1082
  return success, message
1002
1083
 
1084
+ def start_profile(self) -> None:
1085
+ if self.profiler is None:
1086
+ raise RuntimeError("Profiler is not enabled.")
1087
+ self.profiler.start()
1088
+
1089
+ def stop_profile(self) -> None:
1090
+ if self.profiler is None:
1091
+ raise RuntimeError("Profiler is not enabled.")
1092
+ self.profiler.stop()
1093
+ self.profiler.export_chrome_trace(
1094
+ self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
1095
+ )
1096
+ logger.info("Profiler is done")
1097
+
1003
1098
 
1004
1099
  def run_scheduler_process(
1005
1100
  server_args: ServerArgs,
1006
1101
  port_args: PortArgs,
1007
1102
  gpu_id: int,
1008
1103
  tp_rank: int,
1104
+ dp_rank: Optional[int],
1009
1105
  pipe_writer,
1010
1106
  ):
1011
- configure_logger(server_args, prefix=f" TP{tp_rank}")
1107
+ if dp_rank is None:
1108
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
1109
+ else:
1110
+ configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
1111
+
1012
1112
  suppress_other_loggers()
1013
1113
 
1014
1114
  try:
1015
1115
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
1016
1116
  pipe_writer.send("ready")
1017
- scheduler.event_loop()
1117
+ if server_args.enable_overlap_schedule:
1118
+ scheduler.event_loop_overlap()
1119
+ else:
1120
+ scheduler.event_loop_normal()
1018
1121
  except Exception:
1019
1122
  msg = get_exception_traceback()
1020
1123
  logger.error(msg)