sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
88
88
|
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
|
89
89
|
|
90
90
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
91
|
+
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
91
92
|
|
92
93
|
if not self.skip_prefill:
|
93
94
|
self.qo_indptr = torch.zeros(
|
@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
197
198
|
forward_batch.req_pool_indices,
|
198
199
|
bs,
|
199
200
|
self.device,
|
201
|
+
self.token_to_kv_pool_allocator,
|
200
202
|
)
|
201
203
|
)
|
202
204
|
window_num_kv_splits = torch.empty(
|
@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
|
|
225
227
|
mask_indptr = None
|
226
228
|
max_extend_len = None
|
227
229
|
elif forward_batch.forward_mode.is_target_verify():
|
228
|
-
# TODO: Support sliding window in spec inference
|
229
230
|
bs = len(forward_batch.req_pool_indices)
|
230
231
|
qo_indptr = torch.arange(
|
231
232
|
0,
|
@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
|
|
250
251
|
self.req_to_token.stride(0),
|
251
252
|
)
|
252
253
|
|
254
|
+
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
255
|
+
window_kv_indptr, window_kv_indices, window_kv_lens = (
|
256
|
+
update_sliding_window_buffer(
|
257
|
+
self.window_kv_indptr,
|
258
|
+
self.req_to_token,
|
259
|
+
self.sliding_window_size,
|
260
|
+
forward_batch.seq_lens,
|
261
|
+
forward_batch.req_pool_indices,
|
262
|
+
bs,
|
263
|
+
self.device,
|
264
|
+
self.token_to_kv_pool_allocator,
|
265
|
+
)
|
266
|
+
)
|
267
|
+
|
253
268
|
custom_mask = spec_info.custom_mask
|
254
269
|
seq_mask_len = self.num_draft_tokens * (
|
255
270
|
forward_batch.seq_lens + self.num_draft_tokens
|
@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
308
323
|
forward_batch.req_pool_indices,
|
309
324
|
bs,
|
310
325
|
self.device,
|
326
|
+
self.token_to_kv_pool_allocator,
|
311
327
|
)
|
312
328
|
|
313
329
|
qo_indptr = self.qo_indptr
|
@@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend):
|
|
423
439
|
):
|
424
440
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
425
441
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
426
|
-
window_kv_indptr, _ =
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
442
|
+
window_kv_indptr, window_kv_indices, _ = (
|
443
|
+
update_sliding_window_buffer_cuda_graph(
|
444
|
+
self.window_kv_indptr,
|
445
|
+
window_kv_indices,
|
446
|
+
self.req_to_token,
|
447
|
+
self.sliding_window_size,
|
448
|
+
seq_lens[:bs],
|
449
|
+
req_pool_indices,
|
450
|
+
bs,
|
451
|
+
self.token_to_kv_pool_allocator,
|
452
|
+
)
|
434
453
|
)
|
435
454
|
else:
|
436
455
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
|
|
464
483
|
self.req_to_token.stride(0),
|
465
484
|
)
|
466
485
|
|
486
|
+
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
487
|
+
window_kv_indices = self.cuda_graph_window_kv_indices
|
488
|
+
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
489
|
+
window_kv_indptr, window_kv_indices, _ = (
|
490
|
+
update_sliding_window_buffer_cuda_graph(
|
491
|
+
self.window_kv_indptr,
|
492
|
+
window_kv_indices,
|
493
|
+
self.req_to_token,
|
494
|
+
self.sliding_window_size,
|
495
|
+
seq_lens,
|
496
|
+
req_pool_indices,
|
497
|
+
bs,
|
498
|
+
self.token_to_kv_pool_allocator,
|
499
|
+
)
|
500
|
+
)
|
501
|
+
|
467
502
|
custom_mask = self.cuda_graph_custom_mask
|
468
503
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
469
504
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
557
592
|
):
|
558
593
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
559
594
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
560
|
-
_, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
595
|
+
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
561
596
|
self.window_kv_indptr,
|
562
597
|
window_kv_indices,
|
563
598
|
self.req_to_token,
|
@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
565
600
|
seq_lens[:bs],
|
566
601
|
req_pool_indices[:bs],
|
567
602
|
bs,
|
603
|
+
self.token_to_kv_pool_allocator,
|
568
604
|
)
|
569
605
|
self.get_num_kv_splits(
|
570
606
|
window_num_kv_splits[:num_token], window_kv_lens[:bs]
|
@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
|
|
599
635
|
kv_indices,
|
600
636
|
self.req_to_token.stride(0),
|
601
637
|
)
|
638
|
+
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
639
|
+
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
640
|
+
window_kv_indices = self.cuda_graph_window_kv_indices
|
641
|
+
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
642
|
+
self.window_kv_indptr,
|
643
|
+
window_kv_indices,
|
644
|
+
self.req_to_token,
|
645
|
+
self.sliding_window_size,
|
646
|
+
seq_lens,
|
647
|
+
req_pool_indices,
|
648
|
+
bs,
|
649
|
+
self.token_to_kv_pool_allocator,
|
650
|
+
)
|
602
651
|
custom_mask = self.cuda_graph_custom_mask
|
603
652
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
604
653
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
637
686
|
layer: RadixAttention,
|
638
687
|
forward_batch: ForwardBatch,
|
639
688
|
save_kv_cache=True,
|
689
|
+
sinks=None,
|
640
690
|
):
|
641
691
|
# TODO: reuse the buffer across layers
|
642
692
|
if layer.qk_head_dim != layer.v_head_dim:
|
@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
680
730
|
self.forward_metadata.max_extend_len,
|
681
731
|
layer.scaling,
|
682
732
|
layer.logit_cap,
|
683
|
-
sliding_window_size,
|
733
|
+
sliding_window_size=sliding_window_size,
|
734
|
+
sinks=sinks,
|
684
735
|
)
|
685
736
|
return o
|
686
737
|
|
@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
692
743
|
layer: RadixAttention,
|
693
744
|
forward_batch: ForwardBatch,
|
694
745
|
save_kv_cache=True,
|
746
|
+
sinks=None,
|
695
747
|
):
|
696
748
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
697
749
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
728
780
|
self.max_kv_splits,
|
729
781
|
layer.scaling,
|
730
782
|
layer.logit_cap,
|
783
|
+
sinks=sinks,
|
731
784
|
)
|
732
785
|
return o
|
733
786
|
|
@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
|
|
932
985
|
req_pool_indices,
|
933
986
|
bs,
|
934
987
|
device,
|
988
|
+
token_to_kv_pool_allocator=None,
|
935
989
|
):
|
936
990
|
window_kv_lens = torch.minimum(
|
937
991
|
seq_lens,
|
938
|
-
torch.tensor(sliding_window_size
|
992
|
+
torch.tensor(sliding_window_size),
|
939
993
|
)
|
940
994
|
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
941
995
|
window_kv_indptr = window_kv_indptr[: bs + 1]
|
@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
|
|
952
1006
|
window_kv_indices,
|
953
1007
|
req_to_token.stride(0),
|
954
1008
|
)
|
1009
|
+
# full to swa index mapping
|
1010
|
+
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
|
1011
|
+
kv_last_index = window_kv_indptr[-1]
|
1012
|
+
window_kv_indices[:kv_last_index] = (
|
1013
|
+
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
1014
|
+
window_kv_indices[:kv_last_index]
|
1015
|
+
)
|
1016
|
+
)
|
955
1017
|
return window_kv_indptr, window_kv_indices, window_kv_lens
|
956
1018
|
|
957
1019
|
|
@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
|
|
963
1025
|
seq_lens,
|
964
1026
|
req_pool_indices,
|
965
1027
|
bs,
|
1028
|
+
token_to_kv_pool_allocator=None,
|
966
1029
|
):
|
967
1030
|
window_kv_lens = torch.minimum(
|
968
1031
|
seq_lens,
|
969
|
-
torch.tensor(sliding_window_size
|
1032
|
+
torch.tensor(sliding_window_size),
|
970
1033
|
)
|
971
1034
|
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
972
1035
|
window_kv_indptr = window_kv_indptr[: bs + 1]
|
@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
|
|
980
1043
|
window_kv_indices,
|
981
1044
|
req_to_token.stride(0),
|
982
1045
|
)
|
983
|
-
|
1046
|
+
# full to swa index mapping
|
1047
|
+
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
|
1048
|
+
kv_last_index = window_kv_indptr[-1]
|
1049
|
+
window_kv_indices[:kv_last_index] = (
|
1050
|
+
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
1051
|
+
window_kv_indices[:kv_last_index]
|
1052
|
+
)
|
1053
|
+
)
|
1054
|
+
return window_kv_indptr, window_kv_indices, window_kv_lens
|
@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
|
|
495
495
|
O,
|
496
496
|
kv_indptr,
|
497
497
|
num_kv_splits,
|
498
|
+
sink_ptr,
|
498
499
|
stride_mid_ob,
|
499
500
|
stride_mid_oh,
|
500
501
|
stride_mid_os,
|
@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
|
|
504
505
|
MIN_BLOCK_KV: tl.constexpr,
|
505
506
|
BLOCK_DV: tl.constexpr,
|
506
507
|
Lv: tl.constexpr,
|
508
|
+
HAS_SINK: tl.constexpr,
|
507
509
|
):
|
508
510
|
cur_batch = tl.program_id(0)
|
509
511
|
cur_head = tl.program_id(1)
|
@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
|
|
545
547
|
e_sum = e_sum * old_scale + exp_logic
|
546
548
|
e_max = n_e_max
|
547
549
|
|
550
|
+
if HAS_SINK:
|
551
|
+
cur_sink = tl.load(sink_ptr + cur_head)
|
552
|
+
e_sum += tl.exp(cur_sink - e_max)
|
553
|
+
|
548
554
|
tl.store(
|
549
555
|
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
550
556
|
acc / e_sum,
|
@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
|
|
561
567
|
kv_indptr,
|
562
568
|
num_kv_splits,
|
563
569
|
max_kv_splits,
|
570
|
+
sinks=None,
|
564
571
|
):
|
565
572
|
batch, head_num = q.shape[0], q.shape[1]
|
566
573
|
Lv = v_buffer.shape[-1]
|
567
574
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
568
575
|
|
569
576
|
MAX_KV_SPLITS = max_kv_splits
|
577
|
+
HAS_SINK = sinks is not None
|
570
578
|
|
571
579
|
extra_kargs = {}
|
572
580
|
if _is_hip:
|
@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
|
|
581
589
|
o,
|
582
590
|
kv_indptr,
|
583
591
|
num_kv_splits,
|
592
|
+
sinks,
|
584
593
|
logits.stride(0),
|
585
594
|
logits.stride(1),
|
586
595
|
logits.stride(2),
|
@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
|
|
590
599
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
591
600
|
BLOCK_DV=BLOCK_DV,
|
592
601
|
Lv=Lv,
|
602
|
+
HAS_SINK=HAS_SINK,
|
593
603
|
num_warps=4,
|
594
604
|
num_stages=2,
|
595
605
|
**extra_kargs,
|
@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
|
|
609
619
|
max_kv_splits,
|
610
620
|
sm_scale,
|
611
621
|
logit_cap=0.0,
|
622
|
+
sinks=None,
|
612
623
|
):
|
613
624
|
_decode_att_m_fwd(
|
614
625
|
q,
|
@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
|
|
632
643
|
kv_indptr,
|
633
644
|
num_kv_splits,
|
634
645
|
max_kv_splits,
|
646
|
+
sinks,
|
635
647
|
)
|
636
648
|
|
637
649
|
|
@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
|
|
648
660
|
max_kv_splits,
|
649
661
|
sm_scale,
|
650
662
|
logit_cap=0.0,
|
663
|
+
sinks=None,
|
651
664
|
):
|
652
665
|
_decode_grouped_att_m_fwd(
|
653
666
|
q,
|
@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
|
|
671
684
|
kv_indptr,
|
672
685
|
num_kv_splits,
|
673
686
|
max_kv_splits,
|
687
|
+
sinks,
|
674
688
|
)
|
675
689
|
|
676
690
|
|
@@ -687,6 +701,7 @@ def decode_attention_fwd(
|
|
687
701
|
max_kv_splits,
|
688
702
|
sm_scale,
|
689
703
|
logit_cap=0.0,
|
704
|
+
sinks=None,
|
690
705
|
):
|
691
706
|
assert max_kv_splits == attn_logits.shape[2]
|
692
707
|
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
@@ -709,6 +724,7 @@ def decode_attention_fwd(
|
|
709
724
|
max_kv_splits,
|
710
725
|
sm_scale,
|
711
726
|
logit_cap=logit_cap,
|
727
|
+
sinks=sinks,
|
712
728
|
)
|
713
729
|
else:
|
714
730
|
# GQA/MQA/MLA
|
@@ -725,4 +741,5 @@ def decode_attention_fwd(
|
|
725
741
|
max_kv_splits,
|
726
742
|
sm_scale,
|
727
743
|
logit_cap=logit_cap,
|
744
|
+
sinks=sinks,
|
728
745
|
)
|
@@ -51,6 +51,7 @@ def _fwd_kernel(
|
|
51
51
|
kv_indices,
|
52
52
|
mask_ptr,
|
53
53
|
mask_indptr,
|
54
|
+
sink_ptr,
|
54
55
|
sm_scale,
|
55
56
|
kv_group_num,
|
56
57
|
stride_qbs,
|
@@ -78,6 +79,7 @@ def _fwd_kernel(
|
|
78
79
|
IS_CAUSAL: tl.constexpr,
|
79
80
|
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
80
81
|
STORE_TRANSPOSE: tl.constexpr,
|
82
|
+
HAS_SINK: tl.constexpr,
|
81
83
|
):
|
82
84
|
cur_seq = tl.program_id(0)
|
83
85
|
cur_head = tl.program_id(1)
|
@@ -132,38 +134,6 @@ def _fwd_kernel(
|
|
132
134
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
133
135
|
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
134
136
|
|
135
|
-
offs_kv_loc = tl.load(
|
136
|
-
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
|
137
|
-
)
|
138
|
-
|
139
|
-
# load k in transposed way
|
140
|
-
offs_buf_k = (
|
141
|
-
offs_kv_loc[None, :] * stride_buf_kbs
|
142
|
-
+ cur_kv_head * stride_buf_kh
|
143
|
-
+ offs_d[:, None]
|
144
|
-
)
|
145
|
-
k = tl.load(
|
146
|
-
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
147
|
-
)
|
148
|
-
|
149
|
-
qk = tl.dot(q.to(k.dtype), k)
|
150
|
-
if BLOCK_DPE > 0:
|
151
|
-
offs_kpe = (
|
152
|
-
offs_kv_loc[None, :] * stride_buf_kbs
|
153
|
-
+ cur_kv_head * stride_buf_kh
|
154
|
-
+ offs_dpe[:, None]
|
155
|
-
)
|
156
|
-
kpe = tl.load(
|
157
|
-
K_Buffer + offs_kpe,
|
158
|
-
mask=mask_n[None, :],
|
159
|
-
other=0.0,
|
160
|
-
)
|
161
|
-
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
162
|
-
qk *= sm_scale
|
163
|
-
|
164
|
-
if logit_cap > 0:
|
165
|
-
qk = logit_cap * tanh(qk / logit_cap)
|
166
|
-
|
167
137
|
final_mask = mask_m[:, None] & mask_n[None, :]
|
168
138
|
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
|
169
139
|
custom_mask = tl.load(
|
@@ -178,29 +148,77 @@ def _fwd_kernel(
|
|
178
148
|
final_mask &= custom_mask
|
179
149
|
if SLIDING_WINDOW_SIZE > 0:
|
180
150
|
# Add mask where q_id <= kv_id + sliding_window_size
|
181
|
-
|
182
|
-
|
183
|
-
|
151
|
+
# q_id = prefix_len + cur_m, kv_id = cur_n
|
152
|
+
window_mask = (
|
153
|
+
cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
|
154
|
+
) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
|
184
155
|
final_mask &= window_mask
|
185
|
-
qk = tl.where(final_mask, qk, float("-inf"))
|
186
156
|
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
deno = deno * re_scale + tl.sum(p, 1)
|
157
|
+
SKIP_TILE = False
|
158
|
+
if (USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK) or SLIDING_WINDOW_SIZE > 0:
|
159
|
+
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
|
191
160
|
|
192
|
-
|
193
|
-
offs_kv_loc
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
199
|
-
)
|
200
|
-
p = p.to(v.dtype)
|
201
|
-
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
161
|
+
if not SKIP_TILE:
|
162
|
+
offs_kv_loc = tl.load(
|
163
|
+
kv_indices + cur_seq_kv_start_idx + start_n + offs_n,
|
164
|
+
mask=mask_n,
|
165
|
+
other=0,
|
166
|
+
)
|
202
167
|
|
203
|
-
|
168
|
+
# load k in transposed way
|
169
|
+
offs_buf_k = (
|
170
|
+
offs_kv_loc[None, :] * stride_buf_kbs
|
171
|
+
+ cur_kv_head * stride_buf_kh
|
172
|
+
+ offs_d[:, None]
|
173
|
+
)
|
174
|
+
k = tl.load(
|
175
|
+
K_Buffer + offs_buf_k,
|
176
|
+
mask=(mask_n[None, :]) & (mask_d[:, None]),
|
177
|
+
other=0.0,
|
178
|
+
)
|
179
|
+
|
180
|
+
qk = tl.dot(q.to(k.dtype), k)
|
181
|
+
if BLOCK_DPE > 0:
|
182
|
+
offs_kpe = (
|
183
|
+
offs_kv_loc[None, :] * stride_buf_kbs
|
184
|
+
+ cur_kv_head * stride_buf_kh
|
185
|
+
+ offs_dpe[:, None]
|
186
|
+
)
|
187
|
+
kpe = tl.load(
|
188
|
+
K_Buffer + offs_kpe,
|
189
|
+
mask=mask_n[None, :],
|
190
|
+
other=0.0,
|
191
|
+
)
|
192
|
+
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
193
|
+
qk *= sm_scale
|
194
|
+
|
195
|
+
if logit_cap > 0:
|
196
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
197
|
+
|
198
|
+
qk = tl.where(final_mask, qk, float("-inf"))
|
199
|
+
|
200
|
+
row_max = tl.max(qk, 1)
|
201
|
+
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
202
|
+
n_e_max = tl.maximum(row_max_fixed, e_max)
|
203
|
+
|
204
|
+
re_scale = tl.exp(e_max - n_e_max)
|
205
|
+
p = tl.exp(qk - n_e_max[:, None])
|
206
|
+
deno = deno * re_scale + tl.sum(p, 1)
|
207
|
+
|
208
|
+
offs_buf_v = (
|
209
|
+
offs_kv_loc[:, None] * stride_buf_vbs
|
210
|
+
+ cur_kv_head * stride_buf_vh
|
211
|
+
+ offs_dv[None, :]
|
212
|
+
)
|
213
|
+
v = tl.load(
|
214
|
+
V_Buffer + offs_buf_v,
|
215
|
+
mask=mask_n[:, None] & mask_dv[None, :],
|
216
|
+
other=0.0,
|
217
|
+
)
|
218
|
+
p = p.to(v.dtype)
|
219
|
+
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
220
|
+
|
221
|
+
e_max = n_e_max
|
204
222
|
|
205
223
|
# stage 2: compute the triangle part
|
206
224
|
|
@@ -213,35 +231,7 @@ def _fwd_kernel(
|
|
213
231
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
214
232
|
mask_n = (start_n + offs_n) < cur_block_m_end
|
215
233
|
|
216
|
-
|
217
|
-
offs_k = (
|
218
|
-
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
219
|
-
+ cur_kv_head * stride_kh
|
220
|
-
+ offs_d[:, None]
|
221
|
-
)
|
222
|
-
k = tl.load(
|
223
|
-
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
224
|
-
)
|
225
|
-
|
226
|
-
qk = tl.dot(q, k, out_dtype=tl.float32)
|
227
|
-
if BLOCK_DPE > 0:
|
228
|
-
offs_kpe = (
|
229
|
-
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
230
|
-
+ cur_kv_head * stride_kh
|
231
|
-
+ offs_dpe[:, None]
|
232
|
-
)
|
233
|
-
kpe = tl.load(
|
234
|
-
K_Extend + offs_kpe,
|
235
|
-
mask=mask_n[None, :],
|
236
|
-
other=0.0,
|
237
|
-
)
|
238
|
-
qk += tl.dot(qpe, kpe)
|
239
|
-
|
240
|
-
qk *= sm_scale
|
241
|
-
|
242
|
-
if logit_cap > 0:
|
243
|
-
qk = logit_cap * tanh(qk / logit_cap)
|
244
|
-
|
234
|
+
final_mask = mask_m[:, None] & mask_n[None, :]
|
245
235
|
if USE_CUSTOM_MASK:
|
246
236
|
custom_mask = tl.load(
|
247
237
|
mask_ptr
|
@@ -254,34 +244,84 @@ def _fwd_kernel(
|
|
254
244
|
other=0,
|
255
245
|
)
|
256
246
|
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
257
|
-
|
247
|
+
final_mask &= custom_mask
|
258
248
|
elif IS_CAUSAL:
|
259
249
|
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
260
250
|
start_n + offs_n[None, :]
|
261
251
|
)
|
262
252
|
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
263
|
-
|
253
|
+
final_mask &= mask_causual
|
264
254
|
else:
|
265
255
|
mask_non_causal = mask_m[:, None] & mask_n[None, :]
|
266
|
-
|
256
|
+
final_mask &= mask_non_causal
|
257
|
+
|
258
|
+
if SLIDING_WINDOW_SIZE > 0:
|
259
|
+
# Add mask where q_id <= kv_id + sliding_window_size
|
260
|
+
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
|
261
|
+
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
|
262
|
+
)
|
263
|
+
final_mask &= window_mask
|
267
264
|
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
deno = deno * re_scale + tl.sum(p, 1)
|
265
|
+
SKIP_TILE = False
|
266
|
+
if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:
|
267
|
+
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
|
272
268
|
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
269
|
+
if not SKIP_TILE:
|
270
|
+
# load k in transposed way
|
271
|
+
offs_k = (
|
272
|
+
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
273
|
+
+ cur_kv_head * stride_kh
|
274
|
+
+ offs_d[:, None]
|
275
|
+
)
|
276
|
+
k = tl.load(
|
277
|
+
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
278
|
+
)
|
283
279
|
|
284
|
-
|
280
|
+
qk = tl.dot(q, k, out_dtype=tl.float32)
|
281
|
+
if BLOCK_DPE > 0:
|
282
|
+
offs_kpe = (
|
283
|
+
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
|
284
|
+
+ cur_kv_head * stride_kh
|
285
|
+
+ offs_dpe[:, None]
|
286
|
+
)
|
287
|
+
kpe = tl.load(
|
288
|
+
K_Extend + offs_kpe,
|
289
|
+
mask=mask_n[None, :],
|
290
|
+
other=0.0,
|
291
|
+
)
|
292
|
+
qk += tl.dot(qpe, kpe)
|
293
|
+
|
294
|
+
qk *= sm_scale
|
295
|
+
|
296
|
+
if logit_cap > 0:
|
297
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
298
|
+
|
299
|
+
qk = tl.where(final_mask, qk, float("-inf"))
|
300
|
+
|
301
|
+
row_max = tl.max(qk, 1)
|
302
|
+
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
303
|
+
n_e_max = tl.maximum(row_max_fixed, e_max)
|
304
|
+
|
305
|
+
re_scale = tl.exp(e_max - n_e_max)
|
306
|
+
p = tl.exp(qk - n_e_max[:, None])
|
307
|
+
deno = deno * re_scale + tl.sum(p, 1)
|
308
|
+
|
309
|
+
offs_v = (
|
310
|
+
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
|
311
|
+
+ cur_kv_head * stride_vh
|
312
|
+
+ offs_dv[None, :]
|
313
|
+
)
|
314
|
+
v = tl.load(
|
315
|
+
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
316
|
+
)
|
317
|
+
p = p.to(v.dtype)
|
318
|
+
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
319
|
+
|
320
|
+
e_max = n_e_max
|
321
|
+
|
322
|
+
if HAS_SINK:
|
323
|
+
cur_sink = tl.load(sink_ptr + cur_head)
|
324
|
+
deno += tl.exp(cur_sink - e_max)
|
285
325
|
|
286
326
|
offs_o = (
|
287
327
|
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
@@ -321,6 +361,7 @@ def extend_attention_fwd(
|
|
321
361
|
logit_cap=0.0,
|
322
362
|
skip_prefix_custom_mask=True,
|
323
363
|
sliding_window_size=-1,
|
364
|
+
sinks=None,
|
324
365
|
):
|
325
366
|
"""
|
326
367
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
@@ -386,6 +427,8 @@ def extend_attention_fwd(
|
|
386
427
|
# Skip custom mask for prefix part
|
387
428
|
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
|
388
429
|
|
430
|
+
HAS_SINK = sinks is not None
|
431
|
+
|
389
432
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
390
433
|
num_stages = 1
|
391
434
|
|
@@ -405,6 +448,7 @@ def extend_attention_fwd(
|
|
405
448
|
kv_indices,
|
406
449
|
custom_mask,
|
407
450
|
mask_indptr,
|
451
|
+
sinks,
|
408
452
|
sm_scale,
|
409
453
|
kv_group_num,
|
410
454
|
q_extend.stride(0),
|
@@ -431,6 +475,7 @@ def extend_attention_fwd(
|
|
431
475
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
432
476
|
IS_CAUSAL=is_causal,
|
433
477
|
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
478
|
+
HAS_SINK=HAS_SINK,
|
434
479
|
STORE_TRANSPOSE=_is_hip,
|
435
480
|
num_warps=num_warps,
|
436
481
|
num_stages=num_stages,
|