sglang 0.4.3.post3__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sglang/bench_serving.py CHANGED
@@ -220,7 +220,7 @@ async def async_request_openai_completions(
220
220
 
221
221
  most_recent_timestamp = timestamp
222
222
  generated_text += data["choices"][0]["text"]
223
- output_len = data.get("usage", {}).get(
223
+ output_len = (data.get("usage") or {}).get(
224
224
  "completion_tokens", output_len
225
225
  )
226
226
 
@@ -7,16 +7,14 @@ FlashInfer is faster and Triton is easier to customize.
7
7
  Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
8
  """
9
9
 
10
- import math
11
10
  import os
12
11
  from dataclasses import dataclass
13
12
  from enum import Enum, auto
14
13
  from functools import partial
15
- from typing import TYPE_CHECKING, List, Optional, Union
14
+ from typing import TYPE_CHECKING, Callable, List, Optional, Union
16
15
 
17
16
  import torch
18
17
  import triton
19
- import triton.language as tl
20
18
 
21
19
  from sglang.global_config import global_config
22
20
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -37,7 +35,7 @@ if is_flashinfer_available():
37
35
  BatchPrefillWithRaggedKVCacheWrapper,
38
36
  )
39
37
  from flashinfer.cascade import merge_state
40
- from flashinfer.decode import PosEncodingMode
38
+ from flashinfer.decode import _get_range_buf, get_seq_lens
41
39
 
42
40
 
43
41
  class WrapperDispatch(Enum):
@@ -73,8 +71,6 @@ class FlashInferAttnBackend(AttentionBackend):
73
71
  ):
74
72
  super().__init__()
75
73
 
76
- self.is_multimodal = model_runner.model_config.is_multimodal
77
-
78
74
  # Parse constants
79
75
  self.decode_use_tensor_cores = should_use_tensor_core(
80
76
  kv_cache_dtype=model_runner.kv_cache_dtype,
@@ -86,6 +82,7 @@ class FlashInferAttnBackend(AttentionBackend):
86
82
  )
87
83
  self.max_context_len = model_runner.model_config.context_len
88
84
  self.skip_prefill = skip_prefill
85
+ self.is_multimodal = model_runner.model_config.is_multimodal
89
86
 
90
87
  assert not (
91
88
  model_runner.sliding_window_size is not None
@@ -115,7 +112,6 @@ class FlashInferAttnBackend(AttentionBackend):
115
112
  device=model_runner.device,
116
113
  )
117
114
  self.workspace_buffer = global_workspace_buffer
118
-
119
115
  max_bs = model_runner.req_to_token_pool.size
120
116
  if kv_indptr_buf is None:
121
117
  self.kv_indptr = [
@@ -163,9 +159,11 @@ class FlashInferAttnBackend(AttentionBackend):
163
159
  )
164
160
  )
165
161
  self.prefill_wrappers_verify.append(
166
- BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
162
+ BatchPrefillWithPagedKVCacheWrapper(
163
+ self.workspace_buffer,
164
+ "NHD",
165
+ )
167
166
  )
168
-
169
167
  self.decode_wrappers.append(
170
168
  BatchDecodeWithPagedKVCacheWrapper(
171
169
  self.workspace_buffer,
@@ -178,13 +176,14 @@ class FlashInferAttnBackend(AttentionBackend):
178
176
  if not skip_prefill:
179
177
  self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
180
178
  model_runner, self
181
- )
179
+ ) # for verify
182
180
  self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
183
181
 
184
182
  # Other metadata
185
183
  self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
186
184
  self.decode_cuda_graph_metadata = {}
187
- self.prefill_cuda_graph_metadata = {}
185
+ self.prefill_cuda_graph_metadata = {} # For verify
186
+ self.draft_extend_cuda_graph_metadata = {} # For draft extend
188
187
 
189
188
  def init_forward_metadata(self, forward_batch: ForwardBatch):
190
189
  if forward_batch.forward_mode.is_decode_or_idle():
@@ -300,7 +299,6 @@ class FlashInferAttnBackend(AttentionBackend):
300
299
  ],
301
300
  )
302
301
  )
303
-
304
302
  seq_lens_sum = seq_lens.sum().item()
305
303
  self.indices_updater_decode.update(
306
304
  req_pool_indices,
@@ -312,6 +310,10 @@ class FlashInferAttnBackend(AttentionBackend):
312
310
  )
313
311
  self.decode_cuda_graph_metadata[bs] = decode_wrappers
314
312
  self.forward_metadata = DecodeMetadata(decode_wrappers)
313
+ for i in range(self.num_wrappers):
314
+ decode_wrappers[i].begin_forward = partial(
315
+ fast_decode_plan, decode_wrappers[i]
316
+ )
315
317
  elif forward_mode.is_target_verify():
316
318
  prefill_wrappers = []
317
319
  for i in range(self.num_wrappers):
@@ -437,7 +439,7 @@ class FlashInferAttnBackend(AttentionBackend):
437
439
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
438
440
  causal=False,
439
441
  sm_scale=layer.scaling,
440
- logits_soft_cap=layer.logit_cap,
442
+ logits_soft_cap=logits_soft_cap,
441
443
  )
442
444
 
443
445
  o, _ = merge_state(o1, s1, o2, s2)
@@ -636,9 +638,15 @@ class FlashInferIndicesUpdaterDecode:
636
638
  bs = len(req_pool_indices)
637
639
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
638
640
  kv_indptr = kv_indptr[: bs + 1]
639
- kv_indices = torch.empty(
640
- paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
641
- )
641
+
642
+ if wrapper.is_cuda_graph_enabled:
643
+ # Directly write to the cuda graph input buffer
644
+ kv_indices = wrapper._paged_kv_indices_buf
645
+ else:
646
+ kv_indices = torch.empty(
647
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
648
+ )
649
+
642
650
  create_flashinfer_kv_indices_triton[(bs,)](
643
651
  self.req_to_token,
644
652
  req_pool_indices,
@@ -649,9 +657,9 @@ class FlashInferIndicesUpdaterDecode:
649
657
  self.req_to_token.shape[1],
650
658
  )
651
659
  else:
652
- assert isinstance(spec_info, EagleDraftInput)
653
660
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
654
661
  bs = kv_indptr.shape[0] - 1
662
+
655
663
  wrapper.begin_forward(
656
664
  kv_indptr,
657
665
  kv_indices,
@@ -699,7 +707,7 @@ class FlashInferIndicesUpdaterPrefill:
699
707
 
700
708
  def update(
701
709
  self,
702
- req_pool_indices: torch.Tnesor,
710
+ req_pool_indices: torch.Tensor,
703
711
  seq_lens: torch.Tensor,
704
712
  seq_lens_sum: int,
705
713
  prefix_lens: torch.Tensor,
@@ -713,7 +721,7 @@ class FlashInferIndicesUpdaterPrefill:
713
721
 
714
722
  def update_single_wrapper(
715
723
  self,
716
- req_pool_indices: torch.Tnesor,
724
+ req_pool_indices: torch.Tensor,
717
725
  seq_lens: torch.Tensor,
718
726
  seq_lens_sum: int,
719
727
  prefix_lens: torch.Tensor,
@@ -858,7 +866,6 @@ class FlashInferIndicesUpdaterPrefill:
858
866
  kv_indices,
859
867
  self.req_to_token.shape[1],
860
868
  )
861
-
862
869
  qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
863
870
  qo_indptr = qo_indptr[: bs + 1]
864
871
  custom_mask = None
@@ -897,6 +904,7 @@ class FlashInferIndicesUpdaterPrefill:
897
904
  self.head_dim,
898
905
  1,
899
906
  q_data_type=self.q_data_type,
907
+ kv_data_type=self.data_type,
900
908
  custom_mask=custom_mask,
901
909
  non_blocking=True,
902
910
  )
@@ -954,7 +962,10 @@ class FlashInferMultiStepDraftBackend:
954
962
  self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
955
963
 
956
964
  def common_template(
957
- self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
965
+ self,
966
+ forward_batch: ForwardBatch,
967
+ kv_indices_buffer: torch.Tensor,
968
+ call_fn: Callable,
958
969
  ):
959
970
  num_seqs = forward_batch.batch_size
960
971
  bs = self.topk * num_seqs
@@ -1042,17 +1053,15 @@ class FlashInferMultiStepDraftBackend:
1042
1053
  forward_mode=ForwardMode.DECODE,
1043
1054
  spec_info=forward_batch.spec_info,
1044
1055
  )
1045
- decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
1046
- forward_batch.batch_size
1047
- ][0]
1048
- decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
1049
1056
 
1050
1057
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1051
1058
 
1052
- def init_forward_metadata_replay_cuda_graph(self, forward_batch):
1059
+ def init_forward_metadata_replay_cuda_graph(
1060
+ self, forward_batch: ForwardBatch, bs: int
1061
+ ):
1053
1062
  def call_fn(i, forward_batch):
1054
1063
  self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1055
- forward_batch.batch_size,
1064
+ bs,
1056
1065
  forward_batch.req_pool_indices,
1057
1066
  forward_batch.seq_lens,
1058
1067
  seq_lens_sum=-1,
@@ -1113,6 +1122,11 @@ def should_use_tensor_core(
1113
1122
  return False
1114
1123
 
1115
1124
 
1125
+ # Use as a fast path to override the indptr in flashinfer's plan function
1126
+ # This is used to remove some host-to-device copy overhead.
1127
+ global_override_indptr_cpu = None
1128
+
1129
+
1116
1130
  def fast_decode_plan(
1117
1131
  self,
1118
1132
  indptr: torch.Tensor,
@@ -1142,6 +1156,9 @@ def fast_decode_plan(
1142
1156
  if logits_soft_cap is None:
1143
1157
  logits_soft_cap = 0.0
1144
1158
 
1159
+ if self.use_tensor_cores:
1160
+ qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
1161
+
1145
1162
  if self.is_cuda_graph_enabled:
1146
1163
  if batch_size != self._fixed_batch_size:
1147
1164
  raise ValueError(
@@ -1154,7 +1171,7 @@ def fast_decode_plan(
1154
1171
  raise ValueError(
1155
1172
  "The size of indices should be less than or equal to the allocated buffer"
1156
1173
  )
1157
- # Skip these copies
1174
+ # Skip these copies because we directly write to them during prepartion
1158
1175
  # self._paged_kv_indptr_buf.copy_(indptr)
1159
1176
  # self._paged_kv_indices_buf[: len(indices)] = indices
1160
1177
  # self._paged_kv_last_page_len_buf.copy_(last_page_len)
@@ -1162,6 +1179,7 @@ def fast_decode_plan(
1162
1179
  self._paged_kv_indptr_buf = indptr
1163
1180
  self._paged_kv_indices_buf = indices
1164
1181
  self._paged_kv_last_page_len_buf = last_page_len
1182
+ self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
1165
1183
 
1166
1184
  # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
1167
1185
  if not q_data_type:
@@ -1184,27 +1202,55 @@ def fast_decode_plan(
1184
1202
  )
1185
1203
  self.last_page_len = torch.ones(32768, dtype=torch.int32)
1186
1204
 
1187
- empty_q_data = self.empty_q_data
1188
- empty_kv_cache = self.empty_kv_cache
1189
- stream = torch.cuda.current_stream()
1190
- self._cached_module.plan(
1191
- self._float_workspace_buffer,
1192
- self._int_workspace_buffer,
1193
- self._pin_memory_int_workspace_buffer,
1194
- indptr.to("cpu"),
1195
- batch_size,
1196
- num_qo_heads,
1197
- num_kv_heads,
1198
- page_size,
1199
- self.is_cuda_graph_enabled,
1200
- window_left,
1201
- logits_soft_cap,
1202
- head_dim,
1203
- head_dim,
1204
- empty_q_data,
1205
- empty_kv_cache,
1206
- stream.cuda_stream,
1205
+ indptr_host = (
1206
+ global_override_indptr_cpu
1207
+ if global_override_indptr_cpu is not None
1208
+ else indptr.cpu()
1207
1209
  )
1210
+
1211
+ if self.use_tensor_cores:
1212
+ kv_lens_arr_host = get_seq_lens(
1213
+ indptr_host, self.last_page_len[:batch_size], page_size
1214
+ )
1215
+
1216
+ self._plan_info = self._cached_module.plan(
1217
+ self._float_workspace_buffer,
1218
+ self._int_workspace_buffer,
1219
+ self._pin_memory_int_workspace_buffer,
1220
+ qo_indptr_host,
1221
+ indptr_host,
1222
+ kv_lens_arr_host,
1223
+ batch_size, # total_num_rows
1224
+ batch_size,
1225
+ num_qo_heads,
1226
+ num_kv_heads,
1227
+ page_size,
1228
+ self.is_cuda_graph_enabled,
1229
+ head_dim,
1230
+ head_dim,
1231
+ False, # causal
1232
+ torch.cuda.current_stream().cuda_stream,
1233
+ )
1234
+ else:
1235
+ self._plan_info = self._cached_module.plan(
1236
+ self._float_workspace_buffer,
1237
+ self._int_workspace_buffer,
1238
+ self._pin_memory_int_workspace_buffer,
1239
+ indptr_host,
1240
+ batch_size,
1241
+ num_qo_heads,
1242
+ num_kv_heads,
1243
+ page_size,
1244
+ self.is_cuda_graph_enabled,
1245
+ window_left,
1246
+ logits_soft_cap,
1247
+ head_dim,
1248
+ head_dim,
1249
+ self.empty_q_data,
1250
+ self.empty_kv_cache,
1251
+ torch.cuda.current_stream().cuda_stream,
1252
+ )
1253
+
1208
1254
  self._pos_encoding_mode = pos_encoding_mode
1209
1255
  self._window_left = window_left
1210
1256
  self._logits_soft_cap = logits_soft_cap
@@ -578,10 +578,12 @@ class TritonMultiStepDraftBackend:
578
578
 
579
579
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
580
580
 
581
- def init_forward_metadata_replay_cuda_graph(self, forward_batch):
581
+ def init_forward_metadata_replay_cuda_graph(
582
+ self, forward_batch: ForwardBatch, bs: int
583
+ ):
582
584
  def call_fn(i, forward_batch):
583
585
  self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
584
- forward_batch.batch_size,
586
+ bs,
585
587
  forward_batch.req_pool_indices,
586
588
  forward_batch.seq_lens,
587
589
  seq_lens_sum=-1,
@@ -482,6 +482,7 @@ class BatchEmbeddingOut:
482
482
  embeddings: List[List[float]]
483
483
  # Token counts
484
484
  prompt_tokens: List[int]
485
+ cached_tokens: List[int]
485
486
 
486
487
 
487
488
  @dataclass