sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
629
629
|
# For multi-head latent attention
|
630
630
|
q_rope: Optional[torch.Tensor] = None,
|
631
631
|
k_rope: Optional[torch.Tensor] = None,
|
632
|
+
sinks: Optional[torch.Tensor] = None,
|
632
633
|
):
|
633
634
|
if k is not None:
|
634
635
|
assert v is not None
|
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
687
688
|
forward_batch.forward_mode.is_target_verify() and self.topk > 1
|
688
689
|
)
|
689
690
|
|
691
|
+
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
692
|
+
kwargs = {}
|
693
|
+
if sinks is not None:
|
694
|
+
kwargs["sinks"] = sinks
|
695
|
+
|
690
696
|
# Get the appropriate page table based on whether we're using local attention
|
691
697
|
if use_local_attn:
|
692
698
|
local_metadata = metadata.local_attn_metadata
|
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
737
743
|
k_descale=k_descale,
|
738
744
|
v_descale=v_descale,
|
739
745
|
return_softmax_lse=use_cascade_attn,
|
746
|
+
**kwargs,
|
740
747
|
)
|
741
748
|
|
742
749
|
if use_cascade_attn:
|
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
757
764
|
k_descale=k_descale,
|
758
765
|
v_descale=v_descale,
|
759
766
|
return_softmax_lse=True,
|
767
|
+
**kwargs,
|
760
768
|
)
|
761
769
|
o, _ = merge_state_v2_wrapper(
|
762
770
|
o,
|
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
898
906
|
# For multi-head latent attention
|
899
907
|
q_rope: Optional[torch.Tensor] = None,
|
900
908
|
k_rope: Optional[torch.Tensor] = None,
|
909
|
+
sinks: Optional[torch.Tensor] = None,
|
901
910
|
) -> torch.Tensor:
|
902
911
|
if k is not None:
|
903
912
|
assert v is not None
|
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
943
952
|
)
|
944
953
|
causal = not layer.is_cross_attention
|
945
954
|
|
955
|
+
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
956
|
+
kwargs = {}
|
957
|
+
if sinks is not None:
|
958
|
+
kwargs["sinks"] = sinks
|
959
|
+
|
946
960
|
k_descale, v_descale = None, None
|
947
961
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
948
962
|
# has corresponding quantization method so that layer.k_scale is not None,
|
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
985
999
|
softcap=layer.logit_cap,
|
986
1000
|
k_descale=k_descale,
|
987
1001
|
v_descale=v_descale,
|
1002
|
+
**kwargs,
|
988
1003
|
)
|
989
1004
|
elif use_local_attn:
|
990
1005
|
# Use chunked (local) attention batching for self-attention
|
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1003
1018
|
softcap=layer.logit_cap,
|
1004
1019
|
k_descale=k_descale,
|
1005
1020
|
v_descale=v_descale,
|
1021
|
+
**kwargs,
|
1006
1022
|
)
|
1007
1023
|
else:
|
1008
1024
|
page_table = metadata.page_table
|
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1030
1046
|
k_descale=k_descale,
|
1031
1047
|
v_descale=v_descale,
|
1032
1048
|
return_softmax_lse=use_cascade_attn,
|
1049
|
+
**kwargs,
|
1033
1050
|
)
|
1034
1051
|
if use_cascade_attn:
|
1035
1052
|
o, softmax_lse, *rest = result
|
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1050
1067
|
k_descale=k_descale,
|
1051
1068
|
v_descale=v_descale,
|
1052
1069
|
return_softmax_lse=True,
|
1070
|
+
**kwargs,
|
1053
1071
|
)
|
1054
1072
|
)
|
1055
1073
|
o, _ = merge_state_v2(
|
@@ -66,6 +66,10 @@ class PrefillMetadata:
|
|
66
66
|
# Reuse this workspace buffer across all flashinfer wrappers
|
67
67
|
global_workspace_buffer = None
|
68
68
|
|
69
|
+
# Use as a fast path to override the indptr in flashinfer's plan function
|
70
|
+
# This is used to remove some host-to-device copy overhead.
|
71
|
+
global_override_indptr_cpu = None
|
72
|
+
|
69
73
|
|
70
74
|
class FlashInferAttnBackend(AttentionBackend):
|
71
75
|
"""Flashinfer attention kernels."""
|
@@ -205,6 +209,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
205
209
|
self.indices_updater_decode.update(
|
206
210
|
forward_batch.req_pool_indices,
|
207
211
|
forward_batch.seq_lens,
|
212
|
+
forward_batch.seq_lens_cpu,
|
208
213
|
forward_batch.seq_lens_sum,
|
209
214
|
decode_wrappers=self.decode_wrappers,
|
210
215
|
encoder_lens=forward_batch.encoder_lens,
|
@@ -215,6 +220,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
215
220
|
self.indices_updater_prefill.update(
|
216
221
|
forward_batch.req_pool_indices,
|
217
222
|
forward_batch.seq_lens,
|
223
|
+
forward_batch.seq_lens_cpu,
|
218
224
|
forward_batch.seq_lens_sum,
|
219
225
|
prefix_lens=None,
|
220
226
|
prefill_wrappers=self.prefill_wrappers_paged,
|
@@ -229,6 +235,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
229
235
|
self.indices_updater_prefill.update(
|
230
236
|
forward_batch.req_pool_indices,
|
231
237
|
forward_batch.seq_lens,
|
238
|
+
forward_batch.seq_lens_cpu,
|
232
239
|
forward_batch.seq_lens_sum,
|
233
240
|
prefix_lens=None,
|
234
241
|
prefill_wrappers=self.prefill_wrappers_verify,
|
@@ -252,6 +259,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
252
259
|
self.indices_updater_prefill.update(
|
253
260
|
forward_batch.req_pool_indices,
|
254
261
|
forward_batch.seq_lens,
|
262
|
+
forward_batch.seq_lens_cpu,
|
255
263
|
forward_batch.seq_lens_sum,
|
256
264
|
prefix_lens,
|
257
265
|
prefill_wrappers=self.prefill_wrappers_paged,
|
@@ -327,6 +335,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
327
335
|
self.indices_updater_decode.update(
|
328
336
|
req_pool_indices,
|
329
337
|
seq_lens,
|
338
|
+
seq_lens.cpu(), # may add a little overhead in capture stage
|
330
339
|
seq_lens_sum,
|
331
340
|
decode_wrappers=decode_wrappers,
|
332
341
|
encoder_lens=encoder_lens,
|
@@ -358,6 +367,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
358
367
|
self.indices_updater_prefill.update(
|
359
368
|
req_pool_indices,
|
360
369
|
seq_lens,
|
370
|
+
seq_lens.cpu(), # may add a little overhead in capture stage
|
361
371
|
seq_lens_sum,
|
362
372
|
prefix_lens=None,
|
363
373
|
prefill_wrappers=prefill_wrappers,
|
@@ -387,6 +397,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
387
397
|
self.indices_updater_prefill.update(
|
388
398
|
req_pool_indices,
|
389
399
|
seq_lens,
|
400
|
+
seq_lens.cpu(), # may add a little overhead in capture stage
|
390
401
|
seq_lens_sum,
|
391
402
|
prefix_lens=None,
|
392
403
|
prefill_wrappers=prefill_wrappers,
|
@@ -414,6 +425,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
414
425
|
self.indices_updater_decode.update(
|
415
426
|
req_pool_indices[:bs],
|
416
427
|
seq_lens[:bs],
|
428
|
+
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
|
417
429
|
seq_lens_sum,
|
418
430
|
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
419
431
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
@@ -423,6 +435,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
423
435
|
self.indices_updater_prefill.update(
|
424
436
|
req_pool_indices[:bs],
|
425
437
|
seq_lens[:bs],
|
438
|
+
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
|
426
439
|
seq_lens_sum,
|
427
440
|
prefix_lens=None,
|
428
441
|
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
@@ -434,6 +447,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
434
447
|
self.indices_updater_prefill.update(
|
435
448
|
req_pool_indices[:bs],
|
436
449
|
seq_lens[:bs],
|
450
|
+
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
|
437
451
|
seq_lens_sum,
|
438
452
|
prefix_lens=None,
|
439
453
|
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
@@ -581,7 +595,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
581
595
|
|
582
596
|
|
583
597
|
class FlashInferIndicesUpdaterDecode:
|
584
|
-
def __init__(self, model_runner: ModelRunner, attn_backend:
|
598
|
+
def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
|
585
599
|
# Parse Constants
|
586
600
|
self.num_qo_heads = (
|
587
601
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
@@ -614,6 +628,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
614
628
|
self,
|
615
629
|
req_pool_indices: torch.Tensor,
|
616
630
|
seq_lens: torch.Tensor,
|
631
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
617
632
|
seq_lens_sum: int,
|
618
633
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
619
634
|
encoder_lens: Optional[torch.Tensor],
|
@@ -626,6 +641,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
626
641
|
self,
|
627
642
|
req_pool_indices: torch.Tensor,
|
628
643
|
seq_lens: torch.Tensor,
|
644
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
629
645
|
seq_lens_sum: int,
|
630
646
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
631
647
|
encoder_lens: Optional[torch.Tensor],
|
@@ -640,30 +656,39 @@ class FlashInferIndicesUpdaterDecode:
|
|
640
656
|
self.kv_indptr[0],
|
641
657
|
None,
|
642
658
|
spec_info,
|
659
|
+
seq_lens_cpu,
|
643
660
|
)
|
644
661
|
|
645
662
|
def update_sliding_window(
|
646
663
|
self,
|
647
664
|
req_pool_indices: torch.Tensor,
|
648
665
|
seq_lens: torch.Tensor,
|
666
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
649
667
|
seq_lens_sum: int,
|
650
668
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
651
669
|
encoder_lens: Optional[torch.Tensor],
|
652
670
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
653
671
|
):
|
672
|
+
assert self.sliding_window_size is not None
|
654
673
|
for wrapper_id in range(2):
|
655
674
|
if wrapper_id == 0:
|
656
675
|
# Sliding window attention
|
657
|
-
paged_kernel_lens_tmp = torch.
|
658
|
-
seq_lens,
|
659
|
-
torch.tensor(self.sliding_window_size + 1),
|
676
|
+
paged_kernel_lens_tmp = torch.clamp(
|
677
|
+
seq_lens, max=self.sliding_window_size + 1
|
660
678
|
)
|
661
|
-
|
679
|
+
if seq_lens_cpu is not None:
|
680
|
+
seq_lens_cpu_tmp = torch.clamp(
|
681
|
+
seq_lens_cpu, max=self.sliding_window_size + 1
|
682
|
+
)
|
683
|
+
paged_kernel_lens_sum_tmp = seq_lens_cpu_tmp.sum().item()
|
684
|
+
else:
|
685
|
+
paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
|
662
686
|
kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
|
663
687
|
else:
|
664
688
|
# Full attention
|
665
689
|
paged_kernel_lens_tmp = seq_lens
|
666
690
|
paged_kernel_lens_sum_tmp = seq_lens_sum
|
691
|
+
seq_lens_cpu_tmp = seq_lens_cpu
|
667
692
|
kv_start_idx_tmp = None
|
668
693
|
|
669
694
|
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
|
@@ -678,6 +703,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
678
703
|
self.kv_indptr[wrapper_id],
|
679
704
|
kv_start_idx_tmp,
|
680
705
|
spec_info,
|
706
|
+
seq_lens_cpu=seq_lens_cpu_tmp,
|
681
707
|
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
682
708
|
)
|
683
709
|
|
@@ -685,6 +711,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
685
711
|
self,
|
686
712
|
req_pool_indices: torch.Tensor,
|
687
713
|
seq_lens: torch.Tensor,
|
714
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
688
715
|
seq_lens_sum: int,
|
689
716
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
690
717
|
encoder_lens: Optional[torch.Tensor],
|
@@ -709,6 +736,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
709
736
|
self.kv_indptr[wrapper_id],
|
710
737
|
kv_start_idx,
|
711
738
|
spec_info,
|
739
|
+
seq_lens_cpu=seq_lens_cpu,
|
712
740
|
)
|
713
741
|
|
714
742
|
def call_begin_forward(
|
@@ -720,6 +748,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
720
748
|
kv_indptr: torch.Tensor,
|
721
749
|
kv_start_idx: torch.Tensor,
|
722
750
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
751
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
723
752
|
use_sliding_window_kv_pool: bool = False,
|
724
753
|
):
|
725
754
|
if spec_info is None:
|
@@ -756,6 +785,14 @@ class FlashInferIndicesUpdaterDecode:
|
|
756
785
|
)
|
757
786
|
)
|
758
787
|
|
788
|
+
global global_override_indptr_cpu
|
789
|
+
locally_override = False
|
790
|
+
if seq_lens_cpu is not None and global_override_indptr_cpu is None:
|
791
|
+
locally_override = True
|
792
|
+
global_override_indptr_cpu = torch.empty_like(kv_indptr, device="cpu")
|
793
|
+
global_override_indptr_cpu[0] = 0
|
794
|
+
global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
|
795
|
+
|
759
796
|
wrapper.begin_forward(
|
760
797
|
kv_indptr,
|
761
798
|
kv_indices,
|
@@ -769,9 +806,12 @@ class FlashInferIndicesUpdaterDecode:
|
|
769
806
|
non_blocking=True,
|
770
807
|
)
|
771
808
|
|
809
|
+
if locally_override:
|
810
|
+
global_override_indptr_cpu = None
|
811
|
+
|
772
812
|
|
773
813
|
class FlashInferIndicesUpdaterPrefill:
|
774
|
-
def __init__(self, model_runner: ModelRunner, attn_backend:
|
814
|
+
def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
|
775
815
|
# Parse Constants
|
776
816
|
self.num_qo_heads = (
|
777
817
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
@@ -806,6 +846,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
806
846
|
self,
|
807
847
|
req_pool_indices: torch.Tensor,
|
808
848
|
seq_lens: torch.Tensor,
|
849
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
809
850
|
seq_lens_sum: int,
|
810
851
|
prefix_lens: torch.Tensor,
|
811
852
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
@@ -820,6 +861,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
820
861
|
self,
|
821
862
|
req_pool_indices: torch.Tensor,
|
822
863
|
seq_lens: torch.Tensor,
|
864
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
823
865
|
seq_lens_sum: int,
|
824
866
|
prefix_lens: torch.Tensor,
|
825
867
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
@@ -853,6 +895,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
853
895
|
self,
|
854
896
|
req_pool_indices: torch.Tensor,
|
855
897
|
seq_lens: torch.Tensor,
|
898
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
856
899
|
seq_lens_sum: int,
|
857
900
|
prefix_lens: torch.Tensor,
|
858
901
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
@@ -898,6 +941,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
898
941
|
self,
|
899
942
|
req_pool_indices: torch.Tensor,
|
900
943
|
seq_lens: torch.Tensor,
|
944
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
901
945
|
seq_lens_sum: int,
|
902
946
|
prefix_lens: torch.Tensor,
|
903
947
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
@@ -1020,11 +1064,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1020
1064
|
)
|
1021
1065
|
|
1022
1066
|
|
1023
|
-
# Use as a fast path to override the indptr in flashinfer's plan function
|
1024
|
-
# This is used to remove some host-to-device copy overhead.
|
1025
|
-
global global_override_indptr_cpu
|
1026
|
-
|
1027
|
-
|
1028
1067
|
class FlashInferMultiStepDraftBackend:
|
1029
1068
|
"""
|
1030
1069
|
Wrap multiple flashinfer attention backends as one for multiple consecutive
|
@@ -1056,7 +1095,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1056
1095
|
self.kv_last_page_len = torch.ones(
|
1057
1096
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
1058
1097
|
)
|
1059
|
-
self.attn_backends = []
|
1098
|
+
self.attn_backends: List[FlashInferAttnBackend] = []
|
1060
1099
|
for i in range(self.speculative_num_steps):
|
1061
1100
|
self.attn_backends.append(
|
1062
1101
|
FlashInferAttnBackend(
|
@@ -1176,7 +1215,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1176
1215
|
encoder_lens=None,
|
1177
1216
|
forward_mode=ForwardMode.DECODE,
|
1178
1217
|
spec_info=forward_batch.spec_info,
|
1179
|
-
seq_lens_cpu=
|
1218
|
+
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
1180
1219
|
)
|
1181
1220
|
|
1182
1221
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
88
88
|
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
|
89
89
|
|
90
90
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
91
|
+
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
91
92
|
|
92
93
|
if not self.skip_prefill:
|
93
94
|
self.qo_indptr = torch.zeros(
|
@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
197
198
|
forward_batch.req_pool_indices,
|
198
199
|
bs,
|
199
200
|
self.device,
|
201
|
+
self.token_to_kv_pool_allocator,
|
200
202
|
)
|
201
203
|
)
|
202
204
|
window_num_kv_splits = torch.empty(
|
@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
|
|
225
227
|
mask_indptr = None
|
226
228
|
max_extend_len = None
|
227
229
|
elif forward_batch.forward_mode.is_target_verify():
|
228
|
-
# TODO: Support sliding window in spec inference
|
229
230
|
bs = len(forward_batch.req_pool_indices)
|
230
231
|
qo_indptr = torch.arange(
|
231
232
|
0,
|
@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
|
|
250
251
|
self.req_to_token.stride(0),
|
251
252
|
)
|
252
253
|
|
254
|
+
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
255
|
+
window_kv_indptr, window_kv_indices, window_kv_lens = (
|
256
|
+
update_sliding_window_buffer(
|
257
|
+
self.window_kv_indptr,
|
258
|
+
self.req_to_token,
|
259
|
+
self.sliding_window_size,
|
260
|
+
forward_batch.seq_lens,
|
261
|
+
forward_batch.req_pool_indices,
|
262
|
+
bs,
|
263
|
+
self.device,
|
264
|
+
self.token_to_kv_pool_allocator,
|
265
|
+
)
|
266
|
+
)
|
267
|
+
|
253
268
|
custom_mask = spec_info.custom_mask
|
254
269
|
seq_mask_len = self.num_draft_tokens * (
|
255
270
|
forward_batch.seq_lens + self.num_draft_tokens
|
@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
308
323
|
forward_batch.req_pool_indices,
|
309
324
|
bs,
|
310
325
|
self.device,
|
326
|
+
self.token_to_kv_pool_allocator,
|
311
327
|
)
|
312
328
|
|
313
329
|
qo_indptr = self.qo_indptr
|
@@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend):
|
|
423
439
|
):
|
424
440
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
425
441
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
426
|
-
window_kv_indptr, _ =
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
442
|
+
window_kv_indptr, window_kv_indices, _ = (
|
443
|
+
update_sliding_window_buffer_cuda_graph(
|
444
|
+
self.window_kv_indptr,
|
445
|
+
window_kv_indices,
|
446
|
+
self.req_to_token,
|
447
|
+
self.sliding_window_size,
|
448
|
+
seq_lens[:bs],
|
449
|
+
req_pool_indices,
|
450
|
+
bs,
|
451
|
+
self.token_to_kv_pool_allocator,
|
452
|
+
)
|
434
453
|
)
|
435
454
|
else:
|
436
455
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
|
|
464
483
|
self.req_to_token.stride(0),
|
465
484
|
)
|
466
485
|
|
486
|
+
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
487
|
+
window_kv_indices = self.cuda_graph_window_kv_indices
|
488
|
+
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
489
|
+
window_kv_indptr, window_kv_indices, _ = (
|
490
|
+
update_sliding_window_buffer_cuda_graph(
|
491
|
+
self.window_kv_indptr,
|
492
|
+
window_kv_indices,
|
493
|
+
self.req_to_token,
|
494
|
+
self.sliding_window_size,
|
495
|
+
seq_lens,
|
496
|
+
req_pool_indices,
|
497
|
+
bs,
|
498
|
+
self.token_to_kv_pool_allocator,
|
499
|
+
)
|
500
|
+
)
|
501
|
+
|
467
502
|
custom_mask = self.cuda_graph_custom_mask
|
468
503
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
469
504
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
557
592
|
):
|
558
593
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
559
594
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
560
|
-
_, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
595
|
+
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
561
596
|
self.window_kv_indptr,
|
562
597
|
window_kv_indices,
|
563
598
|
self.req_to_token,
|
@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
565
600
|
seq_lens[:bs],
|
566
601
|
req_pool_indices[:bs],
|
567
602
|
bs,
|
603
|
+
self.token_to_kv_pool_allocator,
|
568
604
|
)
|
569
605
|
self.get_num_kv_splits(
|
570
606
|
window_num_kv_splits[:num_token], window_kv_lens[:bs]
|
@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
|
|
599
635
|
kv_indices,
|
600
636
|
self.req_to_token.stride(0),
|
601
637
|
)
|
638
|
+
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
639
|
+
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
640
|
+
window_kv_indices = self.cuda_graph_window_kv_indices
|
641
|
+
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
642
|
+
self.window_kv_indptr,
|
643
|
+
window_kv_indices,
|
644
|
+
self.req_to_token,
|
645
|
+
self.sliding_window_size,
|
646
|
+
seq_lens,
|
647
|
+
req_pool_indices,
|
648
|
+
bs,
|
649
|
+
self.token_to_kv_pool_allocator,
|
650
|
+
)
|
602
651
|
custom_mask = self.cuda_graph_custom_mask
|
603
652
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
604
653
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
637
686
|
layer: RadixAttention,
|
638
687
|
forward_batch: ForwardBatch,
|
639
688
|
save_kv_cache=True,
|
689
|
+
sinks=None,
|
640
690
|
):
|
641
691
|
# TODO: reuse the buffer across layers
|
642
692
|
if layer.qk_head_dim != layer.v_head_dim:
|
@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
680
730
|
self.forward_metadata.max_extend_len,
|
681
731
|
layer.scaling,
|
682
732
|
layer.logit_cap,
|
683
|
-
sliding_window_size,
|
733
|
+
sliding_window_size=sliding_window_size,
|
734
|
+
sinks=sinks,
|
684
735
|
)
|
685
736
|
return o
|
686
737
|
|
@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
692
743
|
layer: RadixAttention,
|
693
744
|
forward_batch: ForwardBatch,
|
694
745
|
save_kv_cache=True,
|
746
|
+
sinks=None,
|
695
747
|
):
|
696
748
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
697
749
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
728
780
|
self.max_kv_splits,
|
729
781
|
layer.scaling,
|
730
782
|
layer.logit_cap,
|
783
|
+
sinks=sinks,
|
731
784
|
)
|
732
785
|
return o
|
733
786
|
|
@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
|
|
932
985
|
req_pool_indices,
|
933
986
|
bs,
|
934
987
|
device,
|
988
|
+
token_to_kv_pool_allocator=None,
|
935
989
|
):
|
936
990
|
window_kv_lens = torch.minimum(
|
937
991
|
seq_lens,
|
938
|
-
torch.tensor(sliding_window_size
|
992
|
+
torch.tensor(sliding_window_size),
|
939
993
|
)
|
940
994
|
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
941
995
|
window_kv_indptr = window_kv_indptr[: bs + 1]
|
@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
|
|
952
1006
|
window_kv_indices,
|
953
1007
|
req_to_token.stride(0),
|
954
1008
|
)
|
1009
|
+
# full to swa index mapping
|
1010
|
+
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
|
1011
|
+
kv_last_index = window_kv_indptr[-1]
|
1012
|
+
window_kv_indices[:kv_last_index] = (
|
1013
|
+
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
1014
|
+
window_kv_indices[:kv_last_index]
|
1015
|
+
)
|
1016
|
+
)
|
955
1017
|
return window_kv_indptr, window_kv_indices, window_kv_lens
|
956
1018
|
|
957
1019
|
|
@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
|
|
963
1025
|
seq_lens,
|
964
1026
|
req_pool_indices,
|
965
1027
|
bs,
|
1028
|
+
token_to_kv_pool_allocator=None,
|
966
1029
|
):
|
967
1030
|
window_kv_lens = torch.minimum(
|
968
1031
|
seq_lens,
|
969
|
-
torch.tensor(sliding_window_size
|
1032
|
+
torch.tensor(sliding_window_size),
|
970
1033
|
)
|
971
1034
|
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
972
1035
|
window_kv_indptr = window_kv_indptr[: bs + 1]
|
@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
|
|
980
1043
|
window_kv_indices,
|
981
1044
|
req_to_token.stride(0),
|
982
1045
|
)
|
983
|
-
|
1046
|
+
# full to swa index mapping
|
1047
|
+
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
|
1048
|
+
kv_last_index = window_kv_indptr[-1]
|
1049
|
+
window_kv_indices[:kv_last_index] = (
|
1050
|
+
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
1051
|
+
window_kv_indices[:kv_last_index]
|
1052
|
+
)
|
1053
|
+
)
|
1054
|
+
return window_kv_indptr, window_kv_indices, window_kv_lens
|
@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
|
|
495
495
|
O,
|
496
496
|
kv_indptr,
|
497
497
|
num_kv_splits,
|
498
|
+
sink_ptr,
|
498
499
|
stride_mid_ob,
|
499
500
|
stride_mid_oh,
|
500
501
|
stride_mid_os,
|
@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
|
|
504
505
|
MIN_BLOCK_KV: tl.constexpr,
|
505
506
|
BLOCK_DV: tl.constexpr,
|
506
507
|
Lv: tl.constexpr,
|
508
|
+
HAS_SINK: tl.constexpr,
|
507
509
|
):
|
508
510
|
cur_batch = tl.program_id(0)
|
509
511
|
cur_head = tl.program_id(1)
|
@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
|
|
545
547
|
e_sum = e_sum * old_scale + exp_logic
|
546
548
|
e_max = n_e_max
|
547
549
|
|
550
|
+
if HAS_SINK:
|
551
|
+
cur_sink = tl.load(sink_ptr + cur_head)
|
552
|
+
e_sum += tl.exp(cur_sink - e_max)
|
553
|
+
|
548
554
|
tl.store(
|
549
555
|
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
550
556
|
acc / e_sum,
|
@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
|
|
561
567
|
kv_indptr,
|
562
568
|
num_kv_splits,
|
563
569
|
max_kv_splits,
|
570
|
+
sinks=None,
|
564
571
|
):
|
565
572
|
batch, head_num = q.shape[0], q.shape[1]
|
566
573
|
Lv = v_buffer.shape[-1]
|
567
574
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
568
575
|
|
569
576
|
MAX_KV_SPLITS = max_kv_splits
|
577
|
+
HAS_SINK = sinks is not None
|
570
578
|
|
571
579
|
extra_kargs = {}
|
572
580
|
if _is_hip:
|
@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
|
|
581
589
|
o,
|
582
590
|
kv_indptr,
|
583
591
|
num_kv_splits,
|
592
|
+
sinks,
|
584
593
|
logits.stride(0),
|
585
594
|
logits.stride(1),
|
586
595
|
logits.stride(2),
|
@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
|
|
590
599
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
591
600
|
BLOCK_DV=BLOCK_DV,
|
592
601
|
Lv=Lv,
|
602
|
+
HAS_SINK=HAS_SINK,
|
593
603
|
num_warps=4,
|
594
604
|
num_stages=2,
|
595
605
|
**extra_kargs,
|
@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
|
|
609
619
|
max_kv_splits,
|
610
620
|
sm_scale,
|
611
621
|
logit_cap=0.0,
|
622
|
+
sinks=None,
|
612
623
|
):
|
613
624
|
_decode_att_m_fwd(
|
614
625
|
q,
|
@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
|
|
632
643
|
kv_indptr,
|
633
644
|
num_kv_splits,
|
634
645
|
max_kv_splits,
|
646
|
+
sinks,
|
635
647
|
)
|
636
648
|
|
637
649
|
|
@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
|
|
648
660
|
max_kv_splits,
|
649
661
|
sm_scale,
|
650
662
|
logit_cap=0.0,
|
663
|
+
sinks=None,
|
651
664
|
):
|
652
665
|
_decode_grouped_att_m_fwd(
|
653
666
|
q,
|
@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
|
|
671
684
|
kv_indptr,
|
672
685
|
num_kv_splits,
|
673
686
|
max_kv_splits,
|
687
|
+
sinks,
|
674
688
|
)
|
675
689
|
|
676
690
|
|
@@ -687,6 +701,7 @@ def decode_attention_fwd(
|
|
687
701
|
max_kv_splits,
|
688
702
|
sm_scale,
|
689
703
|
logit_cap=0.0,
|
704
|
+
sinks=None,
|
690
705
|
):
|
691
706
|
assert max_kv_splits == attn_logits.shape[2]
|
692
707
|
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
@@ -709,6 +724,7 @@ def decode_attention_fwd(
|
|
709
724
|
max_kv_splits,
|
710
725
|
sm_scale,
|
711
726
|
logit_cap=logit_cap,
|
727
|
+
sinks=sinks,
|
712
728
|
)
|
713
729
|
else:
|
714
730
|
# GQA/MQA/MLA
|
@@ -725,4 +741,5 @@ def decode_attention_fwd(
|
|
725
741
|
max_kv_splits,
|
726
742
|
sm_scale,
|
727
743
|
logit_cap=logit_cap,
|
744
|
+
sinks=sinks,
|
728
745
|
)
|