sglang 0.5.1__py3-none-any.whl → 0.5.1.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -334,6 +334,8 @@ class DecodePreallocQueue:
334
334
  error_message,
335
335
  status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
336
336
  )
337
+ if self.scheduler.enable_metrics:
338
+ self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
337
339
  else:
338
340
  raise ValueError(f"Unexpected poll case: {poll}")
339
341
 
@@ -595,6 +597,8 @@ class DecodeTransferQueue:
595
597
  # unlock the kv cache or it will have memory leak
596
598
  self.tree_cache.cache_finished_req(decode_req.req)
597
599
  indices_to_remove.add(i)
600
+ if self.scheduler.enable_metrics:
601
+ self.scheduler.metrics_collector.increment_transfer_failed_reqs()
598
602
  continue
599
603
  elif poll == KVPoll.Success:
600
604
 
@@ -238,6 +238,8 @@ class PrefillBootstrapQueue:
238
238
  self.scheduler.stream_output([req], req.return_logprob)
239
239
  indices_to_remove.add(i)
240
240
  failed_reqs.append(req)
241
+ if self.scheduler.enable_metrics:
242
+ self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
241
243
  continue
242
244
 
243
245
  # KV.WaitingForInput - init here
@@ -522,6 +524,8 @@ class SchedulerDisaggregationPrefillMixin:
522
524
  req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
523
525
  )
524
526
  done_reqs.append(req)
527
+ if self.enable_metrics:
528
+ self.metrics_collector.increment_transfer_failed_reqs()
525
529
  else:
526
530
  assert False, f"Unexpected polling state {poll=}"
527
531
 
@@ -672,7 +672,7 @@ def _set_envs_and_config(server_args: ServerArgs):
672
672
  if server_args.attention_backend == "flashinfer":
673
673
  assert_pkg_version(
674
674
  "flashinfer_python",
675
- "0.2.11.post3",
675
+ "0.2.14.post1",
676
676
  "Please uninstall the old version and "
677
677
  "reinstall the latest version by following the instructions "
678
678
  "at https://docs.flashinfer.ai/installation.html.",
@@ -4,6 +4,8 @@ import os
4
4
  from abc import ABC, abstractmethod
5
5
  from typing import TYPE_CHECKING, Any
6
6
 
7
+ from sglang.srt.utils import print_info_once, print_warning_once
8
+
7
9
  if TYPE_CHECKING:
8
10
  # Avoid circular import.
9
11
  from sglang.srt.entrypoints.context import ConversationContext
@@ -25,7 +27,7 @@ class HarmonyBrowserTool(Tool):
25
27
  exa_api_key = os.getenv("EXA_API_KEY")
26
28
  if not exa_api_key:
27
29
  self.enabled = False
28
- logger.warning_once("EXA_API_KEY is not set, browsing is disabled")
30
+ print_warning_once("EXA_API_KEY is not set, browsing is disabled")
29
31
  return
30
32
 
31
33
  try:
@@ -33,12 +35,12 @@ class HarmonyBrowserTool(Tool):
33
35
  from gpt_oss.tools.simple_browser.backend import ExaBackend
34
36
  except ImportError:
35
37
  self.enabled = False
36
- logger.warning_once("gpt_oss is not installed, browsing is disabled")
38
+ print_warning_once("gpt_oss is not installed, browsing is disabled")
37
39
  return
38
40
 
39
41
  browser_backend = ExaBackend(source="web", api_key=exa_api_key)
40
42
  self.browser_tool = SimpleBrowserTool(backend=browser_backend)
41
- logger.info_once("Browser tool initialized")
43
+ print_info_once("Browser tool initialized")
42
44
 
43
45
  async def get_result(self, context: "ConversationContext") -> Any:
44
46
  from sglang.srt.entrypoints.context import HarmonyContext
@@ -64,13 +66,11 @@ class HarmonyPythonTool(Tool):
64
66
  from gpt_oss.tools.python_docker.docker_tool import PythonTool
65
67
  except ImportError:
66
68
  self.enabled = False
67
- logger.warning_once(
68
- "gpt_oss is not installed, code interpreter is disabled"
69
- )
69
+ print_warning_once("gpt_oss is not installed, code interpreter is disabled")
70
70
  return
71
71
 
72
72
  self.python_tool = PythonTool()
73
- logger.info_once("Code interpreter tool initialized")
73
+ print_info_once("Code interpreter tool initialized")
74
74
 
75
75
  async def get_result(self, context: "ConversationContext") -> Any:
76
76
  from sglang.srt.entrypoints.context import HarmonyContext
@@ -24,7 +24,9 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
24
24
 
25
25
  from sglang.global_config import global_config
26
26
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
27
- from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
27
+ from sglang.srt.layers.attention.flashinfer_backend import (
28
+ create_flashinfer_kv_indices_triton,
29
+ )
28
30
  from sglang.srt.layers.dp_attention import get_attention_tp_size
29
31
  from sglang.srt.layers.utils import is_sm100_supported
30
32
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -179,6 +181,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
179
181
  q_indptr_decode_buf: Optional[torch.Tensor] = None,
180
182
  ):
181
183
  super().__init__()
184
+
182
185
  # Parse constants
183
186
  self.max_context_len = model_runner.model_config.context_len
184
187
  self.device = model_runner.device
@@ -210,25 +213,15 @@ class FlashInferMLAAttnBackend(AttentionBackend):
210
213
  else:
211
214
  self.kv_indptr = kv_indptr_buf
212
215
 
213
- self.kv_indices = torch.empty(
214
- (max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
215
- dtype=torch.int32,
216
- device=model_runner.device,
217
- )
218
-
219
216
  if not self.skip_prefill:
220
217
  self.qo_indptr = torch.zeros(
221
218
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
222
219
  )
223
220
 
224
221
  if q_indptr_decode_buf is None:
225
- # A hack to pre-initialize large batch size for dp attention
226
- if model_runner.server_args.enable_dp_attention:
227
- max_bs = model_runner.server_args.dp_size * max_bs
228
222
  self.q_indptr_decode = torch.arange(
229
223
  0, max_bs + 1, dtype=torch.int32, device=model_runner.device
230
224
  )
231
-
232
225
  else:
233
226
  self.q_indptr_decode = q_indptr_decode_buf
234
227
 
@@ -273,7 +266,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
273
266
  self.prefill_cuda_graph_metadata = {} # For verify
274
267
 
275
268
  def init_forward_metadata(self, forward_batch: ForwardBatch):
276
-
277
269
  if forward_batch.forward_mode.is_decode_or_idle():
278
270
  self.indices_updater_decode.update(
279
271
  forward_batch.req_pool_indices,
@@ -331,9 +323,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
331
323
  max_num_tokens: int,
332
324
  kv_indices_buf: Optional[torch.Tensor] = None,
333
325
  ):
334
- self.cuda_graph_kv_indices = (
335
- self.kv_indices.clone() if kv_indices_buf is None else kv_indices_buf
336
- )
326
+ if kv_indices_buf is None:
327
+ cuda_graph_kv_indices = torch.zeros(
328
+ (max_bs * self.max_context_len,),
329
+ dtype=torch.int32,
330
+ device="cuda",
331
+ )
332
+ else:
333
+ cuda_graph_kv_indices = kv_indices_buf
334
+
335
+ self.cuda_graph_kv_indices = cuda_graph_kv_indices
337
336
  self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
338
337
  self.cuda_graph_kv_indptr = self.kv_indptr.clone()
339
338
  self.cuda_graph_kv_lens = torch.ones(
@@ -359,7 +358,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
359
358
  forward_mode: ForwardMode,
360
359
  spec_info: Optional[SpecInfo],
361
360
  ):
362
-
363
361
  if forward_mode.is_decode_or_idle():
364
362
  decode_wrapper = BatchMLAPagedAttentionWrapper(
365
363
  self.workspace_buffer,
@@ -370,6 +368,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
370
368
  kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
371
369
  backend="auto",
372
370
  )
371
+
373
372
  seq_lens_sum = seq_lens.sum().item()
374
373
  self.indices_updater_decode.update(
375
374
  req_pool_indices,
@@ -440,13 +439,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
440
439
  spec_info: Optional[SpecInfo],
441
440
  seq_lens_cpu: Optional[torch.Tensor],
442
441
  ):
443
-
444
442
  if forward_mode.is_decode_or_idle():
445
443
  assert seq_lens_cpu is not None
446
444
  kv_len_arr_cpu = seq_lens_cpu[:bs]
447
- num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
448
445
  self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
449
- num_pages_per_req, dim=0
446
+ kv_len_arr_cpu, dim=0
450
447
  )
451
448
  self.fast_decode_kwargs.update(
452
449
  {
@@ -455,6 +452,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
455
452
  "kv_len_arr_cpu": kv_len_arr_cpu,
456
453
  }
457
454
  )
455
+
458
456
  self.indices_updater_decode.update(
459
457
  req_pool_indices[:bs],
460
458
  seq_lens[:bs],
@@ -534,6 +532,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
534
532
  q_rope = q_rope.view(
535
533
  -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
536
534
  )
535
+
537
536
  if self.forward_metadata.use_ragged:
538
537
  # ragged prefill
539
538
  if q_rope is not None:
@@ -554,8 +553,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
554
553
  k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
555
554
  q.dtype
556
555
  )
557
- k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
558
-
559
556
  if q_rope is None:
560
557
  qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
561
558
  q, q_rope = (
@@ -617,17 +614,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
617
614
  q_nope = reshaped_q[:, :, : layer.v_head_dim]
618
615
  q_rope = reshaped_q[:, :, layer.v_head_dim :]
619
616
 
620
- k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
617
+ k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
621
618
  q.dtype
622
619
  )
623
- k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
624
620
 
625
621
  o = q_nope.new_empty(q_nope.shape)
622
+ # Direct call to run without the wrapper
626
623
  o = decode_wrapper.run(
627
624
  q_nope,
628
625
  q_rope,
629
- k_buf[:, :, : layer.v_head_dim],
630
- k_buf[:, :, layer.v_head_dim :],
626
+ k_buffer[:, :, : layer.v_head_dim],
627
+ k_buffer[:, :, layer.v_head_dim :],
631
628
  out=o,
632
629
  )
633
630
 
@@ -646,10 +643,9 @@ class FlashInferMLAIndicesUpdaterDecode:
646
643
  self.scaling = model_runner.model_config.scaling
647
644
  self.data_type = model_runner.dtype
648
645
  self.attn_backend = attn_backend
649
- self.page_size = model_runner.page_size
646
+
650
647
  # Buffers and wrappers
651
648
  self.kv_indptr = attn_backend.kv_indptr
652
- self.kv_indices = attn_backend.kv_indices
653
649
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
654
650
  self.q_indptr = attn_backend.q_indptr_decode
655
651
 
@@ -693,17 +689,13 @@ class FlashInferMLAIndicesUpdaterDecode:
693
689
  kv_lens = paged_kernel_lens.to(torch.int32)
694
690
  sm_scale = self.scaling
695
691
  if spec_info is None:
696
- num_pages_per_req = (
697
- paged_kernel_lens + self.page_size - 1
698
- ) // self.page_size
699
- kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
692
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
700
693
  kv_indptr = kv_indptr[: bs + 1]
701
694
  kv_indices = (
702
- self.kv_indices[: kv_indptr[-1]]
695
+ torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
703
696
  if not init_metadata_replay
704
697
  else fast_decode_kwargs["kv_indices"]
705
698
  )
706
-
707
699
  create_flashinfer_kv_indices_triton[(bs,)](
708
700
  self.req_to_token,
709
701
  req_pool_indices,
@@ -712,40 +704,39 @@ class FlashInferMLAIndicesUpdaterDecode:
712
704
  None,
713
705
  kv_indices,
714
706
  self.req_to_token.shape[1],
715
- self.page_size,
716
707
  )
717
708
  else:
718
709
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
719
710
 
720
711
  if not init_metadata_replay:
721
712
  wrapper.plan(
722
- qo_indptr=q_indptr,
723
- kv_indptr=kv_indptr,
724
- kv_indices=kv_indices,
725
- kv_len_arr=kv_lens,
726
- num_heads=self.num_local_heads,
727
- head_dim_ckv=self.kv_lora_rank,
728
- head_dim_kpe=self.qk_rope_head_dim,
729
- page_size=self.page_size,
730
- causal=False,
731
- sm_scale=sm_scale,
732
- q_data_type=self.data_type,
733
- kv_data_type=self.data_type,
713
+ q_indptr,
714
+ kv_indptr,
715
+ kv_indices,
716
+ kv_lens,
717
+ self.num_local_heads,
718
+ self.kv_lora_rank,
719
+ self.qk_rope_head_dim,
720
+ 1,
721
+ False,
722
+ sm_scale,
723
+ self.data_type,
724
+ self.data_type,
734
725
  )
735
726
  else:
736
727
  wrapper.plan(
737
- qo_indptr_cpu=fast_decode_kwargs["qo_indptr_cpu"],
738
- kv_indptr_cpu=fast_decode_kwargs["kv_indptr_cpu"],
739
- kv_indices=kv_indices,
740
- kv_len_arr_cpu=fast_decode_kwargs["kv_len_arr_cpu"],
741
- num_heads=self.num_local_heads,
742
- head_dim_ckv=self.kv_lora_rank,
743
- head_dim_kpe=self.qk_rope_head_dim,
744
- page_size=self.page_size,
745
- causal=False,
746
- sm_scale=sm_scale,
747
- q_data_type=self.data_type,
748
- kv_data_type=self.data_type,
728
+ fast_decode_kwargs["qo_indptr_cpu"],
729
+ fast_decode_kwargs["kv_indptr_cpu"],
730
+ kv_indices,
731
+ fast_decode_kwargs["kv_len_arr_cpu"],
732
+ self.num_local_heads,
733
+ self.kv_lora_rank,
734
+ self.qk_rope_head_dim,
735
+ 1,
736
+ False,
737
+ sm_scale,
738
+ self.data_type,
739
+ self.data_type,
749
740
  )
750
741
 
751
742
 
@@ -767,14 +758,12 @@ class FlashInferMLAIndicesUpdaterPrefill:
767
758
  # Buffers and wrappers
768
759
  self.kv_indptr = attn_backend.kv_indptr
769
760
  self.qo_indptr = attn_backend.qo_indptr
770
- self.kv_indices = attn_backend.kv_indices
771
761
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
772
762
  self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
773
- self.page_size = model_runner.page_size
774
763
 
775
764
  def update(
776
765
  self,
777
- req_pool_indices: torch.Tensor,
766
+ req_pool_indices: torch.Tnesor,
778
767
  seq_lens: torch.Tensor,
779
768
  seq_lens_sum: int,
780
769
  prefix_lens: torch.Tensor,
@@ -788,6 +777,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
788
777
  else:
789
778
  paged_kernel_lens = seq_lens
790
779
  paged_kernel_lens_sum = seq_lens_sum
780
+
791
781
  self.call_begin_forward(
792
782
  self.prefill_wrapper_ragged,
793
783
  prefill_wrapper_paged,
@@ -821,12 +811,13 @@ class FlashInferMLAIndicesUpdaterPrefill:
821
811
 
822
812
  if spec_info is None:
823
813
  assert len(seq_lens) == len(req_pool_indices)
824
- num_pages_per_req = (
825
- paged_kernel_lens + self.page_size - 1
826
- ) // self.page_size
827
- kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
814
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
828
815
  kv_indptr = kv_indptr[: bs + 1]
829
- kv_indices = self.kv_indices[: kv_indptr[-1]]
816
+ kv_indices = torch.empty(
817
+ paged_kernel_lens_sum,
818
+ dtype=torch.int32,
819
+ device=req_pool_indices.device,
820
+ )
830
821
  create_flashinfer_kv_indices_triton[(bs,)](
831
822
  self.req_to_token,
832
823
  req_pool_indices,
@@ -835,7 +826,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
835
826
  None,
836
827
  kv_indices,
837
828
  self.req_to_token.shape[1],
838
- self.page_size,
839
829
  )
840
830
  qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
841
831
  qo_indptr = qo_indptr[: bs + 1]
@@ -853,6 +843,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
853
843
  self.req_to_token,
854
844
  )
855
845
  )
846
+
856
847
  if use_ragged:
857
848
  # ragged prefill
858
849
  wrapper_ragged.begin_forward(
@@ -867,26 +858,20 @@ class FlashInferMLAIndicesUpdaterPrefill:
867
858
  )
868
859
  else:
869
860
  # mla paged prefill
870
- if spec_info is not None:
871
- assert (
872
- self.page_size == 1
873
- ), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
874
- kv_lens = kv_indptr[1:] - kv_indptr[:-1]
875
- else:
876
- kv_lens = paged_kernel_lens.to(torch.int32)
861
+ kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
877
862
  wrapper_paged.plan(
878
- qo_indptr=qo_indptr,
879
- kv_indptr=kv_indptr,
880
- kv_indices=kv_indices,
881
- kv_len_arr=kv_lens,
882
- num_heads=self.num_local_heads,
883
- head_dim_ckv=self.kv_lora_rank,
884
- head_dim_kpe=self.qk_rope_head_dim,
885
- page_size=self.page_size,
886
- causal=True,
887
- sm_scale=sm_scale,
888
- q_data_type=self.q_data_type,
889
- kv_data_type=self.data_type,
863
+ qo_indptr,
864
+ kv_indptr,
865
+ kv_indices,
866
+ kv_len_arr,
867
+ self.num_local_heads,
868
+ self.kv_lora_rank,
869
+ self.qk_rope_head_dim,
870
+ 1,
871
+ True,
872
+ sm_scale,
873
+ self.q_data_type,
874
+ self.data_type,
890
875
  )
891
876
 
892
877
 
@@ -981,7 +966,6 @@ class FlashInferMLAMultiStepDraftBackend:
981
966
  call_fn(i, forward_batch)
982
967
 
983
968
  def init_forward_metadata(self, forward_batch: ForwardBatch):
984
-
985
969
  kv_indices = torch.zeros(
986
970
  (
987
971
  self.speculative_num_steps,
@@ -1017,7 +1001,6 @@ class FlashInferMLAMultiStepDraftBackend:
1017
1001
  )
1018
1002
 
1019
1003
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
1020
-
1021
1004
  def call_fn(i, forward_batch):
1022
1005
  self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
1023
1006
  forward_batch.batch_size,
@@ -1034,7 +1017,6 @@ class FlashInferMLAMultiStepDraftBackend:
1034
1017
  def init_forward_metadata_replay_cuda_graph(
1035
1018
  self, forward_batch: ForwardBatch, bs: int
1036
1019
  ):
1037
-
1038
1020
  def call_fn(i, forward_batch):
1039
1021
  self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1040
1022
  bs,
@@ -9,89 +9,18 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
9
9
 
10
10
  @triton.jit
11
11
  def create_flashinfer_kv_indices_triton(
12
- req_to_token_ptr,
12
+ req_to_token_ptr, # [max_batch, max_context_len]
13
13
  req_pool_indices_ptr,
14
14
  page_kernel_lens_ptr,
15
15
  kv_indptr,
16
16
  kv_start_idx,
17
17
  kv_indices_ptr,
18
18
  req_to_token_ptr_stride: tl.constexpr,
19
- PAGE_SIZE: tl.constexpr = 1,
20
19
  ):
21
- """
22
- Create KV indices for FlashInfer attention backend.
23
-
24
- This Triton kernel builds a lookup table that maps from logical request/token
25
- coordinates to physical token locations in the global KV cache pool. It's used
26
- by FlashInfer attention backends to efficiently access scattered KV cache data.
27
-
28
- The kernel processes each request in parallel and converts the req_to_token
29
- lookup table into a flat list of token indices that can be used by attention kernels.
30
-
31
- general idea:
32
- blocktables/kv_indices_ptr = [batch_size * max_pages(for graph mode with
33
- fixed number of pages)]
34
- max_pages = max_context_len / PAGED_SIZE
35
- kv_indices_ptr will store the flat list of the pages used by each request
36
- Args:
37
- Inputs Arguments (non mutable):
38
-
39
- req_to_token_ptr: Request to token location look up table
40
- Shape: [max_batch, max_context_len]
41
- req_pool_indices_ptr: Request to pool index look up table. Each request uses
42
- one pool.
43
- Shape: [batch_size]
44
- page_kernel_lens_ptr: sequence lengths per request
45
- Shape: [batch_size]
46
- kv_indptr: Should be computed based on number of pages used by each request.
47
- It is used by flashinfer attention kernels to index into the kv_indices_ptr.
48
- per request.
49
- Shape: [batch_size + 1]
50
- kv_indptr[i] = start index in kv_indices for request i
51
- kv_start_idx: Pointer to array containing start offsets for each request in SGL.
52
- Can be None. If provided, adds offset to token positions.
53
-
54
- req_to_token_ptr_stride: Stride for the second dimension of req_to_token.
55
- Equal to max_context_len.
56
-
57
- PAGED_SIZE: Number of tokens per page. Default is 1 for FlashInfer.
58
-
59
- Outputs:
60
- kv_indices_ptr: Pointer to output array where KV indices will be stored.
61
- Shape:[total-num-pages],
62
- where total_num_pages = sum(seq_lens // PAGED_SIZE)
63
-
64
- Example:
65
- If we have:
66
- - req_pool_indices = [0, 1] (request 0 uses pool 0, request 1 uses pool 1)
67
- - page_kernel_lens = [3, 2] (request 0 has 3 tokens, request 1 has 2 tokens)
68
- - req_to_token = [[10, 11, 12, -1], [20, 21, -1, -1]] (tokens are the elements
69
- in radix tree, use them as a pointer to the token location in the kv_indices_ptr)
70
-
71
- The kernel will output:
72
- If PAGE_SIZE = 1:
73
- packed
74
- - kv_indptr (passed in as input arg): [0,3,5]
75
- - kv_indices = [10, 11, 12, 20, 21]
76
- padded - max_pages is 10 tokens per req
77
- - kv_indptr (passed in as input arg): [0,10, 20]
78
- - kv_indices = [10, 11, 12, -1, -1, -1, -1, -1, -1, -1,
79
- 20, 21, -1, -1, -1, -1, -1, -1, -1, -1]
80
-
81
- If PAGE_SIZE = 2
82
- packed:
83
- - kv_indptr (passed in as input arg): [0,3,4]
84
- - kv_indices = [5,6,10]
85
- padded: max_pages is 4
86
- - kv_indptr (passed in as input arg): [0,4,8,..] (note that 4 is the max_pages)
87
- - kv_indices = [5, 6, -1, -1,
88
- 10, -1, -1, -1]
89
- This allows attention kernels to directly access the correct KV cache
90
- entries for each request's tokens.
91
- """
92
20
  BLOCK_SIZE: tl.constexpr = 512
93
- NUM_PAGES_PER_BLOCK: tl.constexpr = BLOCK_SIZE // PAGE_SIZE
94
21
  pid = tl.program_id(axis=0)
22
+
23
+ # find the req pool idx, this is for batch to token
95
24
  req_pool_index = tl.load(req_pool_indices_ptr + pid)
96
25
  kv_indices_offset = tl.load(kv_indptr + pid)
97
26
 
@@ -102,27 +31,19 @@ def create_flashinfer_kv_indices_triton(
102
31
  kv_end = kv_start
103
32
  kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
104
33
 
105
- kv_range = kv_end - kv_start
106
- num_pages = tl.cdiv(kv_range, PAGE_SIZE)
107
- num_loops = tl.cdiv(kv_range, BLOCK_SIZE)
108
- req_to_token_block_start = (
109
- req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + kv_start
110
- )
111
- for i in range(num_loops):
112
- token_offsets_in_block = (
113
- tl.arange(0, NUM_PAGES_PER_BLOCK).to(tl.int64) + i * NUM_PAGES_PER_BLOCK
114
- ) * PAGE_SIZE
115
- page_offsets_in_block = token_offsets_in_block // PAGE_SIZE
116
- valid_tokens = token_offsets_in_block < kv_range
117
- valid_pages = page_offsets_in_block < num_pages
118
- token_numbers = tl.load(
119
- req_to_token_block_start + token_offsets_in_block, mask=valid_tokens
120
- )
121
- tl.store(
122
- kv_indices_ptr + kv_indices_offset + page_offsets_in_block,
123
- token_numbers // PAGE_SIZE, # write the page numbers to kv_indices_ptr
124
- mask=valid_pages,
34
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
35
+ for i in range(num_loop):
36
+ # index into req_to_token_ptr needs to be int64
37
+ offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
38
+ mask = offset < kv_end - kv_start
39
+ data = tl.load(
40
+ req_to_token_ptr
41
+ + req_pool_index * req_to_token_ptr_stride
42
+ + kv_start
43
+ + offset,
44
+ mask=mask,
125
45
  )
46
+ tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
126
47
 
127
48
 
128
49
  @triton.jit
@@ -157,10 +157,6 @@ def cutlass_fused_experts_fp8(
157
157
  rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
158
158
  rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
159
159
 
160
- if not is_sm100_supported():
161
- rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
162
- w1_scale = w1_scale.contiguous()
163
-
164
160
  c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
165
161
  c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
166
162
 
@@ -192,9 +188,6 @@ def cutlass_fused_experts_fp8(
192
188
  silu_and_mul(c1, intermediate)
193
189
 
194
190
  intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
195
- if not is_sm100_supported():
196
- a2_scale = per_group_transpose(a2_scale, expert_offsets)
197
- w2_scale = w2_scale.contiguous()
198
191
 
199
192
  fp8_blockwise_scaled_grouped_mm(
200
193
  c2,