sglang 0.5.1__py3-none-any.whl → 0.5.1.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/disaggregation/decode.py +4 -0
- sglang/srt/disaggregation/prefill.py +4 -0
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/tool.py +7 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +71 -89
- sglang/srt/layers/attention/utils.py +15 -94
- sglang/srt/layers/moe/cutlass_moe.py +0 -7
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +6 -2
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/lora/lora_manager.py +29 -12
- sglang/srt/managers/scheduler_metrics_mixin.py +15 -0
- sglang/srt/metrics/collector.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/models/grok.py +0 -4
- sglang/srt/offloader.py +115 -0
- sglang/srt/server_args.py +0 -4
- sglang/srt/utils.py +0 -7
- sglang/test/test_cutlass_moe.py +33 -28
- sglang/version.py +1 -1
- {sglang-0.5.1.dist-info → sglang-0.5.1.post2.dist-info}/METADATA +4 -4
- {sglang-0.5.1.dist-info → sglang-0.5.1.post2.dist-info}/RECORD +25 -24
- {sglang-0.5.1.dist-info → sglang-0.5.1.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.dist-info → sglang-0.5.1.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.dist-info → sglang-0.5.1.post2.dist-info}/top_level.txt +0 -0
@@ -334,6 +334,8 @@ class DecodePreallocQueue:
|
|
334
334
|
error_message,
|
335
335
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
336
336
|
)
|
337
|
+
if self.scheduler.enable_metrics:
|
338
|
+
self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
|
337
339
|
else:
|
338
340
|
raise ValueError(f"Unexpected poll case: {poll}")
|
339
341
|
|
@@ -595,6 +597,8 @@ class DecodeTransferQueue:
|
|
595
597
|
# unlock the kv cache or it will have memory leak
|
596
598
|
self.tree_cache.cache_finished_req(decode_req.req)
|
597
599
|
indices_to_remove.add(i)
|
600
|
+
if self.scheduler.enable_metrics:
|
601
|
+
self.scheduler.metrics_collector.increment_transfer_failed_reqs()
|
598
602
|
continue
|
599
603
|
elif poll == KVPoll.Success:
|
600
604
|
|
@@ -238,6 +238,8 @@ class PrefillBootstrapQueue:
|
|
238
238
|
self.scheduler.stream_output([req], req.return_logprob)
|
239
239
|
indices_to_remove.add(i)
|
240
240
|
failed_reqs.append(req)
|
241
|
+
if self.scheduler.enable_metrics:
|
242
|
+
self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
|
241
243
|
continue
|
242
244
|
|
243
245
|
# KV.WaitingForInput - init here
|
@@ -522,6 +524,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
522
524
|
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
523
525
|
)
|
524
526
|
done_reqs.append(req)
|
527
|
+
if self.enable_metrics:
|
528
|
+
self.metrics_collector.increment_transfer_failed_reqs()
|
525
529
|
else:
|
526
530
|
assert False, f"Unexpected polling state {poll=}"
|
527
531
|
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -672,7 +672,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
672
672
|
if server_args.attention_backend == "flashinfer":
|
673
673
|
assert_pkg_version(
|
674
674
|
"flashinfer_python",
|
675
|
-
"0.2.
|
675
|
+
"0.2.14.post1",
|
676
676
|
"Please uninstall the old version and "
|
677
677
|
"reinstall the latest version by following the instructions "
|
678
678
|
"at https://docs.flashinfer.ai/installation.html.",
|
sglang/srt/entrypoints/tool.py
CHANGED
@@ -4,6 +4,8 @@ import os
|
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
from typing import TYPE_CHECKING, Any
|
6
6
|
|
7
|
+
from sglang.srt.utils import print_info_once, print_warning_once
|
8
|
+
|
7
9
|
if TYPE_CHECKING:
|
8
10
|
# Avoid circular import.
|
9
11
|
from sglang.srt.entrypoints.context import ConversationContext
|
@@ -25,7 +27,7 @@ class HarmonyBrowserTool(Tool):
|
|
25
27
|
exa_api_key = os.getenv("EXA_API_KEY")
|
26
28
|
if not exa_api_key:
|
27
29
|
self.enabled = False
|
28
|
-
|
30
|
+
print_warning_once("EXA_API_KEY is not set, browsing is disabled")
|
29
31
|
return
|
30
32
|
|
31
33
|
try:
|
@@ -33,12 +35,12 @@ class HarmonyBrowserTool(Tool):
|
|
33
35
|
from gpt_oss.tools.simple_browser.backend import ExaBackend
|
34
36
|
except ImportError:
|
35
37
|
self.enabled = False
|
36
|
-
|
38
|
+
print_warning_once("gpt_oss is not installed, browsing is disabled")
|
37
39
|
return
|
38
40
|
|
39
41
|
browser_backend = ExaBackend(source="web", api_key=exa_api_key)
|
40
42
|
self.browser_tool = SimpleBrowserTool(backend=browser_backend)
|
41
|
-
|
43
|
+
print_info_once("Browser tool initialized")
|
42
44
|
|
43
45
|
async def get_result(self, context: "ConversationContext") -> Any:
|
44
46
|
from sglang.srt.entrypoints.context import HarmonyContext
|
@@ -64,13 +66,11 @@ class HarmonyPythonTool(Tool):
|
|
64
66
|
from gpt_oss.tools.python_docker.docker_tool import PythonTool
|
65
67
|
except ImportError:
|
66
68
|
self.enabled = False
|
67
|
-
|
68
|
-
"gpt_oss is not installed, code interpreter is disabled"
|
69
|
-
)
|
69
|
+
print_warning_once("gpt_oss is not installed, code interpreter is disabled")
|
70
70
|
return
|
71
71
|
|
72
72
|
self.python_tool = PythonTool()
|
73
|
-
|
73
|
+
print_info_once("Code interpreter tool initialized")
|
74
74
|
|
75
75
|
async def get_result(self, context: "ConversationContext") -> Any:
|
76
76
|
from sglang.srt.entrypoints.context import HarmonyContext
|
@@ -24,7 +24,9 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
|
24
24
|
|
25
25
|
from sglang.global_config import global_config
|
26
26
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
27
|
-
from sglang.srt.layers.attention.
|
27
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
28
|
+
create_flashinfer_kv_indices_triton,
|
29
|
+
)
|
28
30
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
29
31
|
from sglang.srt.layers.utils import is_sm100_supported
|
30
32
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -179,6 +181,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
179
181
|
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
180
182
|
):
|
181
183
|
super().__init__()
|
184
|
+
|
182
185
|
# Parse constants
|
183
186
|
self.max_context_len = model_runner.model_config.context_len
|
184
187
|
self.device = model_runner.device
|
@@ -210,25 +213,15 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
210
213
|
else:
|
211
214
|
self.kv_indptr = kv_indptr_buf
|
212
215
|
|
213
|
-
self.kv_indices = torch.empty(
|
214
|
-
(max_bs * (self.max_context_len + self.page_size - 1) // self.page_size,),
|
215
|
-
dtype=torch.int32,
|
216
|
-
device=model_runner.device,
|
217
|
-
)
|
218
|
-
|
219
216
|
if not self.skip_prefill:
|
220
217
|
self.qo_indptr = torch.zeros(
|
221
218
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
222
219
|
)
|
223
220
|
|
224
221
|
if q_indptr_decode_buf is None:
|
225
|
-
# A hack to pre-initialize large batch size for dp attention
|
226
|
-
if model_runner.server_args.enable_dp_attention:
|
227
|
-
max_bs = model_runner.server_args.dp_size * max_bs
|
228
222
|
self.q_indptr_decode = torch.arange(
|
229
223
|
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
230
224
|
)
|
231
|
-
|
232
225
|
else:
|
233
226
|
self.q_indptr_decode = q_indptr_decode_buf
|
234
227
|
|
@@ -273,7 +266,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
273
266
|
self.prefill_cuda_graph_metadata = {} # For verify
|
274
267
|
|
275
268
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
276
|
-
|
277
269
|
if forward_batch.forward_mode.is_decode_or_idle():
|
278
270
|
self.indices_updater_decode.update(
|
279
271
|
forward_batch.req_pool_indices,
|
@@ -331,9 +323,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
331
323
|
max_num_tokens: int,
|
332
324
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
333
325
|
):
|
334
|
-
|
335
|
-
|
336
|
-
|
326
|
+
if kv_indices_buf is None:
|
327
|
+
cuda_graph_kv_indices = torch.zeros(
|
328
|
+
(max_bs * self.max_context_len,),
|
329
|
+
dtype=torch.int32,
|
330
|
+
device="cuda",
|
331
|
+
)
|
332
|
+
else:
|
333
|
+
cuda_graph_kv_indices = kv_indices_buf
|
334
|
+
|
335
|
+
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
337
336
|
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
|
338
337
|
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
|
339
338
|
self.cuda_graph_kv_lens = torch.ones(
|
@@ -359,7 +358,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
359
358
|
forward_mode: ForwardMode,
|
360
359
|
spec_info: Optional[SpecInfo],
|
361
360
|
):
|
362
|
-
|
363
361
|
if forward_mode.is_decode_or_idle():
|
364
362
|
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
365
363
|
self.workspace_buffer,
|
@@ -370,6 +368,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
370
368
|
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
|
371
369
|
backend="auto",
|
372
370
|
)
|
371
|
+
|
373
372
|
seq_lens_sum = seq_lens.sum().item()
|
374
373
|
self.indices_updater_decode.update(
|
375
374
|
req_pool_indices,
|
@@ -440,13 +439,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
440
439
|
spec_info: Optional[SpecInfo],
|
441
440
|
seq_lens_cpu: Optional[torch.Tensor],
|
442
441
|
):
|
443
|
-
|
444
442
|
if forward_mode.is_decode_or_idle():
|
445
443
|
assert seq_lens_cpu is not None
|
446
444
|
kv_len_arr_cpu = seq_lens_cpu[:bs]
|
447
|
-
num_pages_per_req = (seq_lens_cpu + self.page_size - 1) // self.page_size
|
448
445
|
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
|
449
|
-
|
446
|
+
kv_len_arr_cpu, dim=0
|
450
447
|
)
|
451
448
|
self.fast_decode_kwargs.update(
|
452
449
|
{
|
@@ -455,6 +452,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
455
452
|
"kv_len_arr_cpu": kv_len_arr_cpu,
|
456
453
|
}
|
457
454
|
)
|
455
|
+
|
458
456
|
self.indices_updater_decode.update(
|
459
457
|
req_pool_indices[:bs],
|
460
458
|
seq_lens[:bs],
|
@@ -534,6 +532,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
534
532
|
q_rope = q_rope.view(
|
535
533
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
536
534
|
)
|
535
|
+
|
537
536
|
if self.forward_metadata.use_ragged:
|
538
537
|
# ragged prefill
|
539
538
|
if q_rope is not None:
|
@@ -554,8 +553,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
554
553
|
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
555
554
|
q.dtype
|
556
555
|
)
|
557
|
-
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
|
558
|
-
|
559
556
|
if q_rope is None:
|
560
557
|
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
561
558
|
q, q_rope = (
|
@@ -617,17 +614,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
617
614
|
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
618
615
|
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
619
616
|
|
620
|
-
|
617
|
+
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
621
618
|
q.dtype
|
622
619
|
)
|
623
|
-
k_buf = k_buf.view(-1, self.page_size, k_buf.shape[-1])
|
624
620
|
|
625
621
|
o = q_nope.new_empty(q_nope.shape)
|
622
|
+
# Direct call to run without the wrapper
|
626
623
|
o = decode_wrapper.run(
|
627
624
|
q_nope,
|
628
625
|
q_rope,
|
629
|
-
|
630
|
-
|
626
|
+
k_buffer[:, :, : layer.v_head_dim],
|
627
|
+
k_buffer[:, :, layer.v_head_dim :],
|
631
628
|
out=o,
|
632
629
|
)
|
633
630
|
|
@@ -646,10 +643,9 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
646
643
|
self.scaling = model_runner.model_config.scaling
|
647
644
|
self.data_type = model_runner.dtype
|
648
645
|
self.attn_backend = attn_backend
|
649
|
-
|
646
|
+
|
650
647
|
# Buffers and wrappers
|
651
648
|
self.kv_indptr = attn_backend.kv_indptr
|
652
|
-
self.kv_indices = attn_backend.kv_indices
|
653
649
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
654
650
|
self.q_indptr = attn_backend.q_indptr_decode
|
655
651
|
|
@@ -693,17 +689,13 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
693
689
|
kv_lens = paged_kernel_lens.to(torch.int32)
|
694
690
|
sm_scale = self.scaling
|
695
691
|
if spec_info is None:
|
696
|
-
|
697
|
-
paged_kernel_lens + self.page_size - 1
|
698
|
-
) // self.page_size
|
699
|
-
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
|
692
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
700
693
|
kv_indptr = kv_indptr[: bs + 1]
|
701
694
|
kv_indices = (
|
702
|
-
|
695
|
+
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
|
703
696
|
if not init_metadata_replay
|
704
697
|
else fast_decode_kwargs["kv_indices"]
|
705
698
|
)
|
706
|
-
|
707
699
|
create_flashinfer_kv_indices_triton[(bs,)](
|
708
700
|
self.req_to_token,
|
709
701
|
req_pool_indices,
|
@@ -712,40 +704,39 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
712
704
|
None,
|
713
705
|
kv_indices,
|
714
706
|
self.req_to_token.shape[1],
|
715
|
-
self.page_size,
|
716
707
|
)
|
717
708
|
else:
|
718
709
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
719
710
|
|
720
711
|
if not init_metadata_replay:
|
721
712
|
wrapper.plan(
|
722
|
-
|
723
|
-
kv_indptr
|
724
|
-
kv_indices
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
sm_scale
|
732
|
-
|
733
|
-
|
713
|
+
q_indptr,
|
714
|
+
kv_indptr,
|
715
|
+
kv_indices,
|
716
|
+
kv_lens,
|
717
|
+
self.num_local_heads,
|
718
|
+
self.kv_lora_rank,
|
719
|
+
self.qk_rope_head_dim,
|
720
|
+
1,
|
721
|
+
False,
|
722
|
+
sm_scale,
|
723
|
+
self.data_type,
|
724
|
+
self.data_type,
|
734
725
|
)
|
735
726
|
else:
|
736
727
|
wrapper.plan(
|
737
|
-
|
738
|
-
|
739
|
-
kv_indices
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
sm_scale
|
747
|
-
|
748
|
-
|
728
|
+
fast_decode_kwargs["qo_indptr_cpu"],
|
729
|
+
fast_decode_kwargs["kv_indptr_cpu"],
|
730
|
+
kv_indices,
|
731
|
+
fast_decode_kwargs["kv_len_arr_cpu"],
|
732
|
+
self.num_local_heads,
|
733
|
+
self.kv_lora_rank,
|
734
|
+
self.qk_rope_head_dim,
|
735
|
+
1,
|
736
|
+
False,
|
737
|
+
sm_scale,
|
738
|
+
self.data_type,
|
739
|
+
self.data_type,
|
749
740
|
)
|
750
741
|
|
751
742
|
|
@@ -767,14 +758,12 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
767
758
|
# Buffers and wrappers
|
768
759
|
self.kv_indptr = attn_backend.kv_indptr
|
769
760
|
self.qo_indptr = attn_backend.qo_indptr
|
770
|
-
self.kv_indices = attn_backend.kv_indices
|
771
761
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
772
762
|
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
773
|
-
self.page_size = model_runner.page_size
|
774
763
|
|
775
764
|
def update(
|
776
765
|
self,
|
777
|
-
req_pool_indices: torch.
|
766
|
+
req_pool_indices: torch.Tnesor,
|
778
767
|
seq_lens: torch.Tensor,
|
779
768
|
seq_lens_sum: int,
|
780
769
|
prefix_lens: torch.Tensor,
|
@@ -788,6 +777,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
788
777
|
else:
|
789
778
|
paged_kernel_lens = seq_lens
|
790
779
|
paged_kernel_lens_sum = seq_lens_sum
|
780
|
+
|
791
781
|
self.call_begin_forward(
|
792
782
|
self.prefill_wrapper_ragged,
|
793
783
|
prefill_wrapper_paged,
|
@@ -821,12 +811,13 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
821
811
|
|
822
812
|
if spec_info is None:
|
823
813
|
assert len(seq_lens) == len(req_pool_indices)
|
824
|
-
|
825
|
-
paged_kernel_lens + self.page_size - 1
|
826
|
-
) // self.page_size
|
827
|
-
kv_indptr[1 : bs + 1] = torch.cumsum(num_pages_per_req, dim=0)
|
814
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
828
815
|
kv_indptr = kv_indptr[: bs + 1]
|
829
|
-
kv_indices =
|
816
|
+
kv_indices = torch.empty(
|
817
|
+
paged_kernel_lens_sum,
|
818
|
+
dtype=torch.int32,
|
819
|
+
device=req_pool_indices.device,
|
820
|
+
)
|
830
821
|
create_flashinfer_kv_indices_triton[(bs,)](
|
831
822
|
self.req_to_token,
|
832
823
|
req_pool_indices,
|
@@ -835,7 +826,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
835
826
|
None,
|
836
827
|
kv_indices,
|
837
828
|
self.req_to_token.shape[1],
|
838
|
-
self.page_size,
|
839
829
|
)
|
840
830
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
841
831
|
qo_indptr = qo_indptr[: bs + 1]
|
@@ -853,6 +843,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
853
843
|
self.req_to_token,
|
854
844
|
)
|
855
845
|
)
|
846
|
+
|
856
847
|
if use_ragged:
|
857
848
|
# ragged prefill
|
858
849
|
wrapper_ragged.begin_forward(
|
@@ -867,26 +858,20 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
867
858
|
)
|
868
859
|
else:
|
869
860
|
# mla paged prefill
|
870
|
-
|
871
|
-
assert (
|
872
|
-
self.page_size == 1
|
873
|
-
), "Only page_size=1 is supported for flashinfer backend with speculative decoding"
|
874
|
-
kv_lens = kv_indptr[1:] - kv_indptr[:-1]
|
875
|
-
else:
|
876
|
-
kv_lens = paged_kernel_lens.to(torch.int32)
|
861
|
+
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
|
877
862
|
wrapper_paged.plan(
|
878
|
-
qo_indptr
|
879
|
-
kv_indptr
|
880
|
-
kv_indices
|
881
|
-
kv_len_arr
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
|
887
|
-
sm_scale
|
888
|
-
|
889
|
-
|
863
|
+
qo_indptr,
|
864
|
+
kv_indptr,
|
865
|
+
kv_indices,
|
866
|
+
kv_len_arr,
|
867
|
+
self.num_local_heads,
|
868
|
+
self.kv_lora_rank,
|
869
|
+
self.qk_rope_head_dim,
|
870
|
+
1,
|
871
|
+
True,
|
872
|
+
sm_scale,
|
873
|
+
self.q_data_type,
|
874
|
+
self.data_type,
|
890
875
|
)
|
891
876
|
|
892
877
|
|
@@ -981,7 +966,6 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
981
966
|
call_fn(i, forward_batch)
|
982
967
|
|
983
968
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
984
|
-
|
985
969
|
kv_indices = torch.zeros(
|
986
970
|
(
|
987
971
|
self.speculative_num_steps,
|
@@ -1017,7 +1001,6 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
1017
1001
|
)
|
1018
1002
|
|
1019
1003
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
1020
|
-
|
1021
1004
|
def call_fn(i, forward_batch):
|
1022
1005
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
1023
1006
|
forward_batch.batch_size,
|
@@ -1034,7 +1017,6 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
1034
1017
|
def init_forward_metadata_replay_cuda_graph(
|
1035
1018
|
self, forward_batch: ForwardBatch, bs: int
|
1036
1019
|
):
|
1037
|
-
|
1038
1020
|
def call_fn(i, forward_batch):
|
1039
1021
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
1040
1022
|
bs,
|
@@ -9,89 +9,18 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
|
|
9
9
|
|
10
10
|
@triton.jit
|
11
11
|
def create_flashinfer_kv_indices_triton(
|
12
|
-
req_to_token_ptr,
|
12
|
+
req_to_token_ptr, # [max_batch, max_context_len]
|
13
13
|
req_pool_indices_ptr,
|
14
14
|
page_kernel_lens_ptr,
|
15
15
|
kv_indptr,
|
16
16
|
kv_start_idx,
|
17
17
|
kv_indices_ptr,
|
18
18
|
req_to_token_ptr_stride: tl.constexpr,
|
19
|
-
PAGE_SIZE: tl.constexpr = 1,
|
20
19
|
):
|
21
|
-
"""
|
22
|
-
Create KV indices for FlashInfer attention backend.
|
23
|
-
|
24
|
-
This Triton kernel builds a lookup table that maps from logical request/token
|
25
|
-
coordinates to physical token locations in the global KV cache pool. It's used
|
26
|
-
by FlashInfer attention backends to efficiently access scattered KV cache data.
|
27
|
-
|
28
|
-
The kernel processes each request in parallel and converts the req_to_token
|
29
|
-
lookup table into a flat list of token indices that can be used by attention kernels.
|
30
|
-
|
31
|
-
general idea:
|
32
|
-
blocktables/kv_indices_ptr = [batch_size * max_pages(for graph mode with
|
33
|
-
fixed number of pages)]
|
34
|
-
max_pages = max_context_len / PAGED_SIZE
|
35
|
-
kv_indices_ptr will store the flat list of the pages used by each request
|
36
|
-
Args:
|
37
|
-
Inputs Arguments (non mutable):
|
38
|
-
|
39
|
-
req_to_token_ptr: Request to token location look up table
|
40
|
-
Shape: [max_batch, max_context_len]
|
41
|
-
req_pool_indices_ptr: Request to pool index look up table. Each request uses
|
42
|
-
one pool.
|
43
|
-
Shape: [batch_size]
|
44
|
-
page_kernel_lens_ptr: sequence lengths per request
|
45
|
-
Shape: [batch_size]
|
46
|
-
kv_indptr: Should be computed based on number of pages used by each request.
|
47
|
-
It is used by flashinfer attention kernels to index into the kv_indices_ptr.
|
48
|
-
per request.
|
49
|
-
Shape: [batch_size + 1]
|
50
|
-
kv_indptr[i] = start index in kv_indices for request i
|
51
|
-
kv_start_idx: Pointer to array containing start offsets for each request in SGL.
|
52
|
-
Can be None. If provided, adds offset to token positions.
|
53
|
-
|
54
|
-
req_to_token_ptr_stride: Stride for the second dimension of req_to_token.
|
55
|
-
Equal to max_context_len.
|
56
|
-
|
57
|
-
PAGED_SIZE: Number of tokens per page. Default is 1 for FlashInfer.
|
58
|
-
|
59
|
-
Outputs:
|
60
|
-
kv_indices_ptr: Pointer to output array where KV indices will be stored.
|
61
|
-
Shape:[total-num-pages],
|
62
|
-
where total_num_pages = sum(seq_lens // PAGED_SIZE)
|
63
|
-
|
64
|
-
Example:
|
65
|
-
If we have:
|
66
|
-
- req_pool_indices = [0, 1] (request 0 uses pool 0, request 1 uses pool 1)
|
67
|
-
- page_kernel_lens = [3, 2] (request 0 has 3 tokens, request 1 has 2 tokens)
|
68
|
-
- req_to_token = [[10, 11, 12, -1], [20, 21, -1, -1]] (tokens are the elements
|
69
|
-
in radix tree, use them as a pointer to the token location in the kv_indices_ptr)
|
70
|
-
|
71
|
-
The kernel will output:
|
72
|
-
If PAGE_SIZE = 1:
|
73
|
-
packed
|
74
|
-
- kv_indptr (passed in as input arg): [0,3,5]
|
75
|
-
- kv_indices = [10, 11, 12, 20, 21]
|
76
|
-
padded - max_pages is 10 tokens per req
|
77
|
-
- kv_indptr (passed in as input arg): [0,10, 20]
|
78
|
-
- kv_indices = [10, 11, 12, -1, -1, -1, -1, -1, -1, -1,
|
79
|
-
20, 21, -1, -1, -1, -1, -1, -1, -1, -1]
|
80
|
-
|
81
|
-
If PAGE_SIZE = 2
|
82
|
-
packed:
|
83
|
-
- kv_indptr (passed in as input arg): [0,3,4]
|
84
|
-
- kv_indices = [5,6,10]
|
85
|
-
padded: max_pages is 4
|
86
|
-
- kv_indptr (passed in as input arg): [0,4,8,..] (note that 4 is the max_pages)
|
87
|
-
- kv_indices = [5, 6, -1, -1,
|
88
|
-
10, -1, -1, -1]
|
89
|
-
This allows attention kernels to directly access the correct KV cache
|
90
|
-
entries for each request's tokens.
|
91
|
-
"""
|
92
20
|
BLOCK_SIZE: tl.constexpr = 512
|
93
|
-
NUM_PAGES_PER_BLOCK: tl.constexpr = BLOCK_SIZE // PAGE_SIZE
|
94
21
|
pid = tl.program_id(axis=0)
|
22
|
+
|
23
|
+
# find the req pool idx, this is for batch to token
|
95
24
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
96
25
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
97
26
|
|
@@ -102,27 +31,19 @@ def create_flashinfer_kv_indices_triton(
|
|
102
31
|
kv_end = kv_start
|
103
32
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
104
33
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
valid_tokens = token_offsets_in_block < kv_range
|
117
|
-
valid_pages = page_offsets_in_block < num_pages
|
118
|
-
token_numbers = tl.load(
|
119
|
-
req_to_token_block_start + token_offsets_in_block, mask=valid_tokens
|
120
|
-
)
|
121
|
-
tl.store(
|
122
|
-
kv_indices_ptr + kv_indices_offset + page_offsets_in_block,
|
123
|
-
token_numbers // PAGE_SIZE, # write the page numbers to kv_indices_ptr
|
124
|
-
mask=valid_pages,
|
34
|
+
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
35
|
+
for i in range(num_loop):
|
36
|
+
# index into req_to_token_ptr needs to be int64
|
37
|
+
offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
|
38
|
+
mask = offset < kv_end - kv_start
|
39
|
+
data = tl.load(
|
40
|
+
req_to_token_ptr
|
41
|
+
+ req_pool_index * req_to_token_ptr_stride
|
42
|
+
+ kv_start
|
43
|
+
+ offset,
|
44
|
+
mask=mask,
|
125
45
|
)
|
46
|
+
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
126
47
|
|
127
48
|
|
128
49
|
@triton.jit
|
@@ -157,10 +157,6 @@ def cutlass_fused_experts_fp8(
|
|
157
157
|
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
158
158
|
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
159
159
|
|
160
|
-
if not is_sm100_supported():
|
161
|
-
rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
|
162
|
-
w1_scale = w1_scale.contiguous()
|
163
|
-
|
164
160
|
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
165
161
|
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
166
162
|
|
@@ -192,9 +188,6 @@ def cutlass_fused_experts_fp8(
|
|
192
188
|
silu_and_mul(c1, intermediate)
|
193
189
|
|
194
190
|
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
195
|
-
if not is_sm100_supported():
|
196
|
-
a2_scale = per_group_transpose(a2_scale, expert_offsets)
|
197
|
-
w2_scale = w2_scale.contiguous()
|
198
191
|
|
199
192
|
fp8_blockwise_scaled_grouped_mm(
|
200
193
|
c2,
|