sglang 0.5.1.post1__py3-none-any.whl → 0.5.1.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/decode.py +4 -0
- sglang/srt/disaggregation/prefill.py +4 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/openai/protocol.py +27 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/entrypoints/tool.py +7 -7
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +16 -7
- sglang/srt/layers/attention/ascend_backend.py +218 -111
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +76 -91
- sglang/srt/layers/attention/utils.py +15 -94
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/moe/cutlass_moe.py +0 -15
- sglang/srt/layers/moe/ep_moe/layer.py +1 -7
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -7
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/mxfp4.py +16 -23
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/lora_manager.py +29 -12
- sglang/srt/managers/cache_controller.py +223 -156
- sglang/srt/managers/detokenizer_manager.py +5 -0
- sglang/srt/managers/io_struct.py +30 -0
- sglang/srt/managers/scheduler.py +58 -7
- sglang/srt/managers/scheduler_metrics_mixin.py +15 -0
- sglang/srt/managers/tokenizer_manager.py +36 -3
- sglang/srt/mem_cache/hicache_storage.py +31 -20
- sglang/srt/mem_cache/hiradix_cache.py +12 -3
- sglang/srt/mem_cache/memory_pool.py +73 -14
- sglang/srt/mem_cache/memory_pool_host.py +3 -2
- sglang/srt/mem_cache/radix_cache.py +1 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +5 -13
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +85 -81
- sglang/srt/metrics/collector.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +12 -3
- sglang/srt/models/gpt_oss.py +2 -1
- sglang/srt/models/qwen2_5_vl.py +1 -0
- sglang/srt/offloader.py +115 -0
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/server_args.py +10 -5
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +59 -12
- sglang/test/test_cutlass_moe.py +33 -28
- sglang/version.py +1 -1
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/METADATA +6 -5
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/RECORD +69 -65
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import torch
|
8
|
+
import triton
|
9
|
+
import triton.language as tl
|
8
10
|
|
9
11
|
from sglang.srt.configs.model_config import AttentionArch
|
10
12
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -64,6 +66,9 @@ class FlashAttentionMetadata:
|
|
64
66
|
|
65
67
|
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
66
68
|
|
69
|
+
# For sliding window attention topk>1 spec decoding
|
70
|
+
swa_spec_metadata: Optional[FlashAttentionMetadata] = None
|
71
|
+
|
67
72
|
|
68
73
|
# Copied from:
|
69
74
|
# https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
|
@@ -340,6 +345,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
340
345
|
else None
|
341
346
|
)
|
342
347
|
|
348
|
+
# For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata.
|
349
|
+
# We use `layer.sliding_window_size` to decide whether to use SWA for each layer.
|
350
|
+
self.sliding_window_size = model_runner.sliding_window_size
|
351
|
+
self.has_swa = (
|
352
|
+
self.sliding_window_size is not None and self.sliding_window_size > -1
|
353
|
+
)
|
354
|
+
|
343
355
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
344
356
|
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
|
345
357
|
metadata = FlashAttentionMetadata()
|
@@ -556,6 +568,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|
556
568
|
(1, 0),
|
557
569
|
)
|
558
570
|
self.forward_metadata_spec_decode_expand = metadata_expand
|
571
|
+
|
572
|
+
if self.has_swa:
|
573
|
+
self._init_sliding_window_attn_spec_metadata(
|
574
|
+
metadata, metadata_expand
|
575
|
+
)
|
576
|
+
|
559
577
|
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
560
578
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
561
579
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
@@ -657,11 +675,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
657
675
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
658
676
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
659
677
|
# here is two side inclusive
|
660
|
-
|
661
|
-
|
662
|
-
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
663
|
-
else (-1, -1)
|
678
|
+
is_swa = (
|
679
|
+
layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
664
680
|
)
|
681
|
+
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
|
665
682
|
k_descale, v_descale = None, None
|
666
683
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
667
684
|
# has corresponding quantization method so that layer.k_scale is not None,
|
@@ -684,8 +701,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
684
701
|
)
|
685
702
|
|
686
703
|
# We do cascade attention for Target Verify with topk > 1
|
704
|
+
# We don't use cascade attention for Sliding Window Attention:
|
705
|
+
# - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes.
|
706
|
+
# - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it.
|
687
707
|
use_cascade_attn = (
|
688
|
-
forward_batch.forward_mode.is_target_verify()
|
708
|
+
forward_batch.forward_mode.is_target_verify()
|
709
|
+
and self.topk > 1
|
710
|
+
and not is_swa
|
689
711
|
)
|
690
712
|
|
691
713
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
@@ -700,13 +722,18 @@ class FlashAttentionBackend(AttentionBackend):
|
|
700
722
|
cu_seqlens_q = local_metadata.local_query_start_loc
|
701
723
|
cache_seqlens = local_metadata.local_seqused_k
|
702
724
|
max_seqlen_q = local_metadata.local_max_query_len
|
703
|
-
|
725
|
+
elif is_swa and metadata.swa_spec_metadata is not None:
|
726
|
+
swa_spec_metadata = metadata.swa_spec_metadata
|
727
|
+
page_table = swa_spec_metadata.page_table
|
728
|
+
cu_seqlens_q = swa_spec_metadata.cu_seqlens_q
|
729
|
+
cache_seqlens = swa_spec_metadata.cache_seqlens_int32
|
730
|
+
max_seqlen_q = swa_spec_metadata.max_seq_len_q
|
731
|
+
cu_seqlens_k = swa_spec_metadata.cu_seqlens_k
|
704
732
|
else:
|
705
733
|
page_table = metadata.page_table
|
706
734
|
cu_seqlens_q = metadata.cu_seqlens_q
|
707
735
|
cache_seqlens = metadata.cache_seqlens_int32
|
708
736
|
max_seqlen_q = metadata.max_seq_len_q
|
709
|
-
max_seqlen_k = metadata.max_seq_len_k
|
710
737
|
cu_seqlens_k = metadata.cu_seqlens_k
|
711
738
|
|
712
739
|
# Use Flash Attention for prefill
|
@@ -1377,6 +1404,32 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1377
1404
|
),
|
1378
1405
|
}
|
1379
1406
|
|
1407
|
+
if self.has_swa:
|
1408
|
+
self.target_verify_metadata_topk_swa = {
|
1409
|
+
"cache_seqlens": torch.zeros(
|
1410
|
+
max_bs * self.speculative_num_draft_tokens,
|
1411
|
+
dtype=torch.int32,
|
1412
|
+
device=self.device,
|
1413
|
+
),
|
1414
|
+
"cu_seqlens_k": torch.zeros(
|
1415
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1416
|
+
dtype=torch.int32,
|
1417
|
+
device=self.device,
|
1418
|
+
),
|
1419
|
+
"cu_seqlens_q": torch.arange(
|
1420
|
+
0,
|
1421
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1422
|
+
dtype=torch.int32,
|
1423
|
+
device=self.device,
|
1424
|
+
),
|
1425
|
+
"page_table": torch.zeros(
|
1426
|
+
max_bs * self.speculative_num_draft_tokens,
|
1427
|
+
self.max_context_len,
|
1428
|
+
dtype=torch.int32,
|
1429
|
+
device=self.device,
|
1430
|
+
),
|
1431
|
+
}
|
1432
|
+
|
1380
1433
|
self.encoder_metadata = {
|
1381
1434
|
"encoder_page_table": torch.zeros(
|
1382
1435
|
max_bs,
|
@@ -1564,6 +1617,28 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1564
1617
|
|
1565
1618
|
self.target_verify_metadata_topk_normal[bs] = metadata
|
1566
1619
|
self.target_verify_metadata_topk_expand[bs] = metadata_expand
|
1620
|
+
|
1621
|
+
if self.has_swa:
|
1622
|
+
metadata_swa = FlashAttentionMetadata()
|
1623
|
+
metadata_swa.cache_seqlens_int32 = (
|
1624
|
+
self.target_verify_metadata_topk_swa["cache_seqlens"][
|
1625
|
+
: bs * self.speculative_num_draft_tokens
|
1626
|
+
]
|
1627
|
+
)
|
1628
|
+
metadata_swa.max_seq_len_q = 1
|
1629
|
+
metadata_swa.cu_seqlens_q = self.target_verify_metadata_topk_swa[
|
1630
|
+
"cu_seqlens_q"
|
1631
|
+
][: bs * self.speculative_num_draft_tokens + 1]
|
1632
|
+
metadata_swa.cu_seqlens_k = self.target_verify_metadata_topk_swa[
|
1633
|
+
"cu_seqlens_k"
|
1634
|
+
][: bs * self.speculative_num_draft_tokens + 1]
|
1635
|
+
|
1636
|
+
metadata_swa.page_table = self.target_verify_metadata_topk_swa[
|
1637
|
+
"page_table"
|
1638
|
+
][: bs * self.speculative_num_draft_tokens]
|
1639
|
+
self.target_verify_metadata_topk_swa[bs] = metadata_swa
|
1640
|
+
metadata.swa_spec_metadata = metadata_swa
|
1641
|
+
|
1567
1642
|
elif forward_mode.is_draft_extend():
|
1568
1643
|
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
1569
1644
|
:bs
|
@@ -1804,6 +1879,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1804
1879
|
)
|
1805
1880
|
)
|
1806
1881
|
|
1882
|
+
if self.has_swa:
|
1883
|
+
metadata_swa = self.target_verify_metadata_topk_swa[bs]
|
1884
|
+
self._init_sliding_window_attn_spec_metadata(
|
1885
|
+
metadata, metadata_expand, metadata_swa
|
1886
|
+
)
|
1887
|
+
|
1807
1888
|
elif forward_mode.is_draft_extend():
|
1808
1889
|
metadata = self.draft_extend_metadata[bs]
|
1809
1890
|
metadata.cache_seqlens_int32.copy_(seq_lens)
|
@@ -2039,6 +2120,159 @@ class FlashAttentionBackend(AttentionBackend):
|
|
2039
2120
|
lam.local_max_query_len = int(seqlens_q_local_np.max())
|
2040
2121
|
lam.local_max_seq_len = int(seqlens_k_local_np.max())
|
2041
2122
|
|
2123
|
+
def _init_sliding_window_attn_spec_metadata(
|
2124
|
+
self,
|
2125
|
+
metadata: FlashAttentionMetadata,
|
2126
|
+
metadata_expand: FlashAttentionMetadata,
|
2127
|
+
metadata_swa: Optional[FlashAttentionMetadata] = None,
|
2128
|
+
):
|
2129
|
+
# TODO: support page_size > 1 for swa spec
|
2130
|
+
assert (
|
2131
|
+
self.page_size == 1
|
2132
|
+
), "FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention"
|
2133
|
+
|
2134
|
+
cache_seqlens_int32 = (
|
2135
|
+
metadata.cache_seqlens_int32.repeat_interleave(
|
2136
|
+
self.speculative_num_draft_tokens
|
2137
|
+
)
|
2138
|
+
+ metadata_expand.cache_seqlens_int32
|
2139
|
+
)
|
2140
|
+
cu_seqlens_k = torch.nn.functional.pad(
|
2141
|
+
torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0)
|
2142
|
+
)
|
2143
|
+
bs = cache_seqlens_int32.shape[0]
|
2144
|
+
page_table = (
|
2145
|
+
metadata.page_table.new_zeros(
|
2146
|
+
(bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])
|
2147
|
+
)
|
2148
|
+
if metadata_swa is None
|
2149
|
+
else metadata_swa.page_table
|
2150
|
+
)
|
2151
|
+
|
2152
|
+
prepare_swa_spec_page_table_triton(
|
2153
|
+
page_table,
|
2154
|
+
metadata.page_table,
|
2155
|
+
metadata_expand.page_table,
|
2156
|
+
metadata.cache_seqlens_int32,
|
2157
|
+
metadata_expand.cache_seqlens_int32,
|
2158
|
+
self.speculative_num_draft_tokens,
|
2159
|
+
)
|
2160
|
+
|
2161
|
+
if metadata_swa is None:
|
2162
|
+
metadata_swa = FlashAttentionMetadata()
|
2163
|
+
metadata_swa.max_seq_len_q = 1
|
2164
|
+
metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q
|
2165
|
+
metadata_swa.cache_seqlens_int32 = cache_seqlens_int32
|
2166
|
+
metadata_swa.cu_seqlens_k = cu_seqlens_k
|
2167
|
+
metadata_swa.page_table = page_table
|
2168
|
+
else:
|
2169
|
+
metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32)
|
2170
|
+
metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k)
|
2171
|
+
|
2172
|
+
metadata.swa_spec_metadata = metadata_swa
|
2173
|
+
|
2174
|
+
|
2175
|
+
@triton.jit
|
2176
|
+
def _prepare_swa_spec_page_table_kernel(
|
2177
|
+
dst_ptr,
|
2178
|
+
src_a_ptr,
|
2179
|
+
src_b_ptr,
|
2180
|
+
seq_len_a_ptr,
|
2181
|
+
seq_len_b_ptr,
|
2182
|
+
dst_stride_m,
|
2183
|
+
dst_stride_n,
|
2184
|
+
a_stride_m,
|
2185
|
+
a_stride_n,
|
2186
|
+
b_stride_m,
|
2187
|
+
b_stride_n,
|
2188
|
+
LEN_A: tl.constexpr,
|
2189
|
+
LEN_B: tl.constexpr,
|
2190
|
+
REPEAT_STEP: tl.constexpr,
|
2191
|
+
BLOCK_N: tl.constexpr,
|
2192
|
+
):
|
2193
|
+
pid_m = tl.program_id(0)
|
2194
|
+
pid_n = tl.program_id(1)
|
2195
|
+
|
2196
|
+
idx_a = pid_m // REPEAT_STEP
|
2197
|
+
idx_b = pid_m
|
2198
|
+
seq_len_a = tl.load(seq_len_a_ptr + idx_a)
|
2199
|
+
seq_len_b = tl.load(seq_len_b_ptr + idx_b)
|
2200
|
+
|
2201
|
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
2202
|
+
total_len = seq_len_a + seq_len_b
|
2203
|
+
|
2204
|
+
if pid_n * BLOCK_N >= total_len:
|
2205
|
+
return
|
2206
|
+
|
2207
|
+
mask = offs_n < total_len
|
2208
|
+
dst = dst_ptr + pid_m * dst_stride_m + offs_n * dst_stride_n
|
2209
|
+
|
2210
|
+
if (pid_n + 1) * BLOCK_N < seq_len_a:
|
2211
|
+
a_ptr = src_a_ptr + idx_a * a_stride_m + offs_n * a_stride_n
|
2212
|
+
a_mask = mask & (offs_n < LEN_A)
|
2213
|
+
val = tl.load(a_ptr, mask=a_mask, other=0)
|
2214
|
+
tl.store(dst, val, mask=mask)
|
2215
|
+
elif pid_n * BLOCK_N >= seq_len_a:
|
2216
|
+
offs_b = offs_n - seq_len_a
|
2217
|
+
b_ptr = src_b_ptr + idx_b * b_stride_m + offs_b * b_stride_n
|
2218
|
+
b_mask = mask & (offs_b < LEN_B)
|
2219
|
+
val = tl.load(b_ptr, mask=b_mask, other=0)
|
2220
|
+
tl.store(dst, val, mask=mask)
|
2221
|
+
else:
|
2222
|
+
# mixed part
|
2223
|
+
a_offs = offs_n
|
2224
|
+
a_mask = (a_offs < seq_len_a) & (a_offs < LEN_A)
|
2225
|
+
a_ptr = src_a_ptr + idx_a * a_stride_m + a_offs * a_stride_n
|
2226
|
+
a_val = tl.load(a_ptr, mask=a_mask, other=0)
|
2227
|
+
|
2228
|
+
b_offs = offs_n - seq_len_a
|
2229
|
+
b_mask = (b_offs >= 0) & (b_offs < seq_len_b) & (b_offs < LEN_B)
|
2230
|
+
b_ptr = src_b_ptr + idx_b * b_stride_m + b_offs * b_stride_n
|
2231
|
+
b_val = tl.load(b_ptr, mask=b_mask, other=0)
|
2232
|
+
|
2233
|
+
result = tl.where(offs_n < seq_len_a, a_val, b_val)
|
2234
|
+
tl.store(dst, result, mask=mask)
|
2235
|
+
|
2236
|
+
|
2237
|
+
def prepare_swa_spec_page_table_triton(
|
2238
|
+
page_table_dst: torch.Tensor,
|
2239
|
+
page_table_a: torch.Tensor,
|
2240
|
+
page_table_b: torch.Tensor, # expand page table
|
2241
|
+
seq_len_a: torch.Tensor,
|
2242
|
+
seq_len_b: torch.Tensor, # expand seq lens
|
2243
|
+
speculative_num_draft_tokens: int,
|
2244
|
+
):
|
2245
|
+
# concat page_table and expand page_table by kv seq length
|
2246
|
+
bs = seq_len_a.numel()
|
2247
|
+
bs_expand = seq_len_b.numel()
|
2248
|
+
assert bs_expand == bs * speculative_num_draft_tokens
|
2249
|
+
|
2250
|
+
LEN_A = page_table_a.shape[1]
|
2251
|
+
LEN_B = page_table_b.shape[1]
|
2252
|
+
LEN_OUT = LEN_A + LEN_B
|
2253
|
+
REPEAT_STEP = speculative_num_draft_tokens
|
2254
|
+
BLOCK_N = 256
|
2255
|
+
|
2256
|
+
grid = (bs_expand, triton.cdiv(LEN_OUT, BLOCK_N))
|
2257
|
+
_prepare_swa_spec_page_table_kernel[grid](
|
2258
|
+
page_table_dst,
|
2259
|
+
page_table_a,
|
2260
|
+
page_table_b,
|
2261
|
+
seq_len_a,
|
2262
|
+
seq_len_b,
|
2263
|
+
page_table_dst.stride(0),
|
2264
|
+
page_table_dst.stride(1),
|
2265
|
+
page_table_a.stride(0),
|
2266
|
+
page_table_a.stride(1),
|
2267
|
+
page_table_b.stride(0),
|
2268
|
+
page_table_b.stride(1),
|
2269
|
+
LEN_A=LEN_A,
|
2270
|
+
LEN_B=LEN_B,
|
2271
|
+
REPEAT_STEP=REPEAT_STEP,
|
2272
|
+
BLOCK_N=BLOCK_N,
|
2273
|
+
num_warps=4,
|
2274
|
+
)
|
2275
|
+
|
2042
2276
|
|
2043
2277
|
class FlashAttentionMultiStepBackend:
|
2044
2278
|
|
@@ -26,11 +26,14 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
26
26
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
27
27
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
28
28
|
from sglang.srt.layers.radix_attention import AttentionType
|
29
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
30
29
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
31
30
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
32
31
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
33
|
-
from sglang.srt.utils import
|
32
|
+
from sglang.srt.utils import (
|
33
|
+
is_flashinfer_available,
|
34
|
+
is_sm100_supported,
|
35
|
+
next_power_of_2,
|
36
|
+
)
|
34
37
|
|
35
38
|
if TYPE_CHECKING:
|
36
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|