sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -117
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +3 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +22 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +8 -5
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +106 -15
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +55 -13
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +40 -15
- sglang/srt/layers/communicator.py +35 -8
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +9 -8
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +87 -107
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +59 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +8 -7
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -4
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +10 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +61 -32
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +21 -4
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +30 -8
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +170 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +59 -22
- sglang/srt/managers/tokenizer_manager.py +137 -67
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +48 -17
- sglang/srt/model_executor/model_runner.py +24 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +95 -50
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +102 -27
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/qwen3_moe.py +39 -14
- sglang/srt/models/step3_vl.py +10 -1
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +218 -23
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +163 -9
- sglang/srt/utils.py +41 -26
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +4 -4
- sglang/test/test_utils.py +4 -4
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -1406,7 +1406,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1406
1406
|
)
|
1407
1407
|
metadata.page_table = self.decode_cuda_graph_metadata[
|
1408
1408
|
"page_table_draft_decode"
|
1409
|
-
][
|
1409
|
+
][:bs, :]
|
1410
1410
|
self.decode_cuda_graph_metadata[bs] = metadata
|
1411
1411
|
else:
|
1412
1412
|
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
@@ -1424,7 +1424,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1424
1424
|
][: bs + 1]
|
1425
1425
|
metadata.page_table = self.draft_decode_metadata_topk_normal[
|
1426
1426
|
"page_table"
|
1427
|
-
][
|
1427
|
+
][:bs, :]
|
1428
1428
|
|
1429
1429
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1430
1430
|
metadata_expand.cache_seqlens_int32 = (
|
@@ -1461,7 +1461,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1461
1461
|
metadata.max_seq_len_k = seq_lens.max().item()
|
1462
1462
|
# Precompute page table
|
1463
1463
|
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
1464
|
-
|
1464
|
+
:bs, :
|
1465
1465
|
]
|
1466
1466
|
# Precompute cumulative sequence lengths
|
1467
1467
|
metadata.cu_seqlens_q = torch.arange(
|
@@ -1498,9 +1498,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1498
1498
|
: (bs + 1)
|
1499
1499
|
]
|
1500
1500
|
|
1501
|
-
metadata.page_table = self.target_verify_metadata["page_table"][
|
1502
|
-
req_pool_indices, :
|
1503
|
-
]
|
1501
|
+
metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
|
1504
1502
|
|
1505
1503
|
self.target_verify_metadata[bs] = metadata
|
1506
1504
|
else:
|
@@ -1519,7 +1517,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1519
1517
|
][: bs + 1]
|
1520
1518
|
metadata.page_table = self.target_verify_metadata_topk_normal[
|
1521
1519
|
"page_table"
|
1522
|
-
][
|
1520
|
+
][:bs, :]
|
1523
1521
|
|
1524
1522
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1525
1523
|
metadata_expand.cache_seqlens_int32 = (
|
@@ -1562,9 +1560,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1562
1560
|
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
|
1563
1561
|
: (bs + 1)
|
1564
1562
|
]
|
1565
|
-
metadata.page_table = self.draft_extend_metadata["page_table"][
|
1566
|
-
req_pool_indices, :
|
1567
|
-
]
|
1563
|
+
metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
|
1568
1564
|
|
1569
1565
|
self.draft_extend_metadata[bs] = metadata
|
1570
1566
|
|
@@ -1578,7 +1574,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1578
1574
|
][: (encoder_bs + 1)]
|
1579
1575
|
|
1580
1576
|
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
|
1581
|
-
|
1577
|
+
:bs, :
|
1582
1578
|
]
|
1583
1579
|
|
1584
1580
|
self.forward_metadata = metadata
|
@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
88
88
|
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
|
89
89
|
|
90
90
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
91
|
+
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
91
92
|
|
92
93
|
if not self.skip_prefill:
|
93
94
|
self.qo_indptr = torch.zeros(
|
@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
197
198
|
forward_batch.req_pool_indices,
|
198
199
|
bs,
|
199
200
|
self.device,
|
201
|
+
self.token_to_kv_pool_allocator,
|
200
202
|
)
|
201
203
|
)
|
202
204
|
window_num_kv_splits = torch.empty(
|
@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
|
|
225
227
|
mask_indptr = None
|
226
228
|
max_extend_len = None
|
227
229
|
elif forward_batch.forward_mode.is_target_verify():
|
228
|
-
# TODO: Support sliding window in spec inference
|
229
230
|
bs = len(forward_batch.req_pool_indices)
|
230
231
|
qo_indptr = torch.arange(
|
231
232
|
0,
|
@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
|
|
250
251
|
self.req_to_token.stride(0),
|
251
252
|
)
|
252
253
|
|
254
|
+
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
255
|
+
window_kv_indptr, window_kv_indices, window_kv_lens = (
|
256
|
+
update_sliding_window_buffer(
|
257
|
+
self.window_kv_indptr,
|
258
|
+
self.req_to_token,
|
259
|
+
self.sliding_window_size,
|
260
|
+
forward_batch.seq_lens,
|
261
|
+
forward_batch.req_pool_indices,
|
262
|
+
bs,
|
263
|
+
self.device,
|
264
|
+
self.token_to_kv_pool_allocator,
|
265
|
+
)
|
266
|
+
)
|
267
|
+
|
253
268
|
custom_mask = spec_info.custom_mask
|
254
269
|
seq_mask_len = self.num_draft_tokens * (
|
255
270
|
forward_batch.seq_lens + self.num_draft_tokens
|
@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
308
323
|
forward_batch.req_pool_indices,
|
309
324
|
bs,
|
310
325
|
self.device,
|
326
|
+
self.token_to_kv_pool_allocator,
|
311
327
|
)
|
312
328
|
|
313
329
|
qo_indptr = self.qo_indptr
|
@@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend):
|
|
423
439
|
):
|
424
440
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
425
441
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
426
|
-
window_kv_indptr, _ =
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
442
|
+
window_kv_indptr, window_kv_indices, _ = (
|
443
|
+
update_sliding_window_buffer_cuda_graph(
|
444
|
+
self.window_kv_indptr,
|
445
|
+
window_kv_indices,
|
446
|
+
self.req_to_token,
|
447
|
+
self.sliding_window_size,
|
448
|
+
seq_lens[:bs],
|
449
|
+
req_pool_indices,
|
450
|
+
bs,
|
451
|
+
self.token_to_kv_pool_allocator,
|
452
|
+
)
|
434
453
|
)
|
435
454
|
else:
|
436
455
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
|
|
464
483
|
self.req_to_token.stride(0),
|
465
484
|
)
|
466
485
|
|
486
|
+
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
487
|
+
window_kv_indices = self.cuda_graph_window_kv_indices
|
488
|
+
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
489
|
+
window_kv_indptr, window_kv_indices, _ = (
|
490
|
+
update_sliding_window_buffer_cuda_graph(
|
491
|
+
self.window_kv_indptr,
|
492
|
+
window_kv_indices,
|
493
|
+
self.req_to_token,
|
494
|
+
self.sliding_window_size,
|
495
|
+
seq_lens,
|
496
|
+
req_pool_indices,
|
497
|
+
bs,
|
498
|
+
self.token_to_kv_pool_allocator,
|
499
|
+
)
|
500
|
+
)
|
501
|
+
|
467
502
|
custom_mask = self.cuda_graph_custom_mask
|
468
503
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
469
504
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
557
592
|
):
|
558
593
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
559
594
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
560
|
-
_, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
595
|
+
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
561
596
|
self.window_kv_indptr,
|
562
597
|
window_kv_indices,
|
563
598
|
self.req_to_token,
|
@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
565
600
|
seq_lens[:bs],
|
566
601
|
req_pool_indices[:bs],
|
567
602
|
bs,
|
603
|
+
self.token_to_kv_pool_allocator,
|
568
604
|
)
|
569
605
|
self.get_num_kv_splits(
|
570
606
|
window_num_kv_splits[:num_token], window_kv_lens[:bs]
|
@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
|
|
599
635
|
kv_indices,
|
600
636
|
self.req_to_token.stride(0),
|
601
637
|
)
|
638
|
+
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
639
|
+
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
640
|
+
window_kv_indices = self.cuda_graph_window_kv_indices
|
641
|
+
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
642
|
+
self.window_kv_indptr,
|
643
|
+
window_kv_indices,
|
644
|
+
self.req_to_token,
|
645
|
+
self.sliding_window_size,
|
646
|
+
seq_lens,
|
647
|
+
req_pool_indices,
|
648
|
+
bs,
|
649
|
+
self.token_to_kv_pool_allocator,
|
650
|
+
)
|
602
651
|
custom_mask = self.cuda_graph_custom_mask
|
603
652
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
604
653
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
637
686
|
layer: RadixAttention,
|
638
687
|
forward_batch: ForwardBatch,
|
639
688
|
save_kv_cache=True,
|
689
|
+
sinks=None,
|
640
690
|
):
|
641
691
|
# TODO: reuse the buffer across layers
|
642
692
|
if layer.qk_head_dim != layer.v_head_dim:
|
@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
680
730
|
self.forward_metadata.max_extend_len,
|
681
731
|
layer.scaling,
|
682
732
|
layer.logit_cap,
|
683
|
-
sliding_window_size,
|
733
|
+
sliding_window_size=sliding_window_size,
|
734
|
+
sinks=sinks,
|
684
735
|
)
|
685
736
|
return o
|
686
737
|
|
@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
692
743
|
layer: RadixAttention,
|
693
744
|
forward_batch: ForwardBatch,
|
694
745
|
save_kv_cache=True,
|
746
|
+
sinks=None,
|
695
747
|
):
|
696
748
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
697
749
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
728
780
|
self.max_kv_splits,
|
729
781
|
layer.scaling,
|
730
782
|
layer.logit_cap,
|
783
|
+
sinks=sinks,
|
731
784
|
)
|
732
785
|
return o
|
733
786
|
|
@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
|
|
932
985
|
req_pool_indices,
|
933
986
|
bs,
|
934
987
|
device,
|
988
|
+
token_to_kv_pool_allocator=None,
|
935
989
|
):
|
936
990
|
window_kv_lens = torch.minimum(
|
937
991
|
seq_lens,
|
938
|
-
torch.tensor(sliding_window_size
|
992
|
+
torch.tensor(sliding_window_size),
|
939
993
|
)
|
940
994
|
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
941
995
|
window_kv_indptr = window_kv_indptr[: bs + 1]
|
@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
|
|
952
1006
|
window_kv_indices,
|
953
1007
|
req_to_token.stride(0),
|
954
1008
|
)
|
1009
|
+
# full to swa index mapping
|
1010
|
+
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
|
1011
|
+
kv_last_index = window_kv_indptr[-1]
|
1012
|
+
window_kv_indices[:kv_last_index] = (
|
1013
|
+
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
1014
|
+
window_kv_indices[:kv_last_index]
|
1015
|
+
)
|
1016
|
+
)
|
955
1017
|
return window_kv_indptr, window_kv_indices, window_kv_lens
|
956
1018
|
|
957
1019
|
|
@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
|
|
963
1025
|
seq_lens,
|
964
1026
|
req_pool_indices,
|
965
1027
|
bs,
|
1028
|
+
token_to_kv_pool_allocator=None,
|
966
1029
|
):
|
967
1030
|
window_kv_lens = torch.minimum(
|
968
1031
|
seq_lens,
|
969
|
-
torch.tensor(sliding_window_size
|
1032
|
+
torch.tensor(sliding_window_size),
|
970
1033
|
)
|
971
1034
|
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
972
1035
|
window_kv_indptr = window_kv_indptr[: bs + 1]
|
@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
|
|
980
1043
|
window_kv_indices,
|
981
1044
|
req_to_token.stride(0),
|
982
1045
|
)
|
983
|
-
|
1046
|
+
# full to swa index mapping
|
1047
|
+
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
|
1048
|
+
kv_last_index = window_kv_indptr[-1]
|
1049
|
+
window_kv_indices[:kv_last_index] = (
|
1050
|
+
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
1051
|
+
window_kv_indices[:kv_last_index]
|
1052
|
+
)
|
1053
|
+
)
|
1054
|
+
return window_kv_indptr, window_kv_indices, window_kv_lens
|
@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
|
|
495
495
|
O,
|
496
496
|
kv_indptr,
|
497
497
|
num_kv_splits,
|
498
|
+
sink_ptr,
|
498
499
|
stride_mid_ob,
|
499
500
|
stride_mid_oh,
|
500
501
|
stride_mid_os,
|
@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
|
|
504
505
|
MIN_BLOCK_KV: tl.constexpr,
|
505
506
|
BLOCK_DV: tl.constexpr,
|
506
507
|
Lv: tl.constexpr,
|
508
|
+
HAS_SINK: tl.constexpr,
|
507
509
|
):
|
508
510
|
cur_batch = tl.program_id(0)
|
509
511
|
cur_head = tl.program_id(1)
|
@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
|
|
545
547
|
e_sum = e_sum * old_scale + exp_logic
|
546
548
|
e_max = n_e_max
|
547
549
|
|
550
|
+
if HAS_SINK:
|
551
|
+
cur_sink = tl.load(sink_ptr + cur_head)
|
552
|
+
e_sum += tl.exp(cur_sink - e_max)
|
553
|
+
|
548
554
|
tl.store(
|
549
555
|
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
550
556
|
acc / e_sum,
|
@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
|
|
561
567
|
kv_indptr,
|
562
568
|
num_kv_splits,
|
563
569
|
max_kv_splits,
|
570
|
+
sinks=None,
|
564
571
|
):
|
565
572
|
batch, head_num = q.shape[0], q.shape[1]
|
566
573
|
Lv = v_buffer.shape[-1]
|
567
574
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
568
575
|
|
569
576
|
MAX_KV_SPLITS = max_kv_splits
|
577
|
+
HAS_SINK = sinks is not None
|
570
578
|
|
571
579
|
extra_kargs = {}
|
572
580
|
if _is_hip:
|
@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
|
|
581
589
|
o,
|
582
590
|
kv_indptr,
|
583
591
|
num_kv_splits,
|
592
|
+
sinks,
|
584
593
|
logits.stride(0),
|
585
594
|
logits.stride(1),
|
586
595
|
logits.stride(2),
|
@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
|
|
590
599
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
591
600
|
BLOCK_DV=BLOCK_DV,
|
592
601
|
Lv=Lv,
|
602
|
+
HAS_SINK=HAS_SINK,
|
593
603
|
num_warps=4,
|
594
604
|
num_stages=2,
|
595
605
|
**extra_kargs,
|
@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
|
|
609
619
|
max_kv_splits,
|
610
620
|
sm_scale,
|
611
621
|
logit_cap=0.0,
|
622
|
+
sinks=None,
|
612
623
|
):
|
613
624
|
_decode_att_m_fwd(
|
614
625
|
q,
|
@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
|
|
632
643
|
kv_indptr,
|
633
644
|
num_kv_splits,
|
634
645
|
max_kv_splits,
|
646
|
+
sinks,
|
635
647
|
)
|
636
648
|
|
637
649
|
|
@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
|
|
648
660
|
max_kv_splits,
|
649
661
|
sm_scale,
|
650
662
|
logit_cap=0.0,
|
663
|
+
sinks=None,
|
651
664
|
):
|
652
665
|
_decode_grouped_att_m_fwd(
|
653
666
|
q,
|
@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
|
|
671
684
|
kv_indptr,
|
672
685
|
num_kv_splits,
|
673
686
|
max_kv_splits,
|
687
|
+
sinks,
|
674
688
|
)
|
675
689
|
|
676
690
|
|
@@ -687,6 +701,7 @@ def decode_attention_fwd(
|
|
687
701
|
max_kv_splits,
|
688
702
|
sm_scale,
|
689
703
|
logit_cap=0.0,
|
704
|
+
sinks=None,
|
690
705
|
):
|
691
706
|
assert max_kv_splits == attn_logits.shape[2]
|
692
707
|
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
@@ -709,6 +724,7 @@ def decode_attention_fwd(
|
|
709
724
|
max_kv_splits,
|
710
725
|
sm_scale,
|
711
726
|
logit_cap=logit_cap,
|
727
|
+
sinks=sinks,
|
712
728
|
)
|
713
729
|
else:
|
714
730
|
# GQA/MQA/MLA
|
@@ -725,4 +741,5 @@ def decode_attention_fwd(
|
|
725
741
|
max_kv_splits,
|
726
742
|
sm_scale,
|
727
743
|
logit_cap=logit_cap,
|
744
|
+
sinks=sinks,
|
728
745
|
)
|