sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +375 -51
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,14 @@ if TYPE_CHECKING:
20
20
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
21
21
 
22
22
 
23
+ def logit_capping_mod(logit_capping_method, logit_cap):
24
+ # positive logit_cap -> tanh cap
25
+ if logit_capping_method == "tanh":
26
+ return logit_cap
27
+ else:
28
+ raise ValueError()
29
+
30
+
23
31
  @dataclass
24
32
  class ForwardMetadata:
25
33
  attn_logits: torch.Tensor
@@ -35,6 +43,7 @@ class ForwardMetadata:
35
43
  window_kv_indptr: torch.Tensor
36
44
  window_kv_indices: torch.Tensor
37
45
  window_num_kv_splits: torch.Tensor
46
+ window_kv_offsets: torch.Tensor
38
47
 
39
48
 
40
49
  class TritonAttnBackend(AttentionBackend):
@@ -163,6 +172,7 @@ class TritonAttnBackend(AttentionBackend):
163
172
  window_kv_indptr = self.window_kv_indptr
164
173
  window_kv_indices = None
165
174
  window_num_kv_splits = None
175
+ window_kv_offsets = None
166
176
  spec_info = forward_batch.spec_info
167
177
 
168
178
  if forward_batch.forward_mode.is_decode_or_idle():
@@ -170,7 +180,7 @@ class TritonAttnBackend(AttentionBackend):
170
180
  kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
171
181
  kv_indptr = kv_indptr[: bs + 1]
172
182
  kv_indices = torch.empty(
173
- forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
183
+ forward_batch.seq_lens_sum, dtype=torch.int64, device=self.device
174
184
  )
175
185
  create_flashinfer_kv_indices_triton[(bs,)](
176
186
  self.req_to_token,
@@ -186,7 +196,7 @@ class TritonAttnBackend(AttentionBackend):
186
196
  self.sliding_window_size is not None
187
197
  and self.sliding_window_size > 0
188
198
  ):
189
- window_kv_indptr, window_kv_indices, window_kv_lens = (
199
+ window_kv_indptr, window_kv_indices, window_kv_lens, _ = (
190
200
  update_sliding_window_buffer(
191
201
  self.window_kv_indptr,
192
202
  self.req_to_token,
@@ -236,7 +246,7 @@ class TritonAttnBackend(AttentionBackend):
236
246
  kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
237
247
  kv_indptr = kv_indptr[: bs + 1]
238
248
  kv_indices = torch.empty(
239
- kv_indptr[-1], dtype=torch.int32, device=self.device
249
+ kv_indptr[-1], dtype=torch.int64, device=self.device
240
250
  )
241
251
  create_flashinfer_kv_indices_triton[(bs,)](
242
252
  self.req_to_token,
@@ -249,17 +259,21 @@ class TritonAttnBackend(AttentionBackend):
249
259
  )
250
260
 
251
261
  if self.sliding_window_size is not None and self.sliding_window_size > 0:
252
- window_kv_indptr, window_kv_indices, window_kv_lens = (
253
- update_sliding_window_buffer(
254
- self.window_kv_indptr,
255
- self.req_to_token,
256
- self.sliding_window_size,
257
- forward_batch.seq_lens,
258
- forward_batch.req_pool_indices,
259
- bs,
260
- self.device,
261
- self.token_to_kv_pool_allocator,
262
- )
262
+ # window_kv_offsets is used to calculate the start position in custom mask
263
+ (
264
+ window_kv_indptr,
265
+ window_kv_indices,
266
+ window_kv_lens,
267
+ window_kv_offsets,
268
+ ) = update_sliding_window_buffer(
269
+ self.window_kv_indptr,
270
+ self.req_to_token,
271
+ self.sliding_window_size,
272
+ forward_batch.seq_lens,
273
+ forward_batch.req_pool_indices,
274
+ bs,
275
+ self.device,
276
+ self.token_to_kv_pool_allocator,
263
277
  )
264
278
 
265
279
  custom_mask = spec_info.custom_mask
@@ -283,6 +297,7 @@ class TritonAttnBackend(AttentionBackend):
283
297
  self.req_to_token,
284
298
  )
285
299
  )
300
+ kv_indices = kv_indices.to(torch.int64)
286
301
  mask_indptr = None
287
302
  # TODO(FIXME): This will trigger an invalid Eagle tree when using
288
303
  # `max(spec_info.accept_length_cpu)`.
@@ -298,7 +313,7 @@ class TritonAttnBackend(AttentionBackend):
298
313
  kv_indptr = kv_indptr[: bs + 1]
299
314
  kv_indices = torch.empty(
300
315
  forward_batch.extend_prefix_lens.sum().item(),
301
- dtype=torch.int32,
316
+ dtype=torch.int64,
302
317
  device=self.device,
303
318
  )
304
319
  create_flashinfer_kv_indices_triton[(bs,)](
@@ -312,15 +327,17 @@ class TritonAttnBackend(AttentionBackend):
312
327
  )
313
328
  # Sliding window
314
329
  if self.sliding_window_size is not None and self.sliding_window_size > 0:
315
- window_kv_indptr, window_kv_indices, _ = update_sliding_window_buffer(
316
- self.window_kv_indptr,
317
- self.req_to_token,
318
- self.sliding_window_size,
319
- forward_batch.extend_prefix_lens,
320
- forward_batch.req_pool_indices,
321
- bs,
322
- self.device,
323
- self.token_to_kv_pool_allocator,
330
+ window_kv_indptr, window_kv_indices, _, _ = (
331
+ update_sliding_window_buffer(
332
+ self.window_kv_indptr,
333
+ self.req_to_token,
334
+ self.sliding_window_size,
335
+ forward_batch.extend_prefix_lens,
336
+ forward_batch.req_pool_indices,
337
+ bs,
338
+ self.device,
339
+ self.token_to_kv_pool_allocator,
340
+ )
324
341
  )
325
342
 
326
343
  qo_indptr = self.qo_indptr
@@ -346,6 +363,7 @@ class TritonAttnBackend(AttentionBackend):
346
363
  window_kv_indptr,
347
364
  window_kv_indices,
348
365
  window_num_kv_splits,
366
+ window_kv_offsets,
349
367
  )
350
368
 
351
369
  def init_cuda_graph_state(
@@ -370,7 +388,7 @@ class TritonAttnBackend(AttentionBackend):
370
388
  if kv_indices_buf is None:
371
389
  self.cuda_graph_kv_indices = torch.zeros(
372
390
  (max_num_tokens * self.max_context_len),
373
- dtype=torch.int32,
391
+ dtype=torch.int64,
374
392
  device=self.device,
375
393
  )
376
394
  else:
@@ -387,7 +405,7 @@ class TritonAttnBackend(AttentionBackend):
387
405
  if kv_indices_buf is None:
388
406
  self.cuda_graph_window_kv_indices = torch.zeros(
389
407
  (max_num_tokens * self.sliding_window_size),
390
- dtype=torch.int32,
408
+ dtype=torch.int64,
391
409
  device=self.device,
392
410
  )
393
411
  else:
@@ -400,6 +418,12 @@ class TritonAttnBackend(AttentionBackend):
400
418
  device=self.device,
401
419
  )
402
420
 
421
+ self.cuda_graph_window_kv_offsets = torch.zeros(
422
+ (max_bs,),
423
+ dtype=torch.int32,
424
+ device=self.device,
425
+ )
426
+
403
427
  def init_forward_metadata_capture_cuda_graph(
404
428
  self,
405
429
  bs: int,
@@ -414,6 +438,7 @@ class TritonAttnBackend(AttentionBackend):
414
438
  window_kv_indptr = self.window_kv_indptr
415
439
  window_kv_indices = None
416
440
  window_num_kv_splits = None
441
+ window_kv_offsets = None
417
442
 
418
443
  if forward_mode.is_decode_or_idle():
419
444
  if spec_info is None:
@@ -436,7 +461,7 @@ class TritonAttnBackend(AttentionBackend):
436
461
  ):
437
462
  window_kv_indices = self.cuda_graph_window_kv_indices
438
463
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
439
- window_kv_indptr, window_kv_indices, _ = (
464
+ window_kv_indptr, window_kv_indices, _, _ = (
440
465
  update_sliding_window_buffer_cuda_graph(
441
466
  self.window_kv_indptr,
442
467
  window_kv_indices,
@@ -483,13 +508,14 @@ class TritonAttnBackend(AttentionBackend):
483
508
  if self.sliding_window_size is not None and self.sliding_window_size > 0:
484
509
  window_kv_indices = self.cuda_graph_window_kv_indices
485
510
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
486
- window_kv_indptr, window_kv_indices, _ = (
511
+ window_kv_offsets = self.cuda_graph_window_kv_offsets
512
+ window_kv_indptr, window_kv_indices, _, window_kv_offsets[:bs] = (
487
513
  update_sliding_window_buffer_cuda_graph(
488
514
  self.window_kv_indptr,
489
515
  window_kv_indices,
490
516
  self.req_to_token,
491
517
  self.sliding_window_size,
492
- seq_lens,
518
+ seq_lens[:bs],
493
519
  req_pool_indices,
494
520
  bs,
495
521
  self.token_to_kv_pool_allocator,
@@ -551,6 +577,7 @@ class TritonAttnBackend(AttentionBackend):
551
577
  window_kv_indptr,
552
578
  window_kv_indices,
553
579
  window_num_kv_splits,
580
+ window_kv_offsets,
554
581
  )
555
582
 
556
583
  def init_forward_metadata_replay_cuda_graph(
@@ -589,7 +616,7 @@ class TritonAttnBackend(AttentionBackend):
589
616
  ):
590
617
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
591
618
  window_kv_indices = self.cuda_graph_window_kv_indices
592
- _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
619
+ _, _, window_kv_lens, _ = update_sliding_window_buffer_cuda_graph(
593
620
  self.window_kv_indptr,
594
621
  window_kv_indices,
595
622
  self.req_to_token,
@@ -635,15 +662,18 @@ class TritonAttnBackend(AttentionBackend):
635
662
  if self.sliding_window_size is not None and self.sliding_window_size > 0:
636
663
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
637
664
  window_kv_indices = self.cuda_graph_window_kv_indices
638
- _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
639
- self.window_kv_indptr,
640
- window_kv_indices,
641
- self.req_to_token,
642
- self.sliding_window_size,
643
- seq_lens,
644
- req_pool_indices,
645
- bs,
646
- self.token_to_kv_pool_allocator,
665
+ window_kv_offsets = self.cuda_graph_window_kv_offsets
666
+ _, _, window_kv_lens, window_kv_offsets[:bs] = (
667
+ update_sliding_window_buffer_cuda_graph(
668
+ self.window_kv_indptr,
669
+ window_kv_indices,
670
+ self.req_to_token,
671
+ self.sliding_window_size,
672
+ seq_lens[:bs],
673
+ req_pool_indices,
674
+ bs,
675
+ self.token_to_kv_pool_allocator,
676
+ )
647
677
  )
648
678
  custom_mask = self.cuda_graph_custom_mask
649
679
  custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
@@ -696,6 +726,8 @@ class TritonAttnBackend(AttentionBackend):
696
726
  layer, forward_batch.out_cache_loc, k, v
697
727
  )
698
728
 
729
+ logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
730
+
699
731
  causal = True
700
732
  if layer.attn_type == AttentionType.ENCODER_ONLY:
701
733
  causal = False
@@ -706,10 +738,12 @@ class TritonAttnBackend(AttentionBackend):
706
738
  ) # Needed for sliding window mask
707
739
  kv_indptr = self.forward_metadata.window_kv_indptr
708
740
  kv_indices = self.forward_metadata.window_kv_indices
741
+ window_kv_offsets = self.forward_metadata.window_kv_offsets
709
742
  else:
710
743
  sliding_window_size = -1
711
744
  kv_indptr = self.forward_metadata.kv_indptr
712
745
  kv_indices = self.forward_metadata.kv_indices
746
+ window_kv_offsets = None
713
747
 
714
748
  self.extend_attention_fwd(
715
749
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
@@ -726,9 +760,11 @@ class TritonAttnBackend(AttentionBackend):
726
760
  self.forward_metadata.mask_indptr,
727
761
  self.forward_metadata.max_extend_len,
728
762
  layer.scaling,
729
- layer.logit_cap,
763
+ logit_cap=logits_soft_cap,
730
764
  sliding_window_size=sliding_window_size,
731
765
  sinks=sinks,
766
+ window_kv_offsets=window_kv_offsets,
767
+ xai_temperature_len=layer.xai_temperature_len,
732
768
  )
733
769
  return o
734
770
 
@@ -752,6 +788,8 @@ class TritonAttnBackend(AttentionBackend):
752
788
  else:
753
789
  o = torch.empty_like(q)
754
790
 
791
+ logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
792
+
755
793
  if save_kv_cache:
756
794
  forward_batch.token_to_kv_pool.set_kv_buffer(
757
795
  layer, forward_batch.out_cache_loc, k, v
@@ -776,8 +814,9 @@ class TritonAttnBackend(AttentionBackend):
776
814
  self.forward_metadata.num_kv_splits,
777
815
  self.max_kv_splits,
778
816
  layer.scaling,
779
- layer.logit_cap,
817
+ logit_cap=logits_soft_cap,
780
818
  sinks=sinks,
819
+ xai_temperature_len=layer.xai_temperature_len,
781
820
  )
782
821
  return o
783
822
 
@@ -864,7 +903,7 @@ class TritonMultiStepDraftBackend:
864
903
  self.speculative_num_steps,
865
904
  forward_batch.batch_size * self.topk * self.max_context_len,
866
905
  ),
867
- dtype=torch.int32,
906
+ dtype=torch.int64,
868
907
  device=self.device,
869
908
  )
870
909
 
@@ -882,7 +921,7 @@ class TritonMultiStepDraftBackend:
882
921
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
883
922
  self.cuda_graph_kv_indices = torch.zeros(
884
923
  (self.speculative_num_steps, max_num_tokens * self.max_context_len),
885
- dtype=torch.int32,
924
+ dtype=torch.int64,
886
925
  device=self.device,
887
926
  )
888
927
  for i in range(self.speculative_num_steps):
@@ -991,7 +1030,7 @@ def update_sliding_window_buffer(
991
1030
  window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
992
1031
  window_kv_indptr = window_kv_indptr[: bs + 1]
993
1032
  window_kv_indices = torch.empty(
994
- window_kv_indptr[-1], dtype=torch.int32, device=device
1033
+ window_kv_indptr[-1], dtype=torch.int64, device=device
995
1034
  )
996
1035
  window_kv_start_idx = seq_lens - window_kv_lens
997
1036
  create_flashinfer_kv_indices_triton[(bs,)](
@@ -1011,7 +1050,7 @@ def update_sliding_window_buffer(
1011
1050
  window_kv_indices[:kv_last_index]
1012
1051
  )
1013
1052
  )
1014
- return window_kv_indptr, window_kv_indices, window_kv_lens
1053
+ return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
1015
1054
 
1016
1055
 
1017
1056
  def update_sliding_window_buffer_cuda_graph(
@@ -1048,4 +1087,4 @@ def update_sliding_window_buffer_cuda_graph(
1048
1087
  window_kv_indices[:kv_last_index]
1049
1088
  )
1050
1089
  )
1051
- return window_kv_indptr, window_kv_indices, window_kv_lens
1090
+ return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
@@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
69
69
  logit_cap: tl.constexpr,
70
70
  Lk: tl.constexpr,
71
71
  Lv: tl.constexpr,
72
+ xai_temperature_len: tl.constexpr,
72
73
  ):
73
74
  cur_batch = tl.program_id(0)
74
75
  cur_head = tl.program_id(1)
@@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
85
86
  cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
86
87
  kv_splits = tl.load(num_kv_splits + cur_batch)
87
88
 
89
+ if xai_temperature_len > 0:
90
+ offs_qidx = cur_batch_seq_len - 1
91
+ xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
92
+ _qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
93
+ xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
94
+
88
95
  off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
89
96
 
90
97
  kv_len_per_split = (
@@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
122
129
  if logit_cap > 0:
123
130
  qk = logit_cap * tanh(qk / logit_cap)
124
131
 
132
+ if xai_temperature_len > 0:
133
+ qk *= xai_temperature_reg
134
+
125
135
  qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
126
136
 
127
137
  offs_buf_v = (
@@ -181,6 +191,7 @@ def _decode_att_m_fwd(
181
191
  max_kv_splits,
182
192
  sm_scale,
183
193
  logit_cap,
194
+ xai_temperature_len=-1,
184
195
  ):
185
196
  BLOCK = 64
186
197
  # [TODO] work around SGPR limit on MI3xx
@@ -190,7 +201,7 @@ def _decode_att_m_fwd(
190
201
  Lk = k_buffer.shape[-1]
191
202
  Lv = v_buffer.shape[-1]
192
203
 
193
- batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
204
+ batch, head_num = q.shape[0], q.shape[1]
194
205
 
195
206
  grid = (batch, head_num, MAX_KV_SPLITS)
196
207
  kv_group_num = q.shape[1] // k_buffer.shape[1]
@@ -230,6 +241,7 @@ def _decode_att_m_fwd(
230
241
  BLOCK_N=BLOCK,
231
242
  MIN_BLOCK_KV=_MIN_BLOCK_KV,
232
243
  logit_cap=logit_cap,
244
+ xai_temperature_len=xai_temperature_len,
233
245
  num_warps=num_warps,
234
246
  num_stages=2,
235
247
  Lk=Lk,
@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
266
278
  BLOCK_H: tl.constexpr,
267
279
  MIN_BLOCK_KV: tl.constexpr,
268
280
  logit_cap: tl.constexpr,
281
+ xai_temperature_len: tl.constexpr,
269
282
  Lk: tl.constexpr,
270
283
  Lv: tl.constexpr,
271
284
  ):
@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
291
304
  cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
292
305
  kv_splits = tl.load(num_kv_splits + cur_batch)
293
306
 
307
+ if xai_temperature_len > 0:
308
+ offs_qidx = cur_batch_seq_len - 1
309
+ xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
310
+ _qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
311
+ xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
312
+
294
313
  offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
295
314
 
296
315
  if BLOCK_DPE > 0:
@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
351
370
  if logit_cap > 0:
352
371
  qk = logit_cap * tanh(qk / logit_cap)
353
372
 
373
+ if xai_temperature_len > 0:
374
+ qk *= xai_temperature_reg[:, None]
375
+
354
376
  qk = tl.where(
355
377
  mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
356
378
  )
@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
413
435
  max_kv_splits,
414
436
  sm_scale,
415
437
  logit_cap,
438
+ xai_temperature_len=-1,
416
439
  ):
417
440
  BLOCK = 32
418
441
  Lk = k_buffer.shape[-1]
@@ -433,7 +456,7 @@ def _decode_grouped_att_m_fwd(
433
456
  BLOCK_DPE = 0
434
457
  BLOCK_DV = triton.next_power_of_2(Lv)
435
458
 
436
- batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
459
+ batch, head_num = q.shape[0], q.shape[1]
437
460
  kv_group_num = q.shape[1] // k_buffer.shape[1]
438
461
 
439
462
  BLOCK_H = 16
@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
480
503
  BLOCK_H=BLOCK_H,
481
504
  MIN_BLOCK_KV=_MIN_BLOCK_KV,
482
505
  logit_cap=logit_cap,
506
+ xai_temperature_len=xai_temperature_len,
483
507
  num_warps=4,
484
508
  num_stages=num_stages,
485
509
  Lk=Lk,
@@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
620
644
  sm_scale,
621
645
  logit_cap=0.0,
622
646
  sinks=None,
647
+ xai_temperature_len=-1,
623
648
  ):
624
649
  _decode_att_m_fwd(
625
650
  q,
@@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
633
658
  max_kv_splits,
634
659
  sm_scale,
635
660
  logit_cap,
661
+ xai_temperature_len,
636
662
  )
637
663
  _decode_softmax_reducev_fwd(
638
664
  attn_logits,
@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
661
687
  sm_scale,
662
688
  logit_cap=0.0,
663
689
  sinks=None,
690
+ xai_temperature_len=-1,
664
691
  ):
665
692
  _decode_grouped_att_m_fwd(
666
693
  q,
@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
674
701
  max_kv_splits,
675
702
  sm_scale,
676
703
  logit_cap,
704
+ xai_temperature_len,
677
705
  )
678
706
  _decode_softmax_reducev_fwd(
679
707
  attn_logits,
@@ -702,6 +730,7 @@ def decode_attention_fwd(
702
730
  sm_scale,
703
731
  logit_cap=0.0,
704
732
  sinks=None,
733
+ xai_temperature_len=-1,
705
734
  ):
706
735
  assert max_kv_splits == attn_logits.shape[2]
707
736
  assert q.shape[0] <= kv_indptr.shape[0] - 1
@@ -725,6 +754,7 @@ def decode_attention_fwd(
725
754
  sm_scale,
726
755
  logit_cap=logit_cap,
727
756
  sinks=sinks,
757
+ xai_temperature_len=xai_temperature_len,
728
758
  )
729
759
  else:
730
760
  # GQA/MQA/MLA
@@ -742,4 +772,5 @@ def decode_attention_fwd(
742
772
  sm_scale,
743
773
  logit_cap=logit_cap,
744
774
  sinks=sinks,
775
+ xai_temperature_len=xai_temperature_len,
745
776
  )
@@ -52,6 +52,7 @@ def _fwd_kernel(
52
52
  mask_ptr,
53
53
  mask_indptr,
54
54
  sink_ptr,
55
+ window_kv_offset_ptr,
55
56
  sm_scale,
56
57
  kv_group_num,
57
58
  stride_qbs,
@@ -68,6 +69,7 @@ def _fwd_kernel(
68
69
  stride_buf_vh,
69
70
  SLIDING_WINDOW_SIZE: tl.constexpr,
70
71
  logit_cap: tl.constexpr,
72
+ xai_temperature_len: tl.constexpr,
71
73
  Lq: tl.constexpr,
72
74
  Lv: tl.constexpr,
73
75
  BLOCK_DMODEL: tl.constexpr,
@@ -95,6 +97,11 @@ def _fwd_kernel(
95
97
  if USE_CUSTOM_MASK:
96
98
  cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
97
99
 
100
+ # For SWA, we should only load the mask in the sliding window
101
+ window_kv_offset = 0
102
+ if USE_CUSTOM_MASK and SLIDING_WINDOW_SIZE > 0:
103
+ window_kv_offset = tl.load(window_kv_offset_ptr + cur_seq)
104
+
98
105
  offs_d = tl.arange(0, BLOCK_DMODEL)
99
106
  offs_dv = tl.arange(0, BLOCK_DV)
100
107
  offs_m = tl.arange(0, BLOCK_M)
@@ -103,6 +110,15 @@ def _fwd_kernel(
103
110
  mask_d = offs_d < Lq
104
111
  mask_dv = offs_dv < Lv
105
112
 
113
+ if xai_temperature_len > 0:
114
+ offs_qidx = cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m
115
+ xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
116
+ xai_temperature_reg = tl.where(
117
+ offs_qidx > xai_temperature_len,
118
+ tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale,
119
+ 1.0,
120
+ )
121
+
106
122
  offs_q = (
107
123
  (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
108
124
  * stride_qbs
@@ -139,7 +155,9 @@ def _fwd_kernel(
139
155
  custom_mask = tl.load(
140
156
  mask_ptr
141
157
  + cur_seq_mask_start_idx
142
- + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
158
+ + (cur_block_m * BLOCK_M + offs_m[:, None])
159
+ * (cur_seq_len + window_kv_offset)
160
+ + window_kv_offset
143
161
  + start_n
144
162
  + offs_n[None, :],
145
163
  mask=(mask_m[:, None] & mask_n[None, :]),
@@ -195,6 +213,9 @@ def _fwd_kernel(
195
213
  if logit_cap > 0:
196
214
  qk = logit_cap * tanh(qk / logit_cap)
197
215
 
216
+ if xai_temperature_len > 0:
217
+ qk *= xai_temperature_reg[:, None]
218
+
198
219
  qk = tl.where(final_mask, qk, float("-inf"))
199
220
 
200
221
  row_max = tl.max(qk, 1)
@@ -236,7 +257,9 @@ def _fwd_kernel(
236
257
  custom_mask = tl.load(
237
258
  mask_ptr
238
259
  + cur_seq_mask_start_idx
239
- + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
260
+ + (cur_block_m * BLOCK_M + offs_m[:, None])
261
+ * (cur_seq_len + window_kv_offset)
262
+ + window_kv_offset
240
263
  + cur_seq_len_prefix
241
264
  + start_n
242
265
  + offs_n[None, :],
@@ -296,6 +319,9 @@ def _fwd_kernel(
296
319
  if logit_cap > 0:
297
320
  qk = logit_cap * tanh(qk / logit_cap)
298
321
 
322
+ if xai_temperature_len > 0:
323
+ qk *= xai_temperature_reg[:, None]
324
+
299
325
  qk = tl.where(final_mask, qk, float("-inf"))
300
326
 
301
327
  row_max = tl.max(qk, 1)
@@ -362,6 +388,8 @@ def extend_attention_fwd(
362
388
  skip_prefix_custom_mask=True,
363
389
  sliding_window_size=-1,
364
390
  sinks=None,
391
+ window_kv_offsets=None,
392
+ xai_temperature_len=-1,
365
393
  ):
366
394
  """
367
395
  q_extend, k_extend, v_extend, o_extend: contiguous tensors
@@ -449,6 +477,7 @@ def extend_attention_fwd(
449
477
  custom_mask,
450
478
  mask_indptr,
451
479
  sinks,
480
+ window_kv_offsets,
452
481
  sm_scale,
453
482
  kv_group_num,
454
483
  q_extend.stride(0),
@@ -465,6 +494,7 @@ def extend_attention_fwd(
465
494
  v_buffer.stride(1),
466
495
  SLIDING_WINDOW_SIZE=sliding_window_size,
467
496
  logit_cap=logit_cap,
497
+ xai_temperature_len=xai_temperature_len,
468
498
  BLOCK_DMODEL=BLOCK_DMODEL,
469
499
  BLOCK_DPE=BLOCK_DPE,
470
500
  BLOCK_DV=BLOCK_DV,