sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,14 @@ if TYPE_CHECKING:
20
20
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
21
21
 
22
22
 
23
+ def logit_capping_mod(logit_capping_method, logit_cap):
24
+ # positive logit_cap -> tanh cap
25
+ if logit_capping_method == "tanh":
26
+ return logit_cap
27
+ else:
28
+ raise ValueError()
29
+
30
+
23
31
  @dataclass
24
32
  class ForwardMetadata:
25
33
  attn_logits: torch.Tensor
@@ -35,6 +43,7 @@ class ForwardMetadata:
35
43
  window_kv_indptr: torch.Tensor
36
44
  window_kv_indices: torch.Tensor
37
45
  window_num_kv_splits: torch.Tensor
46
+ window_kv_offsets: torch.Tensor
38
47
 
39
48
 
40
49
  class TritonAttnBackend(AttentionBackend):
@@ -57,16 +66,36 @@ class TritonAttnBackend(AttentionBackend):
57
66
  self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
58
67
  self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
59
68
 
69
+ # Parse args
60
70
  self.skip_prefill = skip_prefill
61
-
62
71
  max_bs = model_runner.req_to_token_pool.size
72
+ self.sliding_window_size = model_runner.sliding_window_size
73
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
74
+ self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
75
+ self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
76
+ self.speculative_num_steps = model_runner.server_args.speculative_num_steps
77
+ self.num_head = (
78
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
79
+ )
80
+ self.num_kv_head = model_runner.model_config.get_num_kv_heads(
81
+ get_attention_tp_size()
82
+ )
83
+ self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
84
+ self.max_context_len = model_runner.model_config.context_len
85
+ self.device = model_runner.device
86
+ self.device_core_count = get_device_core_count(model_runner.gpu_id)
87
+ self.static_kv_splits = get_bool_env_var(
88
+ "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
89
+ )
90
+ self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
63
91
 
92
+ # Check arguments
64
93
  assert not (
65
94
  model_runner.sliding_window_size is not None
66
95
  and model_runner.model_config.is_encoder_decoder
67
96
  ), "Sliding window and cross attention are not supported together"
68
- self.sliding_window_size = model_runner.sliding_window_size
69
97
 
98
+ # Initialize buffers
70
99
  # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
71
100
  if kv_indptr_buf is None:
72
101
  self.kv_indptr = torch.zeros(
@@ -87,9 +116,6 @@ class TritonAttnBackend(AttentionBackend):
87
116
  # When provided a buffer, create a clone for the second buffer
88
117
  self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
89
118
 
90
- self.req_to_token = model_runner.req_to_token_pool.req_to_token
91
- self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
92
-
93
119
  if not self.skip_prefill:
94
120
  self.qo_indptr = torch.zeros(
95
121
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
@@ -99,29 +125,9 @@ class TritonAttnBackend(AttentionBackend):
99
125
  (max_bs + 1,), dtype=torch.int64, device=model_runner.device
100
126
  )
101
127
 
102
- self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
103
- self.speculative_num_steps = model_runner.server_args.speculative_num_steps
104
-
105
- self.num_head = (
106
- model_runner.model_config.num_attention_heads // get_attention_tp_size()
107
- )
108
- self.num_kv_head = model_runner.model_config.get_num_kv_heads(
109
- get_attention_tp_size()
110
- )
111
-
112
- self.static_kv_splits = get_bool_env_var(
113
- "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
114
- )
115
- self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
116
- self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
117
-
128
+ # Initialize forward metadata
118
129
  self.forward_metadata: ForwardMetadata = None
119
130
 
120
- self.max_context_len = model_runner.model_config.context_len
121
-
122
- self.device = model_runner.device
123
- self.device_core_count = get_device_core_count(model_runner.gpu_id)
124
-
125
131
  def get_num_kv_splits(
126
132
  self,
127
133
  num_kv_splits: torch.Tensor,
@@ -166,6 +172,7 @@ class TritonAttnBackend(AttentionBackend):
166
172
  window_kv_indptr = self.window_kv_indptr
167
173
  window_kv_indices = None
168
174
  window_num_kv_splits = None
175
+ window_kv_offsets = None
169
176
  spec_info = forward_batch.spec_info
170
177
 
171
178
  if forward_batch.forward_mode.is_decode_or_idle():
@@ -173,7 +180,7 @@ class TritonAttnBackend(AttentionBackend):
173
180
  kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
174
181
  kv_indptr = kv_indptr[: bs + 1]
175
182
  kv_indices = torch.empty(
176
- forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
183
+ forward_batch.seq_lens_sum, dtype=torch.int64, device=self.device
177
184
  )
178
185
  create_flashinfer_kv_indices_triton[(bs,)](
179
186
  self.req_to_token,
@@ -189,7 +196,7 @@ class TritonAttnBackend(AttentionBackend):
189
196
  self.sliding_window_size is not None
190
197
  and self.sliding_window_size > 0
191
198
  ):
192
- window_kv_indptr, window_kv_indices, window_kv_lens = (
199
+ window_kv_indptr, window_kv_indices, window_kv_lens, _ = (
193
200
  update_sliding_window_buffer(
194
201
  self.window_kv_indptr,
195
202
  self.req_to_token,
@@ -239,7 +246,7 @@ class TritonAttnBackend(AttentionBackend):
239
246
  kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
240
247
  kv_indptr = kv_indptr[: bs + 1]
241
248
  kv_indices = torch.empty(
242
- kv_indptr[-1], dtype=torch.int32, device=self.device
249
+ kv_indptr[-1], dtype=torch.int64, device=self.device
243
250
  )
244
251
  create_flashinfer_kv_indices_triton[(bs,)](
245
252
  self.req_to_token,
@@ -252,17 +259,21 @@ class TritonAttnBackend(AttentionBackend):
252
259
  )
253
260
 
254
261
  if self.sliding_window_size is not None and self.sliding_window_size > 0:
255
- window_kv_indptr, window_kv_indices, window_kv_lens = (
256
- update_sliding_window_buffer(
257
- self.window_kv_indptr,
258
- self.req_to_token,
259
- self.sliding_window_size,
260
- forward_batch.seq_lens,
261
- forward_batch.req_pool_indices,
262
- bs,
263
- self.device,
264
- self.token_to_kv_pool_allocator,
265
- )
262
+ # window_kv_offsets is used to calculate the start position in custom mask
263
+ (
264
+ window_kv_indptr,
265
+ window_kv_indices,
266
+ window_kv_lens,
267
+ window_kv_offsets,
268
+ ) = update_sliding_window_buffer(
269
+ self.window_kv_indptr,
270
+ self.req_to_token,
271
+ self.sliding_window_size,
272
+ forward_batch.seq_lens,
273
+ forward_batch.req_pool_indices,
274
+ bs,
275
+ self.device,
276
+ self.token_to_kv_pool_allocator,
266
277
  )
267
278
 
268
279
  custom_mask = spec_info.custom_mask
@@ -286,6 +297,7 @@ class TritonAttnBackend(AttentionBackend):
286
297
  self.req_to_token,
287
298
  )
288
299
  )
300
+ kv_indices = kv_indices.to(torch.int64)
289
301
  mask_indptr = None
290
302
  # TODO(FIXME): This will trigger an invalid Eagle tree when using
291
303
  # `max(spec_info.accept_length_cpu)`.
@@ -301,7 +313,7 @@ class TritonAttnBackend(AttentionBackend):
301
313
  kv_indptr = kv_indptr[: bs + 1]
302
314
  kv_indices = torch.empty(
303
315
  forward_batch.extend_prefix_lens.sum().item(),
304
- dtype=torch.int32,
316
+ dtype=torch.int64,
305
317
  device=self.device,
306
318
  )
307
319
  create_flashinfer_kv_indices_triton[(bs,)](
@@ -315,15 +327,17 @@ class TritonAttnBackend(AttentionBackend):
315
327
  )
316
328
  # Sliding window
317
329
  if self.sliding_window_size is not None and self.sliding_window_size > 0:
318
- window_kv_indptr, window_kv_indices, _ = update_sliding_window_buffer(
319
- self.window_kv_indptr,
320
- self.req_to_token,
321
- self.sliding_window_size,
322
- forward_batch.extend_prefix_lens,
323
- forward_batch.req_pool_indices,
324
- bs,
325
- self.device,
326
- self.token_to_kv_pool_allocator,
330
+ window_kv_indptr, window_kv_indices, _, _ = (
331
+ update_sliding_window_buffer(
332
+ self.window_kv_indptr,
333
+ self.req_to_token,
334
+ self.sliding_window_size,
335
+ forward_batch.extend_prefix_lens,
336
+ forward_batch.req_pool_indices,
337
+ bs,
338
+ self.device,
339
+ self.token_to_kv_pool_allocator,
340
+ )
327
341
  )
328
342
 
329
343
  qo_indptr = self.qo_indptr
@@ -333,7 +347,7 @@ class TritonAttnBackend(AttentionBackend):
333
347
  mask_indptr = None
334
348
  attn_logits = None
335
349
  attn_lse = None
336
- max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
350
+ max_extend_len = max(forward_batch.extend_seq_lens_cpu)
337
351
  num_kv_splits = None
338
352
 
339
353
  self.forward_metadata = ForwardMetadata(
@@ -349,6 +363,7 @@ class TritonAttnBackend(AttentionBackend):
349
363
  window_kv_indptr,
350
364
  window_kv_indices,
351
365
  window_num_kv_splits,
366
+ window_kv_offsets,
352
367
  )
353
368
 
354
369
  def init_cuda_graph_state(
@@ -373,7 +388,7 @@ class TritonAttnBackend(AttentionBackend):
373
388
  if kv_indices_buf is None:
374
389
  self.cuda_graph_kv_indices = torch.zeros(
375
390
  (max_num_tokens * self.max_context_len),
376
- dtype=torch.int32,
391
+ dtype=torch.int64,
377
392
  device=self.device,
378
393
  )
379
394
  else:
@@ -390,7 +405,7 @@ class TritonAttnBackend(AttentionBackend):
390
405
  if kv_indices_buf is None:
391
406
  self.cuda_graph_window_kv_indices = torch.zeros(
392
407
  (max_num_tokens * self.sliding_window_size),
393
- dtype=torch.int32,
408
+ dtype=torch.int64,
394
409
  device=self.device,
395
410
  )
396
411
  else:
@@ -403,6 +418,12 @@ class TritonAttnBackend(AttentionBackend):
403
418
  device=self.device,
404
419
  )
405
420
 
421
+ self.cuda_graph_window_kv_offsets = torch.zeros(
422
+ (max_bs,),
423
+ dtype=torch.int32,
424
+ device=self.device,
425
+ )
426
+
406
427
  def init_forward_metadata_capture_cuda_graph(
407
428
  self,
408
429
  bs: int,
@@ -417,6 +438,7 @@ class TritonAttnBackend(AttentionBackend):
417
438
  window_kv_indptr = self.window_kv_indptr
418
439
  window_kv_indices = None
419
440
  window_num_kv_splits = None
441
+ window_kv_offsets = None
420
442
 
421
443
  if forward_mode.is_decode_or_idle():
422
444
  if spec_info is None:
@@ -439,7 +461,7 @@ class TritonAttnBackend(AttentionBackend):
439
461
  ):
440
462
  window_kv_indices = self.cuda_graph_window_kv_indices
441
463
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
442
- window_kv_indptr, window_kv_indices, _ = (
464
+ window_kv_indptr, window_kv_indices, _, _ = (
443
465
  update_sliding_window_buffer_cuda_graph(
444
466
  self.window_kv_indptr,
445
467
  window_kv_indices,
@@ -486,13 +508,14 @@ class TritonAttnBackend(AttentionBackend):
486
508
  if self.sliding_window_size is not None and self.sliding_window_size > 0:
487
509
  window_kv_indices = self.cuda_graph_window_kv_indices
488
510
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
489
- window_kv_indptr, window_kv_indices, _ = (
511
+ window_kv_offsets = self.cuda_graph_window_kv_offsets
512
+ window_kv_indptr, window_kv_indices, _, window_kv_offsets[:bs] = (
490
513
  update_sliding_window_buffer_cuda_graph(
491
514
  self.window_kv_indptr,
492
515
  window_kv_indices,
493
516
  self.req_to_token,
494
517
  self.sliding_window_size,
495
- seq_lens,
518
+ seq_lens[:bs],
496
519
  req_pool_indices,
497
520
  bs,
498
521
  self.token_to_kv_pool_allocator,
@@ -554,6 +577,7 @@ class TritonAttnBackend(AttentionBackend):
554
577
  window_kv_indptr,
555
578
  window_kv_indices,
556
579
  window_num_kv_splits,
580
+ window_kv_offsets,
557
581
  )
558
582
 
559
583
  def init_forward_metadata_replay_cuda_graph(
@@ -592,7 +616,7 @@ class TritonAttnBackend(AttentionBackend):
592
616
  ):
593
617
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
594
618
  window_kv_indices = self.cuda_graph_window_kv_indices
595
- _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
619
+ _, _, window_kv_lens, _ = update_sliding_window_buffer_cuda_graph(
596
620
  self.window_kv_indptr,
597
621
  window_kv_indices,
598
622
  self.req_to_token,
@@ -638,15 +662,18 @@ class TritonAttnBackend(AttentionBackend):
638
662
  if self.sliding_window_size is not None and self.sliding_window_size > 0:
639
663
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
640
664
  window_kv_indices = self.cuda_graph_window_kv_indices
641
- _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
642
- self.window_kv_indptr,
643
- window_kv_indices,
644
- self.req_to_token,
645
- self.sliding_window_size,
646
- seq_lens,
647
- req_pool_indices,
648
- bs,
649
- self.token_to_kv_pool_allocator,
665
+ window_kv_offsets = self.cuda_graph_window_kv_offsets
666
+ _, _, window_kv_lens, window_kv_offsets[:bs] = (
667
+ update_sliding_window_buffer_cuda_graph(
668
+ self.window_kv_indptr,
669
+ window_kv_indices,
670
+ self.req_to_token,
671
+ self.sliding_window_size,
672
+ seq_lens[:bs],
673
+ req_pool_indices,
674
+ bs,
675
+ self.token_to_kv_pool_allocator,
676
+ )
650
677
  )
651
678
  custom_mask = self.cuda_graph_custom_mask
652
679
  custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
@@ -699,6 +726,8 @@ class TritonAttnBackend(AttentionBackend):
699
726
  layer, forward_batch.out_cache_loc, k, v
700
727
  )
701
728
 
729
+ logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
730
+
702
731
  causal = True
703
732
  if layer.attn_type == AttentionType.ENCODER_ONLY:
704
733
  causal = False
@@ -709,10 +738,12 @@ class TritonAttnBackend(AttentionBackend):
709
738
  ) # Needed for sliding window mask
710
739
  kv_indptr = self.forward_metadata.window_kv_indptr
711
740
  kv_indices = self.forward_metadata.window_kv_indices
741
+ window_kv_offsets = self.forward_metadata.window_kv_offsets
712
742
  else:
713
743
  sliding_window_size = -1
714
744
  kv_indptr = self.forward_metadata.kv_indptr
715
745
  kv_indices = self.forward_metadata.kv_indices
746
+ window_kv_offsets = None
716
747
 
717
748
  self.extend_attention_fwd(
718
749
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
@@ -729,9 +760,11 @@ class TritonAttnBackend(AttentionBackend):
729
760
  self.forward_metadata.mask_indptr,
730
761
  self.forward_metadata.max_extend_len,
731
762
  layer.scaling,
732
- layer.logit_cap,
763
+ logit_cap=logits_soft_cap,
733
764
  sliding_window_size=sliding_window_size,
734
765
  sinks=sinks,
766
+ window_kv_offsets=window_kv_offsets,
767
+ xai_temperature_len=layer.xai_temperature_len,
735
768
  )
736
769
  return o
737
770
 
@@ -755,6 +788,8 @@ class TritonAttnBackend(AttentionBackend):
755
788
  else:
756
789
  o = torch.empty_like(q)
757
790
 
791
+ logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
792
+
758
793
  if save_kv_cache:
759
794
  forward_batch.token_to_kv_pool.set_kv_buffer(
760
795
  layer, forward_batch.out_cache_loc, k, v
@@ -779,8 +814,9 @@ class TritonAttnBackend(AttentionBackend):
779
814
  self.forward_metadata.num_kv_splits,
780
815
  self.max_kv_splits,
781
816
  layer.scaling,
782
- layer.logit_cap,
817
+ logit_cap=logits_soft_cap,
783
818
  sinks=sinks,
819
+ xai_temperature_len=layer.xai_temperature_len,
784
820
  )
785
821
  return o
786
822
 
@@ -867,7 +903,7 @@ class TritonMultiStepDraftBackend:
867
903
  self.speculative_num_steps,
868
904
  forward_batch.batch_size * self.topk * self.max_context_len,
869
905
  ),
870
- dtype=torch.int32,
906
+ dtype=torch.int64,
871
907
  device=self.device,
872
908
  )
873
909
 
@@ -885,7 +921,7 @@ class TritonMultiStepDraftBackend:
885
921
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
886
922
  self.cuda_graph_kv_indices = torch.zeros(
887
923
  (self.speculative_num_steps, max_num_tokens * self.max_context_len),
888
- dtype=torch.int32,
924
+ dtype=torch.int64,
889
925
  device=self.device,
890
926
  )
891
927
  for i in range(self.speculative_num_steps):
@@ -994,7 +1030,7 @@ def update_sliding_window_buffer(
994
1030
  window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
995
1031
  window_kv_indptr = window_kv_indptr[: bs + 1]
996
1032
  window_kv_indices = torch.empty(
997
- window_kv_indptr[-1], dtype=torch.int32, device=device
1033
+ window_kv_indptr[-1], dtype=torch.int64, device=device
998
1034
  )
999
1035
  window_kv_start_idx = seq_lens - window_kv_lens
1000
1036
  create_flashinfer_kv_indices_triton[(bs,)](
@@ -1014,7 +1050,7 @@ def update_sliding_window_buffer(
1014
1050
  window_kv_indices[:kv_last_index]
1015
1051
  )
1016
1052
  )
1017
- return window_kv_indptr, window_kv_indices, window_kv_lens
1053
+ return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
1018
1054
 
1019
1055
 
1020
1056
  def update_sliding_window_buffer_cuda_graph(
@@ -1051,4 +1087,4 @@ def update_sliding_window_buffer_cuda_graph(
1051
1087
  window_kv_indices[:kv_last_index]
1052
1088
  )
1053
1089
  )
1054
- return window_kv_indptr, window_kv_indices, window_kv_lens
1090
+ return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
@@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
69
69
  logit_cap: tl.constexpr,
70
70
  Lk: tl.constexpr,
71
71
  Lv: tl.constexpr,
72
+ xai_temperature_len: tl.constexpr,
72
73
  ):
73
74
  cur_batch = tl.program_id(0)
74
75
  cur_head = tl.program_id(1)
@@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
85
86
  cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
86
87
  kv_splits = tl.load(num_kv_splits + cur_batch)
87
88
 
89
+ if xai_temperature_len > 0:
90
+ offs_qidx = cur_batch_seq_len - 1
91
+ xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
92
+ _qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
93
+ xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
94
+
88
95
  off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
89
96
 
90
97
  kv_len_per_split = (
@@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
122
129
  if logit_cap > 0:
123
130
  qk = logit_cap * tanh(qk / logit_cap)
124
131
 
132
+ if xai_temperature_len > 0:
133
+ qk *= xai_temperature_reg
134
+
125
135
  qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
126
136
 
127
137
  offs_buf_v = (
@@ -181,6 +191,7 @@ def _decode_att_m_fwd(
181
191
  max_kv_splits,
182
192
  sm_scale,
183
193
  logit_cap,
194
+ xai_temperature_len=-1,
184
195
  ):
185
196
  BLOCK = 64
186
197
  # [TODO] work around SGPR limit on MI3xx
@@ -190,7 +201,7 @@ def _decode_att_m_fwd(
190
201
  Lk = k_buffer.shape[-1]
191
202
  Lv = v_buffer.shape[-1]
192
203
 
193
- batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
204
+ batch, head_num = q.shape[0], q.shape[1]
194
205
 
195
206
  grid = (batch, head_num, MAX_KV_SPLITS)
196
207
  kv_group_num = q.shape[1] // k_buffer.shape[1]
@@ -230,6 +241,7 @@ def _decode_att_m_fwd(
230
241
  BLOCK_N=BLOCK,
231
242
  MIN_BLOCK_KV=_MIN_BLOCK_KV,
232
243
  logit_cap=logit_cap,
244
+ xai_temperature_len=xai_temperature_len,
233
245
  num_warps=num_warps,
234
246
  num_stages=2,
235
247
  Lk=Lk,
@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
266
278
  BLOCK_H: tl.constexpr,
267
279
  MIN_BLOCK_KV: tl.constexpr,
268
280
  logit_cap: tl.constexpr,
281
+ xai_temperature_len: tl.constexpr,
269
282
  Lk: tl.constexpr,
270
283
  Lv: tl.constexpr,
271
284
  ):
@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
291
304
  cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
292
305
  kv_splits = tl.load(num_kv_splits + cur_batch)
293
306
 
307
+ if xai_temperature_len > 0:
308
+ offs_qidx = cur_batch_seq_len - 1
309
+ xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
310
+ _qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
311
+ xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
312
+
294
313
  offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
295
314
 
296
315
  if BLOCK_DPE > 0:
@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
351
370
  if logit_cap > 0:
352
371
  qk = logit_cap * tanh(qk / logit_cap)
353
372
 
373
+ if xai_temperature_len > 0:
374
+ qk *= xai_temperature_reg[:, None]
375
+
354
376
  qk = tl.where(
355
377
  mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
356
378
  )
@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
413
435
  max_kv_splits,
414
436
  sm_scale,
415
437
  logit_cap,
438
+ xai_temperature_len=-1,
416
439
  ):
417
440
  BLOCK = 32
418
441
  Lk = k_buffer.shape[-1]
@@ -433,7 +456,7 @@ def _decode_grouped_att_m_fwd(
433
456
  BLOCK_DPE = 0
434
457
  BLOCK_DV = triton.next_power_of_2(Lv)
435
458
 
436
- batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
459
+ batch, head_num = q.shape[0], q.shape[1]
437
460
  kv_group_num = q.shape[1] // k_buffer.shape[1]
438
461
 
439
462
  BLOCK_H = 16
@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
480
503
  BLOCK_H=BLOCK_H,
481
504
  MIN_BLOCK_KV=_MIN_BLOCK_KV,
482
505
  logit_cap=logit_cap,
506
+ xai_temperature_len=xai_temperature_len,
483
507
  num_warps=4,
484
508
  num_stages=num_stages,
485
509
  Lk=Lk,
@@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
620
644
  sm_scale,
621
645
  logit_cap=0.0,
622
646
  sinks=None,
647
+ xai_temperature_len=-1,
623
648
  ):
624
649
  _decode_att_m_fwd(
625
650
  q,
@@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
633
658
  max_kv_splits,
634
659
  sm_scale,
635
660
  logit_cap,
661
+ xai_temperature_len,
636
662
  )
637
663
  _decode_softmax_reducev_fwd(
638
664
  attn_logits,
@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
661
687
  sm_scale,
662
688
  logit_cap=0.0,
663
689
  sinks=None,
690
+ xai_temperature_len=-1,
664
691
  ):
665
692
  _decode_grouped_att_m_fwd(
666
693
  q,
@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
674
701
  max_kv_splits,
675
702
  sm_scale,
676
703
  logit_cap,
704
+ xai_temperature_len,
677
705
  )
678
706
  _decode_softmax_reducev_fwd(
679
707
  attn_logits,
@@ -702,6 +730,7 @@ def decode_attention_fwd(
702
730
  sm_scale,
703
731
  logit_cap=0.0,
704
732
  sinks=None,
733
+ xai_temperature_len=-1,
705
734
  ):
706
735
  assert max_kv_splits == attn_logits.shape[2]
707
736
  assert q.shape[0] <= kv_indptr.shape[0] - 1
@@ -725,6 +754,7 @@ def decode_attention_fwd(
725
754
  sm_scale,
726
755
  logit_cap=logit_cap,
727
756
  sinks=sinks,
757
+ xai_temperature_len=xai_temperature_len,
728
758
  )
729
759
  else:
730
760
  # GQA/MQA/MLA
@@ -742,4 +772,5 @@ def decode_attention_fwd(
742
772
  sm_scale,
743
773
  logit_cap=logit_cap,
744
774
  sinks=sinks,
775
+ xai_temperature_len=xai_temperature_len,
745
776
  )
@@ -52,6 +52,7 @@ def _fwd_kernel(
52
52
  mask_ptr,
53
53
  mask_indptr,
54
54
  sink_ptr,
55
+ window_kv_offset_ptr,
55
56
  sm_scale,
56
57
  kv_group_num,
57
58
  stride_qbs,
@@ -68,6 +69,7 @@ def _fwd_kernel(
68
69
  stride_buf_vh,
69
70
  SLIDING_WINDOW_SIZE: tl.constexpr,
70
71
  logit_cap: tl.constexpr,
72
+ xai_temperature_len: tl.constexpr,
71
73
  Lq: tl.constexpr,
72
74
  Lv: tl.constexpr,
73
75
  BLOCK_DMODEL: tl.constexpr,
@@ -95,6 +97,11 @@ def _fwd_kernel(
95
97
  if USE_CUSTOM_MASK:
96
98
  cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
97
99
 
100
+ # For SWA, we should only load the mask in the sliding window
101
+ window_kv_offset = 0
102
+ if USE_CUSTOM_MASK and SLIDING_WINDOW_SIZE > 0:
103
+ window_kv_offset = tl.load(window_kv_offset_ptr + cur_seq)
104
+
98
105
  offs_d = tl.arange(0, BLOCK_DMODEL)
99
106
  offs_dv = tl.arange(0, BLOCK_DV)
100
107
  offs_m = tl.arange(0, BLOCK_M)
@@ -103,6 +110,15 @@ def _fwd_kernel(
103
110
  mask_d = offs_d < Lq
104
111
  mask_dv = offs_dv < Lv
105
112
 
113
+ if xai_temperature_len > 0:
114
+ offs_qidx = cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m
115
+ xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
116
+ xai_temperature_reg = tl.where(
117
+ offs_qidx > xai_temperature_len,
118
+ tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale,
119
+ 1.0,
120
+ )
121
+
106
122
  offs_q = (
107
123
  (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
108
124
  * stride_qbs
@@ -139,7 +155,9 @@ def _fwd_kernel(
139
155
  custom_mask = tl.load(
140
156
  mask_ptr
141
157
  + cur_seq_mask_start_idx
142
- + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
158
+ + (cur_block_m * BLOCK_M + offs_m[:, None])
159
+ * (cur_seq_len + window_kv_offset)
160
+ + window_kv_offset
143
161
  + start_n
144
162
  + offs_n[None, :],
145
163
  mask=(mask_m[:, None] & mask_n[None, :]),
@@ -195,6 +213,9 @@ def _fwd_kernel(
195
213
  if logit_cap > 0:
196
214
  qk = logit_cap * tanh(qk / logit_cap)
197
215
 
216
+ if xai_temperature_len > 0:
217
+ qk *= xai_temperature_reg[:, None]
218
+
198
219
  qk = tl.where(final_mask, qk, float("-inf"))
199
220
 
200
221
  row_max = tl.max(qk, 1)
@@ -236,7 +257,9 @@ def _fwd_kernel(
236
257
  custom_mask = tl.load(
237
258
  mask_ptr
238
259
  + cur_seq_mask_start_idx
239
- + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
260
+ + (cur_block_m * BLOCK_M + offs_m[:, None])
261
+ * (cur_seq_len + window_kv_offset)
262
+ + window_kv_offset
240
263
  + cur_seq_len_prefix
241
264
  + start_n
242
265
  + offs_n[None, :],
@@ -296,6 +319,9 @@ def _fwd_kernel(
296
319
  if logit_cap > 0:
297
320
  qk = logit_cap * tanh(qk / logit_cap)
298
321
 
322
+ if xai_temperature_len > 0:
323
+ qk *= xai_temperature_reg[:, None]
324
+
299
325
  qk = tl.where(final_mask, qk, float("-inf"))
300
326
 
301
327
  row_max = tl.max(qk, 1)
@@ -362,6 +388,8 @@ def extend_attention_fwd(
362
388
  skip_prefix_custom_mask=True,
363
389
  sliding_window_size=-1,
364
390
  sinks=None,
391
+ window_kv_offsets=None,
392
+ xai_temperature_len=-1,
365
393
  ):
366
394
  """
367
395
  q_extend, k_extend, v_extend, o_extend: contiguous tensors
@@ -449,6 +477,7 @@ def extend_attention_fwd(
449
477
  custom_mask,
450
478
  mask_indptr,
451
479
  sinks,
480
+ window_kv_offsets,
452
481
  sm_scale,
453
482
  kv_group_num,
454
483
  q_extend.stride(0),
@@ -465,6 +494,7 @@ def extend_attention_fwd(
465
494
  v_buffer.stride(1),
466
495
  SLIDING_WINDOW_SIZE=sliding_window_size,
467
496
  logit_cap=logit_cap,
497
+ xai_temperature_len=xai_temperature_len,
468
498
  BLOCK_DMODEL=BLOCK_DMODEL,
469
499
  BLOCK_DPE=BLOCK_DPE,
470
500
  BLOCK_DV=BLOCK_DV,