sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,14 @@ if TYPE_CHECKING:
|
|
20
20
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
21
21
|
|
22
22
|
|
23
|
+
def logit_capping_mod(logit_capping_method, logit_cap):
|
24
|
+
# positive logit_cap -> tanh cap
|
25
|
+
if logit_capping_method == "tanh":
|
26
|
+
return logit_cap
|
27
|
+
else:
|
28
|
+
raise ValueError()
|
29
|
+
|
30
|
+
|
23
31
|
@dataclass
|
24
32
|
class ForwardMetadata:
|
25
33
|
attn_logits: torch.Tensor
|
@@ -35,6 +43,7 @@ class ForwardMetadata:
|
|
35
43
|
window_kv_indptr: torch.Tensor
|
36
44
|
window_kv_indices: torch.Tensor
|
37
45
|
window_num_kv_splits: torch.Tensor
|
46
|
+
window_kv_offsets: torch.Tensor
|
38
47
|
|
39
48
|
|
40
49
|
class TritonAttnBackend(AttentionBackend):
|
@@ -57,16 +66,36 @@ class TritonAttnBackend(AttentionBackend):
|
|
57
66
|
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
|
58
67
|
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
|
59
68
|
|
69
|
+
# Parse args
|
60
70
|
self.skip_prefill = skip_prefill
|
61
|
-
|
62
71
|
max_bs = model_runner.req_to_token_pool.size
|
72
|
+
self.sliding_window_size = model_runner.sliding_window_size
|
73
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
74
|
+
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
75
|
+
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
76
|
+
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
77
|
+
self.num_head = (
|
78
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
79
|
+
)
|
80
|
+
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
81
|
+
get_attention_tp_size()
|
82
|
+
)
|
83
|
+
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
84
|
+
self.max_context_len = model_runner.model_config.context_len
|
85
|
+
self.device = model_runner.device
|
86
|
+
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
87
|
+
self.static_kv_splits = get_bool_env_var(
|
88
|
+
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
89
|
+
)
|
90
|
+
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
63
91
|
|
92
|
+
# Check arguments
|
64
93
|
assert not (
|
65
94
|
model_runner.sliding_window_size is not None
|
66
95
|
and model_runner.model_config.is_encoder_decoder
|
67
96
|
), "Sliding window and cross attention are not supported together"
|
68
|
-
self.sliding_window_size = model_runner.sliding_window_size
|
69
97
|
|
98
|
+
# Initialize buffers
|
70
99
|
# TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
|
71
100
|
if kv_indptr_buf is None:
|
72
101
|
self.kv_indptr = torch.zeros(
|
@@ -87,9 +116,6 @@ class TritonAttnBackend(AttentionBackend):
|
|
87
116
|
# When provided a buffer, create a clone for the second buffer
|
88
117
|
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
|
89
118
|
|
90
|
-
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
91
|
-
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
|
92
|
-
|
93
119
|
if not self.skip_prefill:
|
94
120
|
self.qo_indptr = torch.zeros(
|
95
121
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
@@ -99,29 +125,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
99
125
|
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
100
126
|
)
|
101
127
|
|
102
|
-
|
103
|
-
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
104
|
-
|
105
|
-
self.num_head = (
|
106
|
-
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
107
|
-
)
|
108
|
-
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
109
|
-
get_attention_tp_size()
|
110
|
-
)
|
111
|
-
|
112
|
-
self.static_kv_splits = get_bool_env_var(
|
113
|
-
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
114
|
-
)
|
115
|
-
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
116
|
-
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
117
|
-
|
128
|
+
# Initialize forward metadata
|
118
129
|
self.forward_metadata: ForwardMetadata = None
|
119
130
|
|
120
|
-
self.max_context_len = model_runner.model_config.context_len
|
121
|
-
|
122
|
-
self.device = model_runner.device
|
123
|
-
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
124
|
-
|
125
131
|
def get_num_kv_splits(
|
126
132
|
self,
|
127
133
|
num_kv_splits: torch.Tensor,
|
@@ -166,6 +172,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
166
172
|
window_kv_indptr = self.window_kv_indptr
|
167
173
|
window_kv_indices = None
|
168
174
|
window_num_kv_splits = None
|
175
|
+
window_kv_offsets = None
|
169
176
|
spec_info = forward_batch.spec_info
|
170
177
|
|
171
178
|
if forward_batch.forward_mode.is_decode_or_idle():
|
@@ -173,7 +180,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
173
180
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
174
181
|
kv_indptr = kv_indptr[: bs + 1]
|
175
182
|
kv_indices = torch.empty(
|
176
|
-
forward_batch.seq_lens_sum, dtype=torch.
|
183
|
+
forward_batch.seq_lens_sum, dtype=torch.int64, device=self.device
|
177
184
|
)
|
178
185
|
create_flashinfer_kv_indices_triton[(bs,)](
|
179
186
|
self.req_to_token,
|
@@ -189,7 +196,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
189
196
|
self.sliding_window_size is not None
|
190
197
|
and self.sliding_window_size > 0
|
191
198
|
):
|
192
|
-
window_kv_indptr, window_kv_indices, window_kv_lens = (
|
199
|
+
window_kv_indptr, window_kv_indices, window_kv_lens, _ = (
|
193
200
|
update_sliding_window_buffer(
|
194
201
|
self.window_kv_indptr,
|
195
202
|
self.req_to_token,
|
@@ -239,7 +246,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
239
246
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
240
247
|
kv_indptr = kv_indptr[: bs + 1]
|
241
248
|
kv_indices = torch.empty(
|
242
|
-
kv_indptr[-1], dtype=torch.
|
249
|
+
kv_indptr[-1], dtype=torch.int64, device=self.device
|
243
250
|
)
|
244
251
|
create_flashinfer_kv_indices_triton[(bs,)](
|
245
252
|
self.req_to_token,
|
@@ -252,17 +259,21 @@ class TritonAttnBackend(AttentionBackend):
|
|
252
259
|
)
|
253
260
|
|
254
261
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
262
|
+
# window_kv_offsets is used to calculate the start position in custom mask
|
263
|
+
(
|
264
|
+
window_kv_indptr,
|
265
|
+
window_kv_indices,
|
266
|
+
window_kv_lens,
|
267
|
+
window_kv_offsets,
|
268
|
+
) = update_sliding_window_buffer(
|
269
|
+
self.window_kv_indptr,
|
270
|
+
self.req_to_token,
|
271
|
+
self.sliding_window_size,
|
272
|
+
forward_batch.seq_lens,
|
273
|
+
forward_batch.req_pool_indices,
|
274
|
+
bs,
|
275
|
+
self.device,
|
276
|
+
self.token_to_kv_pool_allocator,
|
266
277
|
)
|
267
278
|
|
268
279
|
custom_mask = spec_info.custom_mask
|
@@ -286,6 +297,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
286
297
|
self.req_to_token,
|
287
298
|
)
|
288
299
|
)
|
300
|
+
kv_indices = kv_indices.to(torch.int64)
|
289
301
|
mask_indptr = None
|
290
302
|
# TODO(FIXME): This will trigger an invalid Eagle tree when using
|
291
303
|
# `max(spec_info.accept_length_cpu)`.
|
@@ -301,7 +313,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
301
313
|
kv_indptr = kv_indptr[: bs + 1]
|
302
314
|
kv_indices = torch.empty(
|
303
315
|
forward_batch.extend_prefix_lens.sum().item(),
|
304
|
-
dtype=torch.
|
316
|
+
dtype=torch.int64,
|
305
317
|
device=self.device,
|
306
318
|
)
|
307
319
|
create_flashinfer_kv_indices_triton[(bs,)](
|
@@ -315,15 +327,17 @@ class TritonAttnBackend(AttentionBackend):
|
|
315
327
|
)
|
316
328
|
# Sliding window
|
317
329
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
318
|
-
window_kv_indptr, window_kv_indices, _ =
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
330
|
+
window_kv_indptr, window_kv_indices, _, _ = (
|
331
|
+
update_sliding_window_buffer(
|
332
|
+
self.window_kv_indptr,
|
333
|
+
self.req_to_token,
|
334
|
+
self.sliding_window_size,
|
335
|
+
forward_batch.extend_prefix_lens,
|
336
|
+
forward_batch.req_pool_indices,
|
337
|
+
bs,
|
338
|
+
self.device,
|
339
|
+
self.token_to_kv_pool_allocator,
|
340
|
+
)
|
327
341
|
)
|
328
342
|
|
329
343
|
qo_indptr = self.qo_indptr
|
@@ -333,7 +347,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
333
347
|
mask_indptr = None
|
334
348
|
attn_logits = None
|
335
349
|
attn_lse = None
|
336
|
-
max_extend_len =
|
350
|
+
max_extend_len = max(forward_batch.extend_seq_lens_cpu)
|
337
351
|
num_kv_splits = None
|
338
352
|
|
339
353
|
self.forward_metadata = ForwardMetadata(
|
@@ -349,6 +363,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
349
363
|
window_kv_indptr,
|
350
364
|
window_kv_indices,
|
351
365
|
window_num_kv_splits,
|
366
|
+
window_kv_offsets,
|
352
367
|
)
|
353
368
|
|
354
369
|
def init_cuda_graph_state(
|
@@ -373,7 +388,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
373
388
|
if kv_indices_buf is None:
|
374
389
|
self.cuda_graph_kv_indices = torch.zeros(
|
375
390
|
(max_num_tokens * self.max_context_len),
|
376
|
-
dtype=torch.
|
391
|
+
dtype=torch.int64,
|
377
392
|
device=self.device,
|
378
393
|
)
|
379
394
|
else:
|
@@ -390,7 +405,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
390
405
|
if kv_indices_buf is None:
|
391
406
|
self.cuda_graph_window_kv_indices = torch.zeros(
|
392
407
|
(max_num_tokens * self.sliding_window_size),
|
393
|
-
dtype=torch.
|
408
|
+
dtype=torch.int64,
|
394
409
|
device=self.device,
|
395
410
|
)
|
396
411
|
else:
|
@@ -403,6 +418,12 @@ class TritonAttnBackend(AttentionBackend):
|
|
403
418
|
device=self.device,
|
404
419
|
)
|
405
420
|
|
421
|
+
self.cuda_graph_window_kv_offsets = torch.zeros(
|
422
|
+
(max_bs,),
|
423
|
+
dtype=torch.int32,
|
424
|
+
device=self.device,
|
425
|
+
)
|
426
|
+
|
406
427
|
def init_forward_metadata_capture_cuda_graph(
|
407
428
|
self,
|
408
429
|
bs: int,
|
@@ -417,6 +438,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
417
438
|
window_kv_indptr = self.window_kv_indptr
|
418
439
|
window_kv_indices = None
|
419
440
|
window_num_kv_splits = None
|
441
|
+
window_kv_offsets = None
|
420
442
|
|
421
443
|
if forward_mode.is_decode_or_idle():
|
422
444
|
if spec_info is None:
|
@@ -439,7 +461,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
439
461
|
):
|
440
462
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
441
463
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
442
|
-
window_kv_indptr, window_kv_indices, _ = (
|
464
|
+
window_kv_indptr, window_kv_indices, _, _ = (
|
443
465
|
update_sliding_window_buffer_cuda_graph(
|
444
466
|
self.window_kv_indptr,
|
445
467
|
window_kv_indices,
|
@@ -486,13 +508,14 @@ class TritonAttnBackend(AttentionBackend):
|
|
486
508
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
487
509
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
488
510
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
489
|
-
|
511
|
+
window_kv_offsets = self.cuda_graph_window_kv_offsets
|
512
|
+
window_kv_indptr, window_kv_indices, _, window_kv_offsets[:bs] = (
|
490
513
|
update_sliding_window_buffer_cuda_graph(
|
491
514
|
self.window_kv_indptr,
|
492
515
|
window_kv_indices,
|
493
516
|
self.req_to_token,
|
494
517
|
self.sliding_window_size,
|
495
|
-
seq_lens,
|
518
|
+
seq_lens[:bs],
|
496
519
|
req_pool_indices,
|
497
520
|
bs,
|
498
521
|
self.token_to_kv_pool_allocator,
|
@@ -554,6 +577,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
554
577
|
window_kv_indptr,
|
555
578
|
window_kv_indices,
|
556
579
|
window_num_kv_splits,
|
580
|
+
window_kv_offsets,
|
557
581
|
)
|
558
582
|
|
559
583
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -592,7 +616,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
592
616
|
):
|
593
617
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
594
618
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
595
|
-
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
619
|
+
_, _, window_kv_lens, _ = update_sliding_window_buffer_cuda_graph(
|
596
620
|
self.window_kv_indptr,
|
597
621
|
window_kv_indices,
|
598
622
|
self.req_to_token,
|
@@ -638,15 +662,18 @@ class TritonAttnBackend(AttentionBackend):
|
|
638
662
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
639
663
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
640
664
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
665
|
+
window_kv_offsets = self.cuda_graph_window_kv_offsets
|
666
|
+
_, _, window_kv_lens, window_kv_offsets[:bs] = (
|
667
|
+
update_sliding_window_buffer_cuda_graph(
|
668
|
+
self.window_kv_indptr,
|
669
|
+
window_kv_indices,
|
670
|
+
self.req_to_token,
|
671
|
+
self.sliding_window_size,
|
672
|
+
seq_lens[:bs],
|
673
|
+
req_pool_indices,
|
674
|
+
bs,
|
675
|
+
self.token_to_kv_pool_allocator,
|
676
|
+
)
|
650
677
|
)
|
651
678
|
custom_mask = self.cuda_graph_custom_mask
|
652
679
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
@@ -699,6 +726,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
699
726
|
layer, forward_batch.out_cache_loc, k, v
|
700
727
|
)
|
701
728
|
|
729
|
+
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
|
730
|
+
|
702
731
|
causal = True
|
703
732
|
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
704
733
|
causal = False
|
@@ -709,10 +738,12 @@ class TritonAttnBackend(AttentionBackend):
|
|
709
738
|
) # Needed for sliding window mask
|
710
739
|
kv_indptr = self.forward_metadata.window_kv_indptr
|
711
740
|
kv_indices = self.forward_metadata.window_kv_indices
|
741
|
+
window_kv_offsets = self.forward_metadata.window_kv_offsets
|
712
742
|
else:
|
713
743
|
sliding_window_size = -1
|
714
744
|
kv_indptr = self.forward_metadata.kv_indptr
|
715
745
|
kv_indices = self.forward_metadata.kv_indices
|
746
|
+
window_kv_offsets = None
|
716
747
|
|
717
748
|
self.extend_attention_fwd(
|
718
749
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
@@ -729,9 +760,11 @@ class TritonAttnBackend(AttentionBackend):
|
|
729
760
|
self.forward_metadata.mask_indptr,
|
730
761
|
self.forward_metadata.max_extend_len,
|
731
762
|
layer.scaling,
|
732
|
-
|
763
|
+
logit_cap=logits_soft_cap,
|
733
764
|
sliding_window_size=sliding_window_size,
|
734
765
|
sinks=sinks,
|
766
|
+
window_kv_offsets=window_kv_offsets,
|
767
|
+
xai_temperature_len=layer.xai_temperature_len,
|
735
768
|
)
|
736
769
|
return o
|
737
770
|
|
@@ -755,6 +788,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
755
788
|
else:
|
756
789
|
o = torch.empty_like(q)
|
757
790
|
|
791
|
+
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
|
792
|
+
|
758
793
|
if save_kv_cache:
|
759
794
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
760
795
|
layer, forward_batch.out_cache_loc, k, v
|
@@ -779,8 +814,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
779
814
|
self.forward_metadata.num_kv_splits,
|
780
815
|
self.max_kv_splits,
|
781
816
|
layer.scaling,
|
782
|
-
|
817
|
+
logit_cap=logits_soft_cap,
|
783
818
|
sinks=sinks,
|
819
|
+
xai_temperature_len=layer.xai_temperature_len,
|
784
820
|
)
|
785
821
|
return o
|
786
822
|
|
@@ -867,7 +903,7 @@ class TritonMultiStepDraftBackend:
|
|
867
903
|
self.speculative_num_steps,
|
868
904
|
forward_batch.batch_size * self.topk * self.max_context_len,
|
869
905
|
),
|
870
|
-
dtype=torch.
|
906
|
+
dtype=torch.int64,
|
871
907
|
device=self.device,
|
872
908
|
)
|
873
909
|
|
@@ -885,7 +921,7 @@ class TritonMultiStepDraftBackend:
|
|
885
921
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
886
922
|
self.cuda_graph_kv_indices = torch.zeros(
|
887
923
|
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
|
888
|
-
dtype=torch.
|
924
|
+
dtype=torch.int64,
|
889
925
|
device=self.device,
|
890
926
|
)
|
891
927
|
for i in range(self.speculative_num_steps):
|
@@ -994,7 +1030,7 @@ def update_sliding_window_buffer(
|
|
994
1030
|
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
995
1031
|
window_kv_indptr = window_kv_indptr[: bs + 1]
|
996
1032
|
window_kv_indices = torch.empty(
|
997
|
-
window_kv_indptr[-1], dtype=torch.
|
1033
|
+
window_kv_indptr[-1], dtype=torch.int64, device=device
|
998
1034
|
)
|
999
1035
|
window_kv_start_idx = seq_lens - window_kv_lens
|
1000
1036
|
create_flashinfer_kv_indices_triton[(bs,)](
|
@@ -1014,7 +1050,7 @@ def update_sliding_window_buffer(
|
|
1014
1050
|
window_kv_indices[:kv_last_index]
|
1015
1051
|
)
|
1016
1052
|
)
|
1017
|
-
return window_kv_indptr, window_kv_indices, window_kv_lens
|
1053
|
+
return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
|
1018
1054
|
|
1019
1055
|
|
1020
1056
|
def update_sliding_window_buffer_cuda_graph(
|
@@ -1051,4 +1087,4 @@ def update_sliding_window_buffer_cuda_graph(
|
|
1051
1087
|
window_kv_indices[:kv_last_index]
|
1052
1088
|
)
|
1053
1089
|
)
|
1054
|
-
return window_kv_indptr, window_kv_indices, window_kv_lens
|
1090
|
+
return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
|
@@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
|
|
69
69
|
logit_cap: tl.constexpr,
|
70
70
|
Lk: tl.constexpr,
|
71
71
|
Lv: tl.constexpr,
|
72
|
+
xai_temperature_len: tl.constexpr,
|
72
73
|
):
|
73
74
|
cur_batch = tl.program_id(0)
|
74
75
|
cur_head = tl.program_id(1)
|
@@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
|
|
85
86
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
86
87
|
kv_splits = tl.load(num_kv_splits + cur_batch)
|
87
88
|
|
89
|
+
if xai_temperature_len > 0:
|
90
|
+
offs_qidx = cur_batch_seq_len - 1
|
91
|
+
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
|
92
|
+
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
|
93
|
+
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
|
94
|
+
|
88
95
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
89
96
|
|
90
97
|
kv_len_per_split = (
|
@@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
|
|
122
129
|
if logit_cap > 0:
|
123
130
|
qk = logit_cap * tanh(qk / logit_cap)
|
124
131
|
|
132
|
+
if xai_temperature_len > 0:
|
133
|
+
qk *= xai_temperature_reg
|
134
|
+
|
125
135
|
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
|
126
136
|
|
127
137
|
offs_buf_v = (
|
@@ -181,6 +191,7 @@ def _decode_att_m_fwd(
|
|
181
191
|
max_kv_splits,
|
182
192
|
sm_scale,
|
183
193
|
logit_cap,
|
194
|
+
xai_temperature_len=-1,
|
184
195
|
):
|
185
196
|
BLOCK = 64
|
186
197
|
# [TODO] work around SGPR limit on MI3xx
|
@@ -190,7 +201,7 @@ def _decode_att_m_fwd(
|
|
190
201
|
Lk = k_buffer.shape[-1]
|
191
202
|
Lv = v_buffer.shape[-1]
|
192
203
|
|
193
|
-
batch, head_num =
|
204
|
+
batch, head_num = q.shape[0], q.shape[1]
|
194
205
|
|
195
206
|
grid = (batch, head_num, MAX_KV_SPLITS)
|
196
207
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
@@ -230,6 +241,7 @@ def _decode_att_m_fwd(
|
|
230
241
|
BLOCK_N=BLOCK,
|
231
242
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
232
243
|
logit_cap=logit_cap,
|
244
|
+
xai_temperature_len=xai_temperature_len,
|
233
245
|
num_warps=num_warps,
|
234
246
|
num_stages=2,
|
235
247
|
Lk=Lk,
|
@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
|
|
266
278
|
BLOCK_H: tl.constexpr,
|
267
279
|
MIN_BLOCK_KV: tl.constexpr,
|
268
280
|
logit_cap: tl.constexpr,
|
281
|
+
xai_temperature_len: tl.constexpr,
|
269
282
|
Lk: tl.constexpr,
|
270
283
|
Lv: tl.constexpr,
|
271
284
|
):
|
@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
|
|
291
304
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
292
305
|
kv_splits = tl.load(num_kv_splits + cur_batch)
|
293
306
|
|
307
|
+
if xai_temperature_len > 0:
|
308
|
+
offs_qidx = cur_batch_seq_len - 1
|
309
|
+
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
|
310
|
+
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
|
311
|
+
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
|
312
|
+
|
294
313
|
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
295
314
|
|
296
315
|
if BLOCK_DPE > 0:
|
@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
|
|
351
370
|
if logit_cap > 0:
|
352
371
|
qk = logit_cap * tanh(qk / logit_cap)
|
353
372
|
|
373
|
+
if xai_temperature_len > 0:
|
374
|
+
qk *= xai_temperature_reg[:, None]
|
375
|
+
|
354
376
|
qk = tl.where(
|
355
377
|
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
|
356
378
|
)
|
@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
|
|
413
435
|
max_kv_splits,
|
414
436
|
sm_scale,
|
415
437
|
logit_cap,
|
438
|
+
xai_temperature_len=-1,
|
416
439
|
):
|
417
440
|
BLOCK = 32
|
418
441
|
Lk = k_buffer.shape[-1]
|
@@ -433,7 +456,7 @@ def _decode_grouped_att_m_fwd(
|
|
433
456
|
BLOCK_DPE = 0
|
434
457
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
435
458
|
|
436
|
-
batch, head_num =
|
459
|
+
batch, head_num = q.shape[0], q.shape[1]
|
437
460
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
438
461
|
|
439
462
|
BLOCK_H = 16
|
@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
|
|
480
503
|
BLOCK_H=BLOCK_H,
|
481
504
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
482
505
|
logit_cap=logit_cap,
|
506
|
+
xai_temperature_len=xai_temperature_len,
|
483
507
|
num_warps=4,
|
484
508
|
num_stages=num_stages,
|
485
509
|
Lk=Lk,
|
@@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
|
|
620
644
|
sm_scale,
|
621
645
|
logit_cap=0.0,
|
622
646
|
sinks=None,
|
647
|
+
xai_temperature_len=-1,
|
623
648
|
):
|
624
649
|
_decode_att_m_fwd(
|
625
650
|
q,
|
@@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
|
|
633
658
|
max_kv_splits,
|
634
659
|
sm_scale,
|
635
660
|
logit_cap,
|
661
|
+
xai_temperature_len,
|
636
662
|
)
|
637
663
|
_decode_softmax_reducev_fwd(
|
638
664
|
attn_logits,
|
@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
|
|
661
687
|
sm_scale,
|
662
688
|
logit_cap=0.0,
|
663
689
|
sinks=None,
|
690
|
+
xai_temperature_len=-1,
|
664
691
|
):
|
665
692
|
_decode_grouped_att_m_fwd(
|
666
693
|
q,
|
@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
|
|
674
701
|
max_kv_splits,
|
675
702
|
sm_scale,
|
676
703
|
logit_cap,
|
704
|
+
xai_temperature_len,
|
677
705
|
)
|
678
706
|
_decode_softmax_reducev_fwd(
|
679
707
|
attn_logits,
|
@@ -702,6 +730,7 @@ def decode_attention_fwd(
|
|
702
730
|
sm_scale,
|
703
731
|
logit_cap=0.0,
|
704
732
|
sinks=None,
|
733
|
+
xai_temperature_len=-1,
|
705
734
|
):
|
706
735
|
assert max_kv_splits == attn_logits.shape[2]
|
707
736
|
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
@@ -725,6 +754,7 @@ def decode_attention_fwd(
|
|
725
754
|
sm_scale,
|
726
755
|
logit_cap=logit_cap,
|
727
756
|
sinks=sinks,
|
757
|
+
xai_temperature_len=xai_temperature_len,
|
728
758
|
)
|
729
759
|
else:
|
730
760
|
# GQA/MQA/MLA
|
@@ -742,4 +772,5 @@ def decode_attention_fwd(
|
|
742
772
|
sm_scale,
|
743
773
|
logit_cap=logit_cap,
|
744
774
|
sinks=sinks,
|
775
|
+
xai_temperature_len=xai_temperature_len,
|
745
776
|
)
|
@@ -52,6 +52,7 @@ def _fwd_kernel(
|
|
52
52
|
mask_ptr,
|
53
53
|
mask_indptr,
|
54
54
|
sink_ptr,
|
55
|
+
window_kv_offset_ptr,
|
55
56
|
sm_scale,
|
56
57
|
kv_group_num,
|
57
58
|
stride_qbs,
|
@@ -68,6 +69,7 @@ def _fwd_kernel(
|
|
68
69
|
stride_buf_vh,
|
69
70
|
SLIDING_WINDOW_SIZE: tl.constexpr,
|
70
71
|
logit_cap: tl.constexpr,
|
72
|
+
xai_temperature_len: tl.constexpr,
|
71
73
|
Lq: tl.constexpr,
|
72
74
|
Lv: tl.constexpr,
|
73
75
|
BLOCK_DMODEL: tl.constexpr,
|
@@ -95,6 +97,11 @@ def _fwd_kernel(
|
|
95
97
|
if USE_CUSTOM_MASK:
|
96
98
|
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
|
97
99
|
|
100
|
+
# For SWA, we should only load the mask in the sliding window
|
101
|
+
window_kv_offset = 0
|
102
|
+
if USE_CUSTOM_MASK and SLIDING_WINDOW_SIZE > 0:
|
103
|
+
window_kv_offset = tl.load(window_kv_offset_ptr + cur_seq)
|
104
|
+
|
98
105
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
99
106
|
offs_dv = tl.arange(0, BLOCK_DV)
|
100
107
|
offs_m = tl.arange(0, BLOCK_M)
|
@@ -103,6 +110,15 @@ def _fwd_kernel(
|
|
103
110
|
mask_d = offs_d < Lq
|
104
111
|
mask_dv = offs_dv < Lv
|
105
112
|
|
113
|
+
if xai_temperature_len > 0:
|
114
|
+
offs_qidx = cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m
|
115
|
+
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
|
116
|
+
xai_temperature_reg = tl.where(
|
117
|
+
offs_qidx > xai_temperature_len,
|
118
|
+
tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale,
|
119
|
+
1.0,
|
120
|
+
)
|
121
|
+
|
106
122
|
offs_q = (
|
107
123
|
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
108
124
|
* stride_qbs
|
@@ -139,7 +155,9 @@ def _fwd_kernel(
|
|
139
155
|
custom_mask = tl.load(
|
140
156
|
mask_ptr
|
141
157
|
+ cur_seq_mask_start_idx
|
142
|
-
+ (cur_block_m * BLOCK_M + offs_m[:, None])
|
158
|
+
+ (cur_block_m * BLOCK_M + offs_m[:, None])
|
159
|
+
* (cur_seq_len + window_kv_offset)
|
160
|
+
+ window_kv_offset
|
143
161
|
+ start_n
|
144
162
|
+ offs_n[None, :],
|
145
163
|
mask=(mask_m[:, None] & mask_n[None, :]),
|
@@ -195,6 +213,9 @@ def _fwd_kernel(
|
|
195
213
|
if logit_cap > 0:
|
196
214
|
qk = logit_cap * tanh(qk / logit_cap)
|
197
215
|
|
216
|
+
if xai_temperature_len > 0:
|
217
|
+
qk *= xai_temperature_reg[:, None]
|
218
|
+
|
198
219
|
qk = tl.where(final_mask, qk, float("-inf"))
|
199
220
|
|
200
221
|
row_max = tl.max(qk, 1)
|
@@ -236,7 +257,9 @@ def _fwd_kernel(
|
|
236
257
|
custom_mask = tl.load(
|
237
258
|
mask_ptr
|
238
259
|
+ cur_seq_mask_start_idx
|
239
|
-
+ (cur_block_m * BLOCK_M + offs_m[:, None])
|
260
|
+
+ (cur_block_m * BLOCK_M + offs_m[:, None])
|
261
|
+
* (cur_seq_len + window_kv_offset)
|
262
|
+
+ window_kv_offset
|
240
263
|
+ cur_seq_len_prefix
|
241
264
|
+ start_n
|
242
265
|
+ offs_n[None, :],
|
@@ -296,6 +319,9 @@ def _fwd_kernel(
|
|
296
319
|
if logit_cap > 0:
|
297
320
|
qk = logit_cap * tanh(qk / logit_cap)
|
298
321
|
|
322
|
+
if xai_temperature_len > 0:
|
323
|
+
qk *= xai_temperature_reg[:, None]
|
324
|
+
|
299
325
|
qk = tl.where(final_mask, qk, float("-inf"))
|
300
326
|
|
301
327
|
row_max = tl.max(qk, 1)
|
@@ -362,6 +388,8 @@ def extend_attention_fwd(
|
|
362
388
|
skip_prefix_custom_mask=True,
|
363
389
|
sliding_window_size=-1,
|
364
390
|
sinks=None,
|
391
|
+
window_kv_offsets=None,
|
392
|
+
xai_temperature_len=-1,
|
365
393
|
):
|
366
394
|
"""
|
367
395
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
@@ -449,6 +477,7 @@ def extend_attention_fwd(
|
|
449
477
|
custom_mask,
|
450
478
|
mask_indptr,
|
451
479
|
sinks,
|
480
|
+
window_kv_offsets,
|
452
481
|
sm_scale,
|
453
482
|
kv_group_num,
|
454
483
|
q_extend.stride(0),
|
@@ -465,6 +494,7 @@ def extend_attention_fwd(
|
|
465
494
|
v_buffer.stride(1),
|
466
495
|
SLIDING_WINDOW_SIZE=sliding_window_size,
|
467
496
|
logit_cap=logit_cap,
|
497
|
+
xai_temperature_len=xai_temperature_len,
|
468
498
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
469
499
|
BLOCK_DPE=BLOCK_DPE,
|
470
500
|
BLOCK_DV=BLOCK_DV,
|