sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +375 -51
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,14 @@ if TYPE_CHECKING:
|
|
20
20
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
21
21
|
|
22
22
|
|
23
|
+
def logit_capping_mod(logit_capping_method, logit_cap):
|
24
|
+
# positive logit_cap -> tanh cap
|
25
|
+
if logit_capping_method == "tanh":
|
26
|
+
return logit_cap
|
27
|
+
else:
|
28
|
+
raise ValueError()
|
29
|
+
|
30
|
+
|
23
31
|
@dataclass
|
24
32
|
class ForwardMetadata:
|
25
33
|
attn_logits: torch.Tensor
|
@@ -35,6 +43,7 @@ class ForwardMetadata:
|
|
35
43
|
window_kv_indptr: torch.Tensor
|
36
44
|
window_kv_indices: torch.Tensor
|
37
45
|
window_num_kv_splits: torch.Tensor
|
46
|
+
window_kv_offsets: torch.Tensor
|
38
47
|
|
39
48
|
|
40
49
|
class TritonAttnBackend(AttentionBackend):
|
@@ -163,6 +172,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
163
172
|
window_kv_indptr = self.window_kv_indptr
|
164
173
|
window_kv_indices = None
|
165
174
|
window_num_kv_splits = None
|
175
|
+
window_kv_offsets = None
|
166
176
|
spec_info = forward_batch.spec_info
|
167
177
|
|
168
178
|
if forward_batch.forward_mode.is_decode_or_idle():
|
@@ -170,7 +180,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
170
180
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
171
181
|
kv_indptr = kv_indptr[: bs + 1]
|
172
182
|
kv_indices = torch.empty(
|
173
|
-
forward_batch.seq_lens_sum, dtype=torch.
|
183
|
+
forward_batch.seq_lens_sum, dtype=torch.int64, device=self.device
|
174
184
|
)
|
175
185
|
create_flashinfer_kv_indices_triton[(bs,)](
|
176
186
|
self.req_to_token,
|
@@ -186,7 +196,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
186
196
|
self.sliding_window_size is not None
|
187
197
|
and self.sliding_window_size > 0
|
188
198
|
):
|
189
|
-
window_kv_indptr, window_kv_indices, window_kv_lens = (
|
199
|
+
window_kv_indptr, window_kv_indices, window_kv_lens, _ = (
|
190
200
|
update_sliding_window_buffer(
|
191
201
|
self.window_kv_indptr,
|
192
202
|
self.req_to_token,
|
@@ -236,7 +246,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
236
246
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
237
247
|
kv_indptr = kv_indptr[: bs + 1]
|
238
248
|
kv_indices = torch.empty(
|
239
|
-
kv_indptr[-1], dtype=torch.
|
249
|
+
kv_indptr[-1], dtype=torch.int64, device=self.device
|
240
250
|
)
|
241
251
|
create_flashinfer_kv_indices_triton[(bs,)](
|
242
252
|
self.req_to_token,
|
@@ -249,17 +259,21 @@ class TritonAttnBackend(AttentionBackend):
|
|
249
259
|
)
|
250
260
|
|
251
261
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
262
|
+
# window_kv_offsets is used to calculate the start position in custom mask
|
263
|
+
(
|
264
|
+
window_kv_indptr,
|
265
|
+
window_kv_indices,
|
266
|
+
window_kv_lens,
|
267
|
+
window_kv_offsets,
|
268
|
+
) = update_sliding_window_buffer(
|
269
|
+
self.window_kv_indptr,
|
270
|
+
self.req_to_token,
|
271
|
+
self.sliding_window_size,
|
272
|
+
forward_batch.seq_lens,
|
273
|
+
forward_batch.req_pool_indices,
|
274
|
+
bs,
|
275
|
+
self.device,
|
276
|
+
self.token_to_kv_pool_allocator,
|
263
277
|
)
|
264
278
|
|
265
279
|
custom_mask = spec_info.custom_mask
|
@@ -283,6 +297,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
283
297
|
self.req_to_token,
|
284
298
|
)
|
285
299
|
)
|
300
|
+
kv_indices = kv_indices.to(torch.int64)
|
286
301
|
mask_indptr = None
|
287
302
|
# TODO(FIXME): This will trigger an invalid Eagle tree when using
|
288
303
|
# `max(spec_info.accept_length_cpu)`.
|
@@ -298,7 +313,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
298
313
|
kv_indptr = kv_indptr[: bs + 1]
|
299
314
|
kv_indices = torch.empty(
|
300
315
|
forward_batch.extend_prefix_lens.sum().item(),
|
301
|
-
dtype=torch.
|
316
|
+
dtype=torch.int64,
|
302
317
|
device=self.device,
|
303
318
|
)
|
304
319
|
create_flashinfer_kv_indices_triton[(bs,)](
|
@@ -312,15 +327,17 @@ class TritonAttnBackend(AttentionBackend):
|
|
312
327
|
)
|
313
328
|
# Sliding window
|
314
329
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
315
|
-
window_kv_indptr, window_kv_indices, _ =
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
330
|
+
window_kv_indptr, window_kv_indices, _, _ = (
|
331
|
+
update_sliding_window_buffer(
|
332
|
+
self.window_kv_indptr,
|
333
|
+
self.req_to_token,
|
334
|
+
self.sliding_window_size,
|
335
|
+
forward_batch.extend_prefix_lens,
|
336
|
+
forward_batch.req_pool_indices,
|
337
|
+
bs,
|
338
|
+
self.device,
|
339
|
+
self.token_to_kv_pool_allocator,
|
340
|
+
)
|
324
341
|
)
|
325
342
|
|
326
343
|
qo_indptr = self.qo_indptr
|
@@ -346,6 +363,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
346
363
|
window_kv_indptr,
|
347
364
|
window_kv_indices,
|
348
365
|
window_num_kv_splits,
|
366
|
+
window_kv_offsets,
|
349
367
|
)
|
350
368
|
|
351
369
|
def init_cuda_graph_state(
|
@@ -370,7 +388,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
370
388
|
if kv_indices_buf is None:
|
371
389
|
self.cuda_graph_kv_indices = torch.zeros(
|
372
390
|
(max_num_tokens * self.max_context_len),
|
373
|
-
dtype=torch.
|
391
|
+
dtype=torch.int64,
|
374
392
|
device=self.device,
|
375
393
|
)
|
376
394
|
else:
|
@@ -387,7 +405,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
387
405
|
if kv_indices_buf is None:
|
388
406
|
self.cuda_graph_window_kv_indices = torch.zeros(
|
389
407
|
(max_num_tokens * self.sliding_window_size),
|
390
|
-
dtype=torch.
|
408
|
+
dtype=torch.int64,
|
391
409
|
device=self.device,
|
392
410
|
)
|
393
411
|
else:
|
@@ -400,6 +418,12 @@ class TritonAttnBackend(AttentionBackend):
|
|
400
418
|
device=self.device,
|
401
419
|
)
|
402
420
|
|
421
|
+
self.cuda_graph_window_kv_offsets = torch.zeros(
|
422
|
+
(max_bs,),
|
423
|
+
dtype=torch.int32,
|
424
|
+
device=self.device,
|
425
|
+
)
|
426
|
+
|
403
427
|
def init_forward_metadata_capture_cuda_graph(
|
404
428
|
self,
|
405
429
|
bs: int,
|
@@ -414,6 +438,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
414
438
|
window_kv_indptr = self.window_kv_indptr
|
415
439
|
window_kv_indices = None
|
416
440
|
window_num_kv_splits = None
|
441
|
+
window_kv_offsets = None
|
417
442
|
|
418
443
|
if forward_mode.is_decode_or_idle():
|
419
444
|
if spec_info is None:
|
@@ -436,7 +461,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
436
461
|
):
|
437
462
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
438
463
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
439
|
-
window_kv_indptr, window_kv_indices, _ = (
|
464
|
+
window_kv_indptr, window_kv_indices, _, _ = (
|
440
465
|
update_sliding_window_buffer_cuda_graph(
|
441
466
|
self.window_kv_indptr,
|
442
467
|
window_kv_indices,
|
@@ -483,13 +508,14 @@ class TritonAttnBackend(AttentionBackend):
|
|
483
508
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
484
509
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
485
510
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
486
|
-
|
511
|
+
window_kv_offsets = self.cuda_graph_window_kv_offsets
|
512
|
+
window_kv_indptr, window_kv_indices, _, window_kv_offsets[:bs] = (
|
487
513
|
update_sliding_window_buffer_cuda_graph(
|
488
514
|
self.window_kv_indptr,
|
489
515
|
window_kv_indices,
|
490
516
|
self.req_to_token,
|
491
517
|
self.sliding_window_size,
|
492
|
-
seq_lens,
|
518
|
+
seq_lens[:bs],
|
493
519
|
req_pool_indices,
|
494
520
|
bs,
|
495
521
|
self.token_to_kv_pool_allocator,
|
@@ -551,6 +577,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
551
577
|
window_kv_indptr,
|
552
578
|
window_kv_indices,
|
553
579
|
window_num_kv_splits,
|
580
|
+
window_kv_offsets,
|
554
581
|
)
|
555
582
|
|
556
583
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -589,7 +616,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
589
616
|
):
|
590
617
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
591
618
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
592
|
-
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
619
|
+
_, _, window_kv_lens, _ = update_sliding_window_buffer_cuda_graph(
|
593
620
|
self.window_kv_indptr,
|
594
621
|
window_kv_indices,
|
595
622
|
self.req_to_token,
|
@@ -635,15 +662,18 @@ class TritonAttnBackend(AttentionBackend):
|
|
635
662
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
636
663
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
637
664
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
665
|
+
window_kv_offsets = self.cuda_graph_window_kv_offsets
|
666
|
+
_, _, window_kv_lens, window_kv_offsets[:bs] = (
|
667
|
+
update_sliding_window_buffer_cuda_graph(
|
668
|
+
self.window_kv_indptr,
|
669
|
+
window_kv_indices,
|
670
|
+
self.req_to_token,
|
671
|
+
self.sliding_window_size,
|
672
|
+
seq_lens[:bs],
|
673
|
+
req_pool_indices,
|
674
|
+
bs,
|
675
|
+
self.token_to_kv_pool_allocator,
|
676
|
+
)
|
647
677
|
)
|
648
678
|
custom_mask = self.cuda_graph_custom_mask
|
649
679
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
@@ -696,6 +726,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
696
726
|
layer, forward_batch.out_cache_loc, k, v
|
697
727
|
)
|
698
728
|
|
729
|
+
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
|
730
|
+
|
699
731
|
causal = True
|
700
732
|
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
701
733
|
causal = False
|
@@ -706,10 +738,12 @@ class TritonAttnBackend(AttentionBackend):
|
|
706
738
|
) # Needed for sliding window mask
|
707
739
|
kv_indptr = self.forward_metadata.window_kv_indptr
|
708
740
|
kv_indices = self.forward_metadata.window_kv_indices
|
741
|
+
window_kv_offsets = self.forward_metadata.window_kv_offsets
|
709
742
|
else:
|
710
743
|
sliding_window_size = -1
|
711
744
|
kv_indptr = self.forward_metadata.kv_indptr
|
712
745
|
kv_indices = self.forward_metadata.kv_indices
|
746
|
+
window_kv_offsets = None
|
713
747
|
|
714
748
|
self.extend_attention_fwd(
|
715
749
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
@@ -726,9 +760,11 @@ class TritonAttnBackend(AttentionBackend):
|
|
726
760
|
self.forward_metadata.mask_indptr,
|
727
761
|
self.forward_metadata.max_extend_len,
|
728
762
|
layer.scaling,
|
729
|
-
|
763
|
+
logit_cap=logits_soft_cap,
|
730
764
|
sliding_window_size=sliding_window_size,
|
731
765
|
sinks=sinks,
|
766
|
+
window_kv_offsets=window_kv_offsets,
|
767
|
+
xai_temperature_len=layer.xai_temperature_len,
|
732
768
|
)
|
733
769
|
return o
|
734
770
|
|
@@ -752,6 +788,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
752
788
|
else:
|
753
789
|
o = torch.empty_like(q)
|
754
790
|
|
791
|
+
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
|
792
|
+
|
755
793
|
if save_kv_cache:
|
756
794
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
757
795
|
layer, forward_batch.out_cache_loc, k, v
|
@@ -776,8 +814,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
776
814
|
self.forward_metadata.num_kv_splits,
|
777
815
|
self.max_kv_splits,
|
778
816
|
layer.scaling,
|
779
|
-
|
817
|
+
logit_cap=logits_soft_cap,
|
780
818
|
sinks=sinks,
|
819
|
+
xai_temperature_len=layer.xai_temperature_len,
|
781
820
|
)
|
782
821
|
return o
|
783
822
|
|
@@ -864,7 +903,7 @@ class TritonMultiStepDraftBackend:
|
|
864
903
|
self.speculative_num_steps,
|
865
904
|
forward_batch.batch_size * self.topk * self.max_context_len,
|
866
905
|
),
|
867
|
-
dtype=torch.
|
906
|
+
dtype=torch.int64,
|
868
907
|
device=self.device,
|
869
908
|
)
|
870
909
|
|
@@ -882,7 +921,7 @@ class TritonMultiStepDraftBackend:
|
|
882
921
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
883
922
|
self.cuda_graph_kv_indices = torch.zeros(
|
884
923
|
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
|
885
|
-
dtype=torch.
|
924
|
+
dtype=torch.int64,
|
886
925
|
device=self.device,
|
887
926
|
)
|
888
927
|
for i in range(self.speculative_num_steps):
|
@@ -991,7 +1030,7 @@ def update_sliding_window_buffer(
|
|
991
1030
|
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
992
1031
|
window_kv_indptr = window_kv_indptr[: bs + 1]
|
993
1032
|
window_kv_indices = torch.empty(
|
994
|
-
window_kv_indptr[-1], dtype=torch.
|
1033
|
+
window_kv_indptr[-1], dtype=torch.int64, device=device
|
995
1034
|
)
|
996
1035
|
window_kv_start_idx = seq_lens - window_kv_lens
|
997
1036
|
create_flashinfer_kv_indices_triton[(bs,)](
|
@@ -1011,7 +1050,7 @@ def update_sliding_window_buffer(
|
|
1011
1050
|
window_kv_indices[:kv_last_index]
|
1012
1051
|
)
|
1013
1052
|
)
|
1014
|
-
return window_kv_indptr, window_kv_indices, window_kv_lens
|
1053
|
+
return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
|
1015
1054
|
|
1016
1055
|
|
1017
1056
|
def update_sliding_window_buffer_cuda_graph(
|
@@ -1048,4 +1087,4 @@ def update_sliding_window_buffer_cuda_graph(
|
|
1048
1087
|
window_kv_indices[:kv_last_index]
|
1049
1088
|
)
|
1050
1089
|
)
|
1051
|
-
return window_kv_indptr, window_kv_indices, window_kv_lens
|
1090
|
+
return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
|
@@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
|
|
69
69
|
logit_cap: tl.constexpr,
|
70
70
|
Lk: tl.constexpr,
|
71
71
|
Lv: tl.constexpr,
|
72
|
+
xai_temperature_len: tl.constexpr,
|
72
73
|
):
|
73
74
|
cur_batch = tl.program_id(0)
|
74
75
|
cur_head = tl.program_id(1)
|
@@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
|
|
85
86
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
86
87
|
kv_splits = tl.load(num_kv_splits + cur_batch)
|
87
88
|
|
89
|
+
if xai_temperature_len > 0:
|
90
|
+
offs_qidx = cur_batch_seq_len - 1
|
91
|
+
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
|
92
|
+
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
|
93
|
+
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
|
94
|
+
|
88
95
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
89
96
|
|
90
97
|
kv_len_per_split = (
|
@@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
|
|
122
129
|
if logit_cap > 0:
|
123
130
|
qk = logit_cap * tanh(qk / logit_cap)
|
124
131
|
|
132
|
+
if xai_temperature_len > 0:
|
133
|
+
qk *= xai_temperature_reg
|
134
|
+
|
125
135
|
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
|
126
136
|
|
127
137
|
offs_buf_v = (
|
@@ -181,6 +191,7 @@ def _decode_att_m_fwd(
|
|
181
191
|
max_kv_splits,
|
182
192
|
sm_scale,
|
183
193
|
logit_cap,
|
194
|
+
xai_temperature_len=-1,
|
184
195
|
):
|
185
196
|
BLOCK = 64
|
186
197
|
# [TODO] work around SGPR limit on MI3xx
|
@@ -190,7 +201,7 @@ def _decode_att_m_fwd(
|
|
190
201
|
Lk = k_buffer.shape[-1]
|
191
202
|
Lv = v_buffer.shape[-1]
|
192
203
|
|
193
|
-
batch, head_num =
|
204
|
+
batch, head_num = q.shape[0], q.shape[1]
|
194
205
|
|
195
206
|
grid = (batch, head_num, MAX_KV_SPLITS)
|
196
207
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
@@ -230,6 +241,7 @@ def _decode_att_m_fwd(
|
|
230
241
|
BLOCK_N=BLOCK,
|
231
242
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
232
243
|
logit_cap=logit_cap,
|
244
|
+
xai_temperature_len=xai_temperature_len,
|
233
245
|
num_warps=num_warps,
|
234
246
|
num_stages=2,
|
235
247
|
Lk=Lk,
|
@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
|
|
266
278
|
BLOCK_H: tl.constexpr,
|
267
279
|
MIN_BLOCK_KV: tl.constexpr,
|
268
280
|
logit_cap: tl.constexpr,
|
281
|
+
xai_temperature_len: tl.constexpr,
|
269
282
|
Lk: tl.constexpr,
|
270
283
|
Lv: tl.constexpr,
|
271
284
|
):
|
@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
|
|
291
304
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
292
305
|
kv_splits = tl.load(num_kv_splits + cur_batch)
|
293
306
|
|
307
|
+
if xai_temperature_len > 0:
|
308
|
+
offs_qidx = cur_batch_seq_len - 1
|
309
|
+
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
|
310
|
+
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
|
311
|
+
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
|
312
|
+
|
294
313
|
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
295
314
|
|
296
315
|
if BLOCK_DPE > 0:
|
@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
|
|
351
370
|
if logit_cap > 0:
|
352
371
|
qk = logit_cap * tanh(qk / logit_cap)
|
353
372
|
|
373
|
+
if xai_temperature_len > 0:
|
374
|
+
qk *= xai_temperature_reg[:, None]
|
375
|
+
|
354
376
|
qk = tl.where(
|
355
377
|
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
|
356
378
|
)
|
@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
|
|
413
435
|
max_kv_splits,
|
414
436
|
sm_scale,
|
415
437
|
logit_cap,
|
438
|
+
xai_temperature_len=-1,
|
416
439
|
):
|
417
440
|
BLOCK = 32
|
418
441
|
Lk = k_buffer.shape[-1]
|
@@ -433,7 +456,7 @@ def _decode_grouped_att_m_fwd(
|
|
433
456
|
BLOCK_DPE = 0
|
434
457
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
435
458
|
|
436
|
-
batch, head_num =
|
459
|
+
batch, head_num = q.shape[0], q.shape[1]
|
437
460
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
438
461
|
|
439
462
|
BLOCK_H = 16
|
@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
|
|
480
503
|
BLOCK_H=BLOCK_H,
|
481
504
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
482
505
|
logit_cap=logit_cap,
|
506
|
+
xai_temperature_len=xai_temperature_len,
|
483
507
|
num_warps=4,
|
484
508
|
num_stages=num_stages,
|
485
509
|
Lk=Lk,
|
@@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
|
|
620
644
|
sm_scale,
|
621
645
|
logit_cap=0.0,
|
622
646
|
sinks=None,
|
647
|
+
xai_temperature_len=-1,
|
623
648
|
):
|
624
649
|
_decode_att_m_fwd(
|
625
650
|
q,
|
@@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
|
|
633
658
|
max_kv_splits,
|
634
659
|
sm_scale,
|
635
660
|
logit_cap,
|
661
|
+
xai_temperature_len,
|
636
662
|
)
|
637
663
|
_decode_softmax_reducev_fwd(
|
638
664
|
attn_logits,
|
@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
|
|
661
687
|
sm_scale,
|
662
688
|
logit_cap=0.0,
|
663
689
|
sinks=None,
|
690
|
+
xai_temperature_len=-1,
|
664
691
|
):
|
665
692
|
_decode_grouped_att_m_fwd(
|
666
693
|
q,
|
@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
|
|
674
701
|
max_kv_splits,
|
675
702
|
sm_scale,
|
676
703
|
logit_cap,
|
704
|
+
xai_temperature_len,
|
677
705
|
)
|
678
706
|
_decode_softmax_reducev_fwd(
|
679
707
|
attn_logits,
|
@@ -702,6 +730,7 @@ def decode_attention_fwd(
|
|
702
730
|
sm_scale,
|
703
731
|
logit_cap=0.0,
|
704
732
|
sinks=None,
|
733
|
+
xai_temperature_len=-1,
|
705
734
|
):
|
706
735
|
assert max_kv_splits == attn_logits.shape[2]
|
707
736
|
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
@@ -725,6 +754,7 @@ def decode_attention_fwd(
|
|
725
754
|
sm_scale,
|
726
755
|
logit_cap=logit_cap,
|
727
756
|
sinks=sinks,
|
757
|
+
xai_temperature_len=xai_temperature_len,
|
728
758
|
)
|
729
759
|
else:
|
730
760
|
# GQA/MQA/MLA
|
@@ -742,4 +772,5 @@ def decode_attention_fwd(
|
|
742
772
|
sm_scale,
|
743
773
|
logit_cap=logit_cap,
|
744
774
|
sinks=sinks,
|
775
|
+
xai_temperature_len=xai_temperature_len,
|
745
776
|
)
|
@@ -52,6 +52,7 @@ def _fwd_kernel(
|
|
52
52
|
mask_ptr,
|
53
53
|
mask_indptr,
|
54
54
|
sink_ptr,
|
55
|
+
window_kv_offset_ptr,
|
55
56
|
sm_scale,
|
56
57
|
kv_group_num,
|
57
58
|
stride_qbs,
|
@@ -68,6 +69,7 @@ def _fwd_kernel(
|
|
68
69
|
stride_buf_vh,
|
69
70
|
SLIDING_WINDOW_SIZE: tl.constexpr,
|
70
71
|
logit_cap: tl.constexpr,
|
72
|
+
xai_temperature_len: tl.constexpr,
|
71
73
|
Lq: tl.constexpr,
|
72
74
|
Lv: tl.constexpr,
|
73
75
|
BLOCK_DMODEL: tl.constexpr,
|
@@ -95,6 +97,11 @@ def _fwd_kernel(
|
|
95
97
|
if USE_CUSTOM_MASK:
|
96
98
|
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
|
97
99
|
|
100
|
+
# For SWA, we should only load the mask in the sliding window
|
101
|
+
window_kv_offset = 0
|
102
|
+
if USE_CUSTOM_MASK and SLIDING_WINDOW_SIZE > 0:
|
103
|
+
window_kv_offset = tl.load(window_kv_offset_ptr + cur_seq)
|
104
|
+
|
98
105
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
99
106
|
offs_dv = tl.arange(0, BLOCK_DV)
|
100
107
|
offs_m = tl.arange(0, BLOCK_M)
|
@@ -103,6 +110,15 @@ def _fwd_kernel(
|
|
103
110
|
mask_d = offs_d < Lq
|
104
111
|
mask_dv = offs_dv < Lv
|
105
112
|
|
113
|
+
if xai_temperature_len > 0:
|
114
|
+
offs_qidx = cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m
|
115
|
+
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
|
116
|
+
xai_temperature_reg = tl.where(
|
117
|
+
offs_qidx > xai_temperature_len,
|
118
|
+
tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale,
|
119
|
+
1.0,
|
120
|
+
)
|
121
|
+
|
106
122
|
offs_q = (
|
107
123
|
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
108
124
|
* stride_qbs
|
@@ -139,7 +155,9 @@ def _fwd_kernel(
|
|
139
155
|
custom_mask = tl.load(
|
140
156
|
mask_ptr
|
141
157
|
+ cur_seq_mask_start_idx
|
142
|
-
+ (cur_block_m * BLOCK_M + offs_m[:, None])
|
158
|
+
+ (cur_block_m * BLOCK_M + offs_m[:, None])
|
159
|
+
* (cur_seq_len + window_kv_offset)
|
160
|
+
+ window_kv_offset
|
143
161
|
+ start_n
|
144
162
|
+ offs_n[None, :],
|
145
163
|
mask=(mask_m[:, None] & mask_n[None, :]),
|
@@ -195,6 +213,9 @@ def _fwd_kernel(
|
|
195
213
|
if logit_cap > 0:
|
196
214
|
qk = logit_cap * tanh(qk / logit_cap)
|
197
215
|
|
216
|
+
if xai_temperature_len > 0:
|
217
|
+
qk *= xai_temperature_reg[:, None]
|
218
|
+
|
198
219
|
qk = tl.where(final_mask, qk, float("-inf"))
|
199
220
|
|
200
221
|
row_max = tl.max(qk, 1)
|
@@ -236,7 +257,9 @@ def _fwd_kernel(
|
|
236
257
|
custom_mask = tl.load(
|
237
258
|
mask_ptr
|
238
259
|
+ cur_seq_mask_start_idx
|
239
|
-
+ (cur_block_m * BLOCK_M + offs_m[:, None])
|
260
|
+
+ (cur_block_m * BLOCK_M + offs_m[:, None])
|
261
|
+
* (cur_seq_len + window_kv_offset)
|
262
|
+
+ window_kv_offset
|
240
263
|
+ cur_seq_len_prefix
|
241
264
|
+ start_n
|
242
265
|
+ offs_n[None, :],
|
@@ -296,6 +319,9 @@ def _fwd_kernel(
|
|
296
319
|
if logit_cap > 0:
|
297
320
|
qk = logit_cap * tanh(qk / logit_cap)
|
298
321
|
|
322
|
+
if xai_temperature_len > 0:
|
323
|
+
qk *= xai_temperature_reg[:, None]
|
324
|
+
|
299
325
|
qk = tl.where(final_mask, qk, float("-inf"))
|
300
326
|
|
301
327
|
row_max = tl.max(qk, 1)
|
@@ -362,6 +388,8 @@ def extend_attention_fwd(
|
|
362
388
|
skip_prefix_custom_mask=True,
|
363
389
|
sliding_window_size=-1,
|
364
390
|
sinks=None,
|
391
|
+
window_kv_offsets=None,
|
392
|
+
xai_temperature_len=-1,
|
365
393
|
):
|
366
394
|
"""
|
367
395
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
@@ -449,6 +477,7 @@ def extend_attention_fwd(
|
|
449
477
|
custom_mask,
|
450
478
|
mask_indptr,
|
451
479
|
sinks,
|
480
|
+
window_kv_offsets,
|
452
481
|
sm_scale,
|
453
482
|
kv_group_num,
|
454
483
|
q_extend.stride(0),
|
@@ -465,6 +494,7 @@ def extend_attention_fwd(
|
|
465
494
|
v_buffer.stride(1),
|
466
495
|
SLIDING_WINDOW_SIZE=sliding_window_size,
|
467
496
|
logit_cap=logit_cap,
|
497
|
+
xai_temperature_len=xai_temperature_len,
|
468
498
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
469
499
|
BLOCK_DPE=BLOCK_DPE,
|
470
500
|
BLOCK_DV=BLOCK_DV,
|