sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. sglang/bench_latency.py +31 -13
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/conversation.py +11 -2
  6. sglang/srt/layers/attention/__init__.py +27 -5
  7. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  9. sglang/srt/layers/attention/triton_backend.py +6 -4
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  12. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  13. sglang/srt/layers/sampler.py +6 -2
  14. sglang/srt/managers/data_parallel_controller.py +177 -0
  15. sglang/srt/managers/detokenizer_manager.py +31 -10
  16. sglang/srt/managers/io_struct.py +11 -2
  17. sglang/srt/managers/schedule_batch.py +126 -43
  18. sglang/srt/managers/schedule_policy.py +2 -1
  19. sglang/srt/managers/scheduler.py +245 -142
  20. sglang/srt/managers/tokenizer_manager.py +14 -1
  21. sglang/srt/managers/tp_worker.py +111 -1
  22. sglang/srt/mem_cache/chunk_cache.py +8 -4
  23. sglang/srt/mem_cache/memory_pool.py +77 -4
  24. sglang/srt/mem_cache/radix_cache.py +15 -7
  25. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  26. sglang/srt/model_executor/forward_batch_info.py +16 -21
  27. sglang/srt/model_executor/model_runner.py +100 -36
  28. sglang/srt/models/baichuan.py +2 -3
  29. sglang/srt/models/chatglm.py +5 -6
  30. sglang/srt/models/commandr.py +1 -2
  31. sglang/srt/models/dbrx.py +1 -2
  32. sglang/srt/models/deepseek.py +4 -5
  33. sglang/srt/models/deepseek_v2.py +5 -6
  34. sglang/srt/models/exaone.py +1 -2
  35. sglang/srt/models/gemma.py +2 -2
  36. sglang/srt/models/gemma2.py +5 -5
  37. sglang/srt/models/gpt_bigcode.py +5 -5
  38. sglang/srt/models/grok.py +1 -2
  39. sglang/srt/models/internlm2.py +1 -2
  40. sglang/srt/models/llama.py +1 -2
  41. sglang/srt/models/llama_classification.py +1 -2
  42. sglang/srt/models/llama_reward.py +2 -3
  43. sglang/srt/models/llava.py +4 -8
  44. sglang/srt/models/llavavid.py +1 -2
  45. sglang/srt/models/minicpm.py +1 -2
  46. sglang/srt/models/minicpm3.py +5 -6
  47. sglang/srt/models/mixtral.py +1 -2
  48. sglang/srt/models/mixtral_quant.py +1 -2
  49. sglang/srt/models/olmo.py +352 -0
  50. sglang/srt/models/olmoe.py +1 -2
  51. sglang/srt/models/qwen.py +1 -2
  52. sglang/srt/models/qwen2.py +1 -2
  53. sglang/srt/models/qwen2_moe.py +4 -5
  54. sglang/srt/models/stablelm.py +1 -2
  55. sglang/srt/models/torch_native_llama.py +1 -2
  56. sglang/srt/models/xverse.py +1 -2
  57. sglang/srt/models/xverse_moe.py +4 -5
  58. sglang/srt/models/yivl.py +1 -2
  59. sglang/srt/openai_api/adapter.py +97 -52
  60. sglang/srt/openai_api/protocol.py +10 -2
  61. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  62. sglang/srt/sampling/sampling_batch_info.py +105 -59
  63. sglang/srt/sampling/sampling_params.py +2 -0
  64. sglang/srt/server.py +171 -37
  65. sglang/srt/server_args.py +127 -48
  66. sglang/srt/utils.py +37 -14
  67. sglang/test/few_shot_gsm8k.py +4 -1
  68. sglang/test/few_shot_gsm8k_engine.py +144 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  70. sglang/version.py +1 -1
  71. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
  72. sglang-0.3.4.dist-info/RECORD +143 -0
  73. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  74. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  75. sglang-0.3.3.dist-info/RECORD +0 -139
  76. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  77. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -53,6 +53,7 @@ global_server_args_dict = {
53
53
  "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
54
54
  "disable_mla": ServerArgs.disable_mla,
55
55
  "torchao_config": ServerArgs.torchao_config,
56
+ "disable_nan_detection": ServerArgs.disable_nan_detection,
56
57
  }
57
58
 
58
59
 
@@ -196,6 +197,9 @@ class Req:
196
197
  # this does not include the jump forward tokens.
197
198
  self.completion_tokens_wo_jump_forward = 0
198
199
 
200
+ # The number of cached tokens, that were already cached in the KV store
201
+ self.cached_tokens = 0
202
+
199
203
  # For vision inputs
200
204
  self.image_inputs: Optional[ImageInputs] = None
201
205
 
@@ -203,6 +207,7 @@ class Req:
203
207
  self.prefix_indices = []
204
208
  self.extend_input_len = 0
205
209
  self.last_node = None
210
+ self.is_inflight_req = 0
206
211
 
207
212
  # Logprobs (arguments)
208
213
  self.return_logprob = False
@@ -391,25 +396,30 @@ class Req:
391
396
  return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
392
397
 
393
398
 
399
+ bid = 0
400
+
401
+
394
402
  @dataclass
395
403
  class ScheduleBatch:
396
404
  """Store all inforamtion of a batch."""
397
405
 
398
406
  # Request, memory pool, and cache
399
407
  reqs: List[Req]
400
- req_to_token_pool: ReqToTokenPool
401
- token_to_kv_pool: BaseTokenToKVPool
402
- tree_cache: BasePrefixCache
408
+ req_to_token_pool: ReqToTokenPool = None
409
+ token_to_kv_pool: BaseTokenToKVPool = None
410
+ tree_cache: BasePrefixCache = None
403
411
 
404
412
  forward_mode: ForwardMode = None
405
413
  sampling_info: SamplingBatchInfo = None
406
414
 
407
415
  # Batched arguments to model runner
408
- input_ids: List[int] = None
409
- req_pool_indices: List[int] = None
410
- seq_lens: List[int] = None
416
+ input_ids: torch.Tensor = None
417
+ req_pool_indices: torch.Tensor = None
418
+ seq_lens: torch.Tensor = None
411
419
  out_cache_loc: torch.Tensor = None
412
420
 
421
+ output_ids: torch.Tensor = None
422
+
413
423
  # For processing logprobs
414
424
  return_logprob: bool = False
415
425
  top_logprobs_nums: Optional[List[int]] = None
@@ -419,10 +429,14 @@ class ScheduleBatch:
419
429
  extend_lens: List[int] = None
420
430
  extend_num_tokens: int = None
421
431
  running_bs: int = None
432
+ decoding_reqs: List[Req] = None
422
433
 
423
434
  # Stream
424
435
  has_stream: bool = False
425
436
 
437
+ # device
438
+ device: str = "cuda"
439
+
426
440
  # Has regex
427
441
  has_regex: bool = False
428
442
 
@@ -439,6 +453,7 @@ class ScheduleBatch:
439
453
  tree_cache=tree_cache,
440
454
  return_logprob=return_logprob,
441
455
  has_stream=has_stream,
456
+ device=req_to_token_pool.device,
442
457
  has_regex=has_regex,
443
458
  )
444
459
 
@@ -488,17 +503,24 @@ class ScheduleBatch:
488
503
 
489
504
  pt = 0
490
505
  for i, req in enumerate(reqs):
506
+ already_computed = (
507
+ req.extend_logprob_start_len + 1 + req.cached_tokens
508
+ if req.extend_logprob_start_len > 0
509
+ else 0
510
+ )
511
+ req.cached_tokens += len(req.prefix_indices) - already_computed
512
+
491
513
  req.req_pool_idx = req_pool_indices[i]
492
514
  pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
493
515
  seq_lens.append(seq_len)
494
516
  assert seq_len - pre_len == req.extend_input_len
495
517
 
496
518
  if pre_len > 0:
497
- self.req_to_token_pool.req_to_token[req.req_pool_idx][
498
- :pre_len
499
- ] = req.prefix_indices
519
+ self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
520
+ req.prefix_indices
521
+ )
500
522
 
501
- self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
523
+ self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
502
524
  out_cache_loc[pt : pt + req.extend_input_len]
503
525
  )
504
526
 
@@ -514,10 +536,15 @@ class ScheduleBatch:
514
536
  pt += req.extend_input_len
515
537
 
516
538
  # Set fields
517
- with out_cache_loc.device:
518
- self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
519
- self.req_pool_indices = torch.tensor(req_pool_indices)
520
- self.seq_lens = torch.tensor(seq_lens)
539
+ self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
540
+ self.device, non_blocking=True
541
+ )
542
+ self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
543
+ self.device, non_blocking=True
544
+ )
545
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
546
+ self.device, non_blocking=True
547
+ )
521
548
 
522
549
  self.extend_num_tokens = extend_num_tokens
523
550
  self.out_cache_loc = out_cache_loc
@@ -527,7 +554,9 @@ class ScheduleBatch:
527
554
  self.extend_lens = [r.extend_input_len for r in reqs]
528
555
  self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
529
556
 
530
- self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
557
+ self.sampling_info = SamplingBatchInfo.from_schedule_batch(
558
+ self, vocab_size, global_server_args_dict["disable_penalizer"]
559
+ )
531
560
 
532
561
  def mix_with_running(self, running_batch: "ScheduleBatch"):
533
562
  self.forward_mode = ForwardMode.MIXED
@@ -582,9 +611,11 @@ class ScheduleBatch:
582
611
 
583
612
  retracted_reqs = []
584
613
  seq_lens_cpu = self.seq_lens.cpu().numpy()
614
+ first_iter = True
585
615
  while (
586
616
  self.token_to_kv_pool.available_size()
587
617
  < len(sorted_indices) * global_config.retract_decode_steps
618
+ or first_iter
588
619
  ):
589
620
  if len(sorted_indices) == 1:
590
621
  # Corner case: only one request left
@@ -593,6 +624,7 @@ class ScheduleBatch:
593
624
  ), "No space left for only one request"
594
625
  break
595
626
 
627
+ first_iter = False
596
628
  idx = sorted_indices.pop()
597
629
  req = self.reqs[idx]
598
630
  retracted_reqs.append(req)
@@ -633,7 +665,7 @@ class ScheduleBatch:
633
665
  req.last_update_decode_tokens = 0
634
666
  req.logprob_start_len = 10**9
635
667
 
636
- self.filter_batch(sorted_indices)
668
+ self.filter_batch(keep_indices=sorted_indices)
637
669
 
638
670
  # Reqs in batch are filtered
639
671
  total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
@@ -648,7 +680,7 @@ class ScheduleBatch:
648
680
 
649
681
  def check_for_jump_forward(self, pad_input_ids_func):
650
682
  jump_forward_reqs = []
651
- filter_indices = [i for i in range(len(self.reqs))]
683
+ keep_indices = set(i for i in range(len(self.reqs)))
652
684
 
653
685
  for i, req in enumerate(self.reqs):
654
686
  if req.jump_forward_map is not None:
@@ -708,63 +740,71 @@ class ScheduleBatch:
708
740
  )
709
741
 
710
742
  jump_forward_reqs.append(req)
711
- filter_indices.remove(i)
743
+ keep_indices.remove(i)
712
744
 
713
- self.filter_batch(filter_indices)
745
+ self.filter_batch(keep_indices=list(keep_indices))
714
746
 
715
747
  return jump_forward_reqs
716
748
 
717
- def prepare_for_decode(self, input_ids=None):
749
+ def prepare_for_decode(self):
718
750
  self.forward_mode = ForwardMode.DECODE
719
751
 
720
- if input_ids is None:
721
- input_ids = [
722
- r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
723
- for r in self.reqs
724
- ]
725
-
726
- self.input_ids = torch.tensor(
727
- input_ids, dtype=torch.int32, device=self.seq_lens.device
728
- )
729
- self.seq_lens.add_(1)
752
+ self.input_ids = self.output_ids
753
+ self.output_ids = None
754
+ if self.sampling_info.penalizer_orchestrator:
755
+ self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
756
+ self.input_ids
757
+ )
730
758
 
731
759
  # Alloc mem
732
760
  bs = len(self.reqs)
733
761
  self.out_cache_loc = self.alloc_token_slots(bs)
734
762
 
735
- self.req_to_token_pool.req_to_token[
736
- self.req_pool_indices, self.seq_lens - 1
737
- ] = self.out_cache_loc
763
+ self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
764
+ self.out_cache_loc
765
+ )
766
+ self.seq_lens.add_(1)
738
767
 
739
- def filter_batch(self, unfinished_indices: List[int]):
740
- if unfinished_indices is None or len(unfinished_indices) == 0:
768
+ def filter_batch(
769
+ self,
770
+ current_inflight_req: Optional[Req] = None,
771
+ keep_indices: Optional[List[int]] = None,
772
+ ):
773
+ if keep_indices is None:
774
+ keep_indices = [
775
+ i
776
+ for i in range(len(self.reqs))
777
+ if not self.reqs[i].finished()
778
+ and self.reqs[i] is not current_inflight_req
779
+ ]
780
+
781
+ if keep_indices is None or len(keep_indices) == 0:
741
782
  # Filter out all requests
742
783
  self.reqs = []
743
784
  return
744
785
 
745
- if len(unfinished_indices) == len(self.reqs):
786
+ if len(keep_indices) == len(self.reqs):
746
787
  # No need to filter
747
788
  return
748
789
 
749
- self.reqs = [self.reqs[i] for i in unfinished_indices]
750
- new_indices = torch.tensor(
751
- unfinished_indices, dtype=torch.int32, device=self.seq_lens.device
790
+ self.reqs = [self.reqs[i] for i in keep_indices]
791
+ new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
792
+ self.device, non_blocking=True
752
793
  )
753
794
  self.req_pool_indices = self.req_pool_indices[new_indices]
754
795
  self.seq_lens = self.seq_lens[new_indices]
755
796
  self.out_cache_loc = None
797
+ self.output_ids = self.output_ids[new_indices]
756
798
  self.return_logprob = any(req.return_logprob for req in self.reqs)
757
799
  if self.return_logprob:
758
- self.top_logprobs_nums = [
759
- self.top_logprobs_nums[i] for i in unfinished_indices
760
- ]
800
+ self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
761
801
  else:
762
802
  self.top_logprobs_nums = None
763
803
 
764
804
  self.has_stream = any(req.stream for req in self.reqs)
765
805
  self.has_regex = any(req.regex_fsm for req in self.reqs)
766
806
 
767
- self.sampling_info.filter_batch(unfinished_indices, new_indices)
807
+ self.sampling_info.filter_batch(keep_indices, new_indices)
768
808
 
769
809
  def merge_batch(self, other: "ScheduleBatch"):
770
810
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -777,6 +817,8 @@ class ScheduleBatch:
777
817
  )
778
818
  self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
779
819
  self.out_cache_loc = None
820
+ if self.output_ids is not None:
821
+ self.output_ids = torch.concat([self.output_ids, other.output_ids])
780
822
  if self.return_logprob and other.return_logprob:
781
823
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
782
824
  elif self.return_logprob:
@@ -806,8 +848,14 @@ class ScheduleBatch:
806
848
  self.sampling_info.regex_fsm_states = [
807
849
  req.regex_fsm_state for req in self.reqs
808
850
  ]
851
+ else:
852
+ self.sampling_info.regex_fsms = None
853
+
854
+ global bid
855
+ bid += 1
809
856
 
810
857
  return ModelWorkerBatch(
858
+ bid=bid,
811
859
  forward_mode=self.forward_mode,
812
860
  input_ids=self.input_ids,
813
861
  req_pool_indices=self.req_pool_indices,
@@ -823,9 +871,26 @@ class ScheduleBatch:
823
871
  sampling_info=self.sampling_info,
824
872
  )
825
873
 
874
+ def copy(self):
875
+ return ScheduleBatch(
876
+ reqs=self.reqs,
877
+ forward_mode=self.forward_mode,
878
+ out_cache_loc=self.out_cache_loc,
879
+ return_logprob=self.return_logprob,
880
+ decoding_reqs=self.decoding_reqs,
881
+ )
882
+
883
+ def __str__(self):
884
+ return (
885
+ f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
886
+ f"#req={(len(self.reqs))})"
887
+ )
888
+
826
889
 
827
890
  @dataclass
828
891
  class ModelWorkerBatch:
892
+ # The batch id
893
+ bid: int
829
894
  # The forward mode
830
895
  forward_mode: ForwardMode
831
896
  # The input ids
@@ -854,3 +919,21 @@ class ModelWorkerBatch:
854
919
 
855
920
  # Sampling info
856
921
  sampling_info: SamplingBatchInfo
922
+
923
+ def copy(self):
924
+ return ModelWorkerBatch(
925
+ bid=self.bid,
926
+ forward_mode=self.forward_mode,
927
+ input_ids=self.input_ids.clone(),
928
+ req_pool_indices=self.req_pool_indices,
929
+ seq_lens=self.seq_lens.clone(),
930
+ out_cache_loc=self.out_cache_loc,
931
+ return_logprob=self.return_logprob,
932
+ top_logprobs_nums=self.top_logprobs_nums,
933
+ extend_seq_lens=self.extend_seq_lens,
934
+ extend_prefix_lens=self.extend_prefix_lens,
935
+ extend_logprob_start_lens=self.extend_logprob_start_lens,
936
+ image_inputs=self.image_inputs,
937
+ lora_paths=self.lora_paths,
938
+ sampling_info=self.sampling_info.copy(),
939
+ )
@@ -45,12 +45,13 @@ class SchedulePolicy:
45
45
  def calc_priority(self, waiting_queue: List[Req]):
46
46
  # Compute matched prefix length
47
47
  prefix_computed = False
48
- if self.policy in ["lpm", "dfs-weight"]:
48
+ if self.policy == "lpm" or self.policy == "dfs-weight":
49
49
  for r in waiting_queue:
50
50
  # NOTE: the prefix_indices must always be aligned with last_node
51
51
  r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
52
52
  rid=r.rid, key=r.adjust_max_prefix_ids()
53
53
  )
54
+
54
55
  prefix_computed = True
55
56
 
56
57
  if self.policy == "lpm":