sglang 0.1.19__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,10 +15,16 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
15
15
 
16
16
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
17
17
 
18
+ # Store some global server args
19
+ global_server_args_dict = {}
20
+
18
21
 
19
22
  class ForwardMode(IntEnum):
23
+ # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
20
24
  PREFILL = auto()
25
+ # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
21
26
  EXTEND = auto()
27
+ # Decode one token.
22
28
  DECODE = auto()
23
29
 
24
30
 
@@ -66,7 +72,10 @@ class FINISH_ABORT(BaseFinishReason):
66
72
 
67
73
 
68
74
  class Req:
75
+ """Store all inforamtion of a request."""
76
+
69
77
  def __init__(self, rid, origin_input_text, origin_input_ids):
78
+ # Input and output info
70
79
  self.rid = rid
71
80
  self.origin_input_text = origin_input_text
72
81
  self.origin_input_ids_unpadded = origin_input_ids # Before image padding
@@ -74,7 +83,7 @@ class Req:
74
83
  self.output_ids = [] # Each decode stage's output ids
75
84
  self.input_ids = None # input_ids = origin_input_ids + output_ids
76
85
 
77
- # For incremental decode
86
+ # For incremental decoding
78
87
  self.decoded_text = ""
79
88
  self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
80
89
  self.read_offset = None
@@ -89,20 +98,19 @@ class Req:
89
98
  self.image_offset = 0
90
99
  self.pad_value = None
91
100
 
101
+ # Prefix info
102
+ self.extend_input_len = 0
103
+ self.prefix_indices = []
104
+ self.last_node = None
105
+
92
106
  # Sampling parameters
93
107
  self.sampling_params = None
94
108
  self.stream = False
95
109
 
96
- self.tokenizer = None
97
-
98
110
  # Check finish
111
+ self.tokenizer = None
99
112
  self.finished_reason = None
100
113
 
101
- # Prefix info
102
- self.extend_input_len = 0
103
- self.prefix_indices = []
104
- self.last_node = None
105
-
106
114
  # Logprobs
107
115
  self.return_logprob = False
108
116
  self.logprob_start_len = 0
@@ -252,35 +260,36 @@ class Req:
252
260
 
253
261
  @dataclass
254
262
  class Batch:
263
+ """Store all inforamtion of a batch."""
264
+
265
+ # Request, memory pool, and cache
255
266
  reqs: List[Req]
256
267
  req_to_token_pool: ReqToTokenPool
257
268
  token_to_kv_pool: TokenToKVPool
258
269
  tree_cache: RadixCache
259
270
 
260
- # batched arguments to model runner
271
+ # Batched arguments to model runner
261
272
  input_ids: torch.Tensor = None
262
273
  req_pool_indices: torch.Tensor = None
263
274
  seq_lens: torch.Tensor = None
264
275
  prefix_lens: torch.Tensor = None
265
276
  position_ids_offsets: torch.Tensor = None
266
277
  out_cache_loc: torch.Tensor = None
267
- out_cache_cont_start: torch.Tensor = None
268
- out_cache_cont_end: torch.Tensor = None
269
278
 
270
- # for processing logprobs
279
+ # For processing logprobs
271
280
  return_logprob: bool = False
272
281
  top_logprobs_nums: List[int] = None
273
282
 
274
- # for multimodal
283
+ # For multimodal
275
284
  pixel_values: List[torch.Tensor] = None
276
285
  image_sizes: List[List[int]] = None
277
286
  image_offsets: List[int] = None
278
287
 
279
- # other arguments for control
288
+ # Other arguments for control
280
289
  output_ids: torch.Tensor = None
281
290
  extend_num_tokens: int = None
282
291
 
283
- # batched sampling params
292
+ # Batched sampling params
284
293
  temperatures: torch.Tensor = None
285
294
  top_ps: torch.Tensor = None
286
295
  top_ks: torch.Tensor = None
@@ -303,8 +312,8 @@ class Batch:
303
312
  def is_empty(self):
304
313
  return len(self.reqs) == 0
305
314
 
306
- # whether batch has at least 1 streaming request
307
315
  def has_stream(self) -> bool:
316
+ # Return whether batch has at least 1 streaming request
308
317
  return any(r.stream for r in self.reqs)
309
318
 
310
319
  def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
@@ -338,7 +347,7 @@ class Batch:
338
347
 
339
348
  position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
340
349
 
341
- # Alloc mem
350
+ # Allocate memory
342
351
  seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
343
352
  extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
344
353
  out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
@@ -555,21 +564,12 @@ class Batch:
555
564
 
556
565
  # Alloc mem
557
566
  bs = len(self.reqs)
558
- alloc_res = self.token_to_kv_pool.alloc_contiguous(bs)
559
- if alloc_res is None:
560
- self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
567
+ self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
561
568
 
562
- if self.out_cache_loc is None:
563
- print("Decode out of memory. This should never happen.")
564
- self.tree_cache.pretty_print()
565
- exit()
566
-
567
- self.out_cache_cont_start = None
568
- self.out_cache_cont_end = None
569
- else:
570
- self.out_cache_loc = alloc_res[0]
571
- self.out_cache_cont_start = alloc_res[1]
572
- self.out_cache_cont_end = alloc_res[2]
569
+ if self.out_cache_loc is None:
570
+ print("Decode out of memory. This should never happen.")
571
+ self.tree_cache.pretty_print()
572
+ exit()
573
573
 
574
574
  self.req_to_token_pool.req_to_token[
575
575
  self.req_pool_indices, self.seq_lens - 1
@@ -583,7 +583,7 @@ class Batch:
583
583
  self.req_pool_indices = self.req_pool_indices[new_indices]
584
584
  self.prefix_lens = None
585
585
  self.position_ids_offsets = self.position_ids_offsets[new_indices]
586
- self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
586
+ self.out_cache_loc = None
587
587
  self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
588
588
  self.return_logprob = any(req.return_logprob for req in self.reqs)
589
589
 
@@ -611,7 +611,7 @@ class Batch:
611
611
  self.position_ids_offsets = torch.concat(
612
612
  [self.position_ids_offsets, other.position_ids_offsets]
613
613
  )
614
- self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
614
+ self.out_cache_loc = None
615
615
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
616
616
  self.return_logprob = any(req.return_logprob for req in self.reqs)
617
617
 
@@ -664,7 +664,11 @@ class Batch:
664
664
  # TODO(lmzheng): apply penalty
665
665
  probs = torch.softmax(logits, dim=-1)
666
666
  probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
667
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
667
+ try:
668
+ sampled_index = torch.multinomial(probs_sort, num_samples=1)
669
+ except RuntimeError as e:
670
+ warnings.warn(f"Ignore errors in sampling: {e}")
671
+ sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device)
668
672
  batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
669
673
  -1
670
674
  )
@@ -692,3 +696,207 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
692
696
  ] = 0.0
693
697
  probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
694
698
  return probs_sort, probs_idx
699
+
700
+
701
+ @dataclass
702
+ class InputMetadata:
703
+ """Store all inforamtion of a forward pass."""
704
+
705
+ forward_mode: ForwardMode
706
+ batch_size: int
707
+ total_num_tokens: int
708
+ req_pool_indices: torch.Tensor
709
+ seq_lens: torch.Tensor
710
+ positions: torch.Tensor
711
+ req_to_token_pool: ReqToTokenPool
712
+ token_to_kv_pool: TokenToKVPool
713
+
714
+ # For extend
715
+ extend_seq_lens: torch.Tensor
716
+ extend_start_loc: torch.Tensor
717
+ extend_no_prefix: bool
718
+
719
+ # Output location of the KV cache
720
+ out_cache_loc: torch.Tensor = None
721
+
722
+ # Output options
723
+ return_logprob: bool = False
724
+ top_logprobs_nums: List[int] = None
725
+
726
+ # Trition attention backend
727
+ triton_max_seq_len: int = 0
728
+ triton_max_extend_len: int = 0
729
+ triton_start_loc: torch.Tensor = None
730
+ triton_prefix_lens: torch.Tensor = None
731
+
732
+ # FlashInfer attention backend
733
+ flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
734
+ flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
735
+ flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
736
+
737
+ @classmethod
738
+ def create(
739
+ cls,
740
+ model_runner,
741
+ forward_mode,
742
+ req_pool_indices,
743
+ seq_lens,
744
+ prefix_lens,
745
+ position_ids_offsets,
746
+ out_cache_loc,
747
+ top_logprobs_nums=None,
748
+ return_logprob=False,
749
+ skip_flashinfer_init=False,
750
+ ):
751
+ if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
752
+ init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
753
+ model_runner.flashinfer_decode_wrapper)
754
+
755
+ batch_size = len(req_pool_indices)
756
+
757
+ if forward_mode == ForwardMode.DECODE:
758
+ positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
759
+ extend_seq_lens = extend_start_loc = extend_no_prefix = None
760
+ if not model_runner.server_args.disable_flashinfer:
761
+ # This variable is not needed in this case,
762
+ # we do not compute it to make it compatbile with cuda graph.
763
+ total_num_tokens = None
764
+ else:
765
+ total_num_tokens = int(torch.sum(seq_lens))
766
+ else:
767
+ seq_lens_cpu = seq_lens.cpu().numpy()
768
+ prefix_lens_cpu = prefix_lens.cpu().numpy()
769
+ position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
770
+ positions = torch.tensor(
771
+ np.concatenate(
772
+ [
773
+ np.arange(
774
+ prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
775
+ seq_lens_cpu[i] + position_ids_offsets_cpu[i],
776
+ )
777
+ for i in range(batch_size)
778
+ ],
779
+ axis=0,
780
+ ),
781
+ device="cuda",
782
+ )
783
+ extend_seq_lens = seq_lens - prefix_lens
784
+ extend_start_loc = torch.zeros_like(seq_lens)
785
+ extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
786
+ extend_no_prefix = torch.all(prefix_lens == 0)
787
+ total_num_tokens = int(torch.sum(seq_lens))
788
+
789
+ ret = cls(
790
+ forward_mode=forward_mode,
791
+ batch_size=batch_size,
792
+ total_num_tokens=total_num_tokens,
793
+ req_pool_indices=req_pool_indices,
794
+ seq_lens=seq_lens,
795
+ positions=positions,
796
+ req_to_token_pool=model_runner.req_to_token_pool,
797
+ token_to_kv_pool=model_runner.token_to_kv_pool,
798
+ out_cache_loc=out_cache_loc,
799
+ extend_seq_lens=extend_seq_lens,
800
+ extend_start_loc=extend_start_loc,
801
+ extend_no_prefix=extend_no_prefix,
802
+ return_logprob=return_logprob,
803
+ top_logprobs_nums=top_logprobs_nums,
804
+ flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
805
+ flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
806
+ flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
807
+ )
808
+
809
+ if model_runner.server_args.disable_flashinfer:
810
+ (ret.triton_max_seq_len,
811
+ ret.triton_max_extend_len,
812
+ ret.triton_start_loc,
813
+ ret.triton_prefix_lens) = init_triton_args(forward_mode, seq_lens, prefix_lens)
814
+
815
+ return ret
816
+
817
+
818
+ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
819
+ flashinfer_decode_wrapper):
820
+ num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
821
+ num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
822
+ head_dim = model_runner.model_config.head_dim
823
+ batch_size = len(req_pool_indices)
824
+
825
+ if forward_mode == ForwardMode.DECODE:
826
+ paged_kernel_lens = seq_lens
827
+ else:
828
+ paged_kernel_lens = prefix_lens
829
+
830
+ kv_indptr = torch.zeros(
831
+ (batch_size + 1,), dtype=torch.int32, device="cuda"
832
+ )
833
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
834
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
835
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
836
+ kv_indices = torch.cat(
837
+ [
838
+ model_runner.req_to_token_pool.req_to_token[
839
+ req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
840
+ ]
841
+ for i in range(batch_size)
842
+ ],
843
+ dim=0,
844
+ ).contiguous()
845
+ kv_last_page_len = torch.ones(
846
+ (batch_size,), dtype=torch.int32, device="cuda"
847
+ )
848
+
849
+ if forward_mode == ForwardMode.DECODE:
850
+ flashinfer_decode_wrapper.end_forward()
851
+ flashinfer_decode_wrapper.begin_forward(
852
+ kv_indptr,
853
+ kv_indices,
854
+ kv_last_page_len,
855
+ num_qo_heads,
856
+ num_kv_heads,
857
+ head_dim,
858
+ 1,
859
+ )
860
+ else:
861
+ # extend part
862
+ qo_indptr = torch.zeros(
863
+ (batch_size + 1,), dtype=torch.int32, device="cuda"
864
+ )
865
+ qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
866
+
867
+ model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
868
+ model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
869
+ qo_indptr,
870
+ qo_indptr,
871
+ num_qo_heads,
872
+ num_kv_heads,
873
+ head_dim,
874
+ )
875
+
876
+ # cached part
877
+ model_runner.flashinfer_prefill_wrapper_paged.end_forward()
878
+ model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
879
+ qo_indptr,
880
+ kv_indptr,
881
+ kv_indices,
882
+ kv_last_page_len,
883
+ num_qo_heads,
884
+ num_kv_heads,
885
+ head_dim,
886
+ 1,
887
+ )
888
+
889
+
890
+ def init_triton_args(forward_mode, seq_lens, prefix_lens):
891
+ batch_size = len(seq_lens)
892
+ max_seq_len = int(torch.max(seq_lens))
893
+ start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
894
+ start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
895
+
896
+ if forward_mode == ForwardMode.DECODE:
897
+ max_extend_len = None
898
+ else:
899
+ extend_seq_lens = seq_lens - prefix_lens
900
+ max_extend_len = int(torch.max(extend_seq_lens))
901
+
902
+ return max_seq_len, max_extend_len, start_loc, prefix_lens