sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -2,9 +2,6 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
"""
|
4
4
|
Support attention backend for FlashMLA.
|
5
|
-
|
6
|
-
#TODO
|
7
|
-
Enable speculative sampling in FlashMLA
|
8
5
|
"""
|
9
6
|
|
10
7
|
from dataclasses import dataclass
|
@@ -14,8 +11,6 @@ import torch
|
|
14
11
|
import triton
|
15
12
|
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
16
13
|
|
17
|
-
from sglang.global_config import global_config
|
18
|
-
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
19
14
|
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
20
15
|
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
|
21
16
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
@@ -24,7 +19,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
|
|
24
19
|
if TYPE_CHECKING:
|
25
20
|
from sglang.srt.layers.radix_attention import RadixAttention
|
26
21
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
27
|
-
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
28
22
|
from sglang.srt.speculative.spec_info import SpecInfo
|
29
23
|
|
30
24
|
|
@@ -154,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
154
148
|
def init_cuda_graph_state(
|
155
149
|
self,
|
156
150
|
max_bs: int,
|
151
|
+
max_num_tokens: int,
|
157
152
|
block_kv_indices: Optional[torch.Tensor] = None,
|
158
153
|
):
|
159
154
|
if block_kv_indices is None:
|
@@ -330,7 +325,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
330
325
|
)
|
331
326
|
|
332
327
|
def get_cuda_graph_seq_len_fill_value(self):
|
333
|
-
return
|
328
|
+
return 1
|
334
329
|
|
335
330
|
def forward_decode(
|
336
331
|
self,
|
@@ -464,11 +459,9 @@ class FlashMLAMultiStepDraftBackend:
|
|
464
459
|
topk: int,
|
465
460
|
speculative_num_steps: int,
|
466
461
|
):
|
467
|
-
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
468
|
-
|
469
462
|
if topk > 1:
|
470
463
|
raise ValueError(
|
471
|
-
|
464
|
+
"Currently FlashMLA only supports topk=1 for speculative decoding"
|
472
465
|
)
|
473
466
|
self.topk = topk
|
474
467
|
self.speculative_num_steps = speculative_num_steps
|
@@ -510,9 +503,11 @@ class FlashMLAMultiStepDraftBackend:
|
|
510
503
|
|
511
504
|
self.common_template(forward_batch, call_fn)
|
512
505
|
|
513
|
-
def init_cuda_graph_state(self, max_bs: int):
|
506
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
514
507
|
for i in range(self.speculative_num_steps):
|
515
|
-
self.attn_backends[i].init_cuda_graph_state(
|
508
|
+
self.attn_backends[i].init_cuda_graph_state(
|
509
|
+
max_bs, max_num_tokens, block_kv_indices=None
|
510
|
+
)
|
516
511
|
|
517
512
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
518
513
|
def call_fn(i, forward_batch):
|
@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend):
|
|
32
32
|
if forward_batch_child.batch_size > 0:
|
33
33
|
child.init_forward_metadata(forward_batch=forward_batch_child)
|
34
34
|
|
35
|
-
def init_cuda_graph_state(self, max_bs: int):
|
36
|
-
self.primary.init_cuda_graph_state(max_bs=max_bs)
|
35
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
36
|
+
self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
|
37
37
|
for item in self.children:
|
38
38
|
# TODO for children, maybe can provide *smaller* max_bs to optimize
|
39
|
-
item.init_cuda_graph_state(max_bs=max_bs)
|
39
|
+
item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
|
40
40
|
|
41
41
|
def init_forward_metadata_capture_cuda_graph(
|
42
42
|
self,
|
@@ -12,7 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
|
12
12
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
13
13
|
from sglang.srt.layers.radix_attention import AttentionType
|
14
14
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
15
|
-
from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
15
|
+
from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -20,117 +20,6 @@ if TYPE_CHECKING:
|
|
20
20
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
21
21
|
|
22
22
|
|
23
|
-
@triton.jit
|
24
|
-
def get_num_kv_splits_triton(
|
25
|
-
num_kv_splits_ptr,
|
26
|
-
seq_lens_ptr,
|
27
|
-
num_seq,
|
28
|
-
num_group,
|
29
|
-
num_head,
|
30
|
-
num_kv_head,
|
31
|
-
max_kv_splits,
|
32
|
-
device_core_count,
|
33
|
-
MAX_NUM_SEQ: tl.constexpr,
|
34
|
-
):
|
35
|
-
# TODO: this method is tunable, we need more online serving data to tune it
|
36
|
-
offs_seq = tl.arange(0, MAX_NUM_SEQ)
|
37
|
-
mask_seq = offs_seq < num_seq
|
38
|
-
|
39
|
-
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
|
40
|
-
max_seq_len = tl.max(seq_lens)
|
41
|
-
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
|
42
|
-
min_seq_len = tl.min(seq_lens)
|
43
|
-
if max_seq_len * 8 < min_seq_len * 10:
|
44
|
-
min_seq_len = max_seq_len
|
45
|
-
max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
|
46
|
-
kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
|
47
|
-
|
48
|
-
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
|
49
|
-
ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
|
50
|
-
ext_device_core_count = tl.cast(
|
51
|
-
device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
|
52
|
-
)
|
53
|
-
block_h, num_kv_group = 16, num_head // num_kv_head
|
54
|
-
if num_kv_group == 1:
|
55
|
-
token_grid = num_seq * num_group * num_head
|
56
|
-
else:
|
57
|
-
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
|
58
|
-
block_h = tl.minimum(block_h, num_kv_group)
|
59
|
-
token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
|
60
|
-
max_kv_splits_2 = tl.minimum(
|
61
|
-
tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
|
62
|
-
)
|
63
|
-
kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
|
64
|
-
|
65
|
-
num_kv_splits = tl.maximum(
|
66
|
-
tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
|
67
|
-
)
|
68
|
-
|
69
|
-
offs_token = offs_seq * num_group
|
70
|
-
mask_token = offs_token < num_seq * num_group
|
71
|
-
for i in range(0, num_group):
|
72
|
-
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
|
73
|
-
|
74
|
-
|
75
|
-
def update_sliding_window_buffer(
|
76
|
-
window_kv_indptr,
|
77
|
-
req_to_token,
|
78
|
-
sliding_window_size,
|
79
|
-
seq_lens,
|
80
|
-
req_pool_indices,
|
81
|
-
bs,
|
82
|
-
device,
|
83
|
-
):
|
84
|
-
window_kv_lens = torch.minimum(
|
85
|
-
seq_lens,
|
86
|
-
torch.tensor(sliding_window_size + 1),
|
87
|
-
)
|
88
|
-
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
89
|
-
window_kv_indptr = window_kv_indptr[: bs + 1]
|
90
|
-
window_kv_indices = torch.empty(
|
91
|
-
window_kv_indptr[-1], dtype=torch.int32, device=device
|
92
|
-
)
|
93
|
-
window_kv_start_idx = seq_lens - window_kv_lens
|
94
|
-
create_flashinfer_kv_indices_triton[(bs,)](
|
95
|
-
req_to_token,
|
96
|
-
req_pool_indices,
|
97
|
-
window_kv_lens,
|
98
|
-
window_kv_indptr,
|
99
|
-
window_kv_start_idx,
|
100
|
-
window_kv_indices,
|
101
|
-
req_to_token.stride(0),
|
102
|
-
)
|
103
|
-
return window_kv_indptr, window_kv_indices, window_kv_lens
|
104
|
-
|
105
|
-
|
106
|
-
def update_sliding_window_buffer_cuda_graph(
|
107
|
-
window_kv_indptr,
|
108
|
-
window_kv_indices,
|
109
|
-
req_to_token,
|
110
|
-
sliding_window_size,
|
111
|
-
seq_lens,
|
112
|
-
req_pool_indices,
|
113
|
-
bs,
|
114
|
-
):
|
115
|
-
window_kv_lens = torch.minimum(
|
116
|
-
seq_lens,
|
117
|
-
torch.tensor(sliding_window_size + 1),
|
118
|
-
)
|
119
|
-
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
120
|
-
window_kv_indptr = window_kv_indptr[: bs + 1]
|
121
|
-
window_kv_start_idx = seq_lens - window_kv_lens
|
122
|
-
create_flashinfer_kv_indices_triton[(bs,)](
|
123
|
-
req_to_token,
|
124
|
-
req_pool_indices,
|
125
|
-
window_kv_lens,
|
126
|
-
window_kv_indptr,
|
127
|
-
window_kv_start_idx,
|
128
|
-
window_kv_indices,
|
129
|
-
req_to_token.stride(0),
|
130
|
-
)
|
131
|
-
return window_kv_indptr, window_kv_lens
|
132
|
-
|
133
|
-
|
134
23
|
@dataclass
|
135
24
|
class ForwardMetadata:
|
136
25
|
attn_logits: torch.Tensor
|
@@ -165,8 +54,8 @@ class TritonAttnBackend(AttentionBackend):
|
|
165
54
|
|
166
55
|
super().__init__()
|
167
56
|
|
168
|
-
self.decode_attention_fwd = decode_attention_fwd
|
169
|
-
self.extend_attention_fwd = extend_attention_fwd
|
57
|
+
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
|
58
|
+
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
|
170
59
|
|
171
60
|
self.skip_prefill = skip_prefill
|
172
61
|
|
@@ -372,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
372
261
|
num_kv_splits = None
|
373
262
|
attn_logits = None
|
374
263
|
attn_lse = None
|
264
|
+
|
375
265
|
elif forward_batch.forward_mode.is_draft_extend():
|
376
266
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
377
267
|
spec_info.generate_attn_arg_prefill(
|
@@ -446,24 +336,27 @@ class TritonAttnBackend(AttentionBackend):
|
|
446
336
|
)
|
447
337
|
|
448
338
|
def init_cuda_graph_state(
|
449
|
-
self,
|
339
|
+
self,
|
340
|
+
max_bs: int,
|
341
|
+
max_num_tokens: int,
|
342
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
450
343
|
):
|
451
344
|
self.cuda_graph_attn_logits = torch.zeros(
|
452
|
-
(
|
345
|
+
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
|
453
346
|
dtype=torch.float32,
|
454
347
|
device=self.device,
|
455
348
|
)
|
456
349
|
self.cuda_graph_attn_lse = torch.zeros(
|
457
|
-
(
|
350
|
+
(max_num_tokens, self.num_head, self.max_kv_splits),
|
458
351
|
dtype=torch.float32,
|
459
352
|
device=self.device,
|
460
353
|
)
|
461
354
|
self.cuda_graph_num_kv_splits = torch.full(
|
462
|
-
(
|
355
|
+
(max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
463
356
|
)
|
464
357
|
if kv_indices_buf is None:
|
465
358
|
self.cuda_graph_kv_indices = torch.zeros(
|
466
|
-
(
|
359
|
+
(max_num_tokens * self.max_context_len),
|
467
360
|
dtype=torch.int32,
|
468
361
|
device=self.device,
|
469
362
|
)
|
@@ -472,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
472
365
|
|
473
366
|
if not self.skip_prefill:
|
474
367
|
self.cuda_graph_custom_mask = torch.zeros(
|
475
|
-
(
|
368
|
+
(max_num_tokens * self.max_context_len),
|
476
369
|
dtype=torch.uint8,
|
477
370
|
device=self.device,
|
478
371
|
)
|
@@ -480,7 +373,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
480
373
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
481
374
|
if kv_indices_buf is None:
|
482
375
|
self.cuda_graph_window_kv_indices = torch.zeros(
|
483
|
-
(
|
376
|
+
(max_num_tokens * self.sliding_window_size),
|
484
377
|
dtype=torch.int32,
|
485
378
|
device=self.device,
|
486
379
|
)
|
@@ -488,7 +381,10 @@ class TritonAttnBackend(AttentionBackend):
|
|
488
381
|
self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
|
489
382
|
|
490
383
|
self.cuda_graph_window_num_kv_splits = torch.full(
|
491
|
-
(
|
384
|
+
(max_num_tokens,),
|
385
|
+
self.max_kv_splits,
|
386
|
+
dtype=torch.int32,
|
387
|
+
device=self.device,
|
492
388
|
)
|
493
389
|
|
494
390
|
def init_forward_metadata_capture_cuda_graph(
|
@@ -569,6 +465,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
569
465
|
)
|
570
466
|
|
571
467
|
custom_mask = self.cuda_graph_custom_mask
|
468
|
+
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
572
469
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
573
470
|
mask_indptr = self.mask_indptr[: bs + 1]
|
574
471
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
@@ -877,6 +774,7 @@ class TritonMultiStepDraftBackend:
|
|
877
774
|
self.device = model_runner.device
|
878
775
|
# Cached variables for generate_draft_decode_kv_indices
|
879
776
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
777
|
+
self.page_size = model_runner.server_args.page_size
|
880
778
|
|
881
779
|
def common_template(
|
882
780
|
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
@@ -894,14 +792,13 @@ class TritonMultiStepDraftBackend:
|
|
894
792
|
kv_indices_buffer,
|
895
793
|
self.kv_indptr,
|
896
794
|
forward_batch.positions,
|
897
|
-
num_seqs,
|
898
|
-
self.topk,
|
899
795
|
self.pool_len,
|
900
796
|
kv_indices_buffer.shape[1],
|
901
797
|
self.kv_indptr.shape[1],
|
902
|
-
|
903
|
-
|
904
|
-
|
798
|
+
next_power_of_2(num_seqs),
|
799
|
+
next_power_of_2(self.speculative_num_steps),
|
800
|
+
next_power_of_2(bs),
|
801
|
+
self.page_size,
|
905
802
|
)
|
906
803
|
|
907
804
|
for i in range(self.speculative_num_steps):
|
@@ -932,15 +829,15 @@ class TritonMultiStepDraftBackend:
|
|
932
829
|
|
933
830
|
self.common_template(forward_batch, kv_indices, call_fn)
|
934
831
|
|
935
|
-
def init_cuda_graph_state(self, max_bs: int):
|
832
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
936
833
|
self.cuda_graph_kv_indices = torch.zeros(
|
937
|
-
(self.speculative_num_steps,
|
834
|
+
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
|
938
835
|
dtype=torch.int32,
|
939
836
|
device=self.device,
|
940
837
|
)
|
941
838
|
for i in range(self.speculative_num_steps):
|
942
839
|
self.attn_backends[i].init_cuda_graph_state(
|
943
|
-
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
840
|
+
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
944
841
|
)
|
945
842
|
|
946
843
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
@@ -973,3 +870,114 @@ class TritonMultiStepDraftBackend:
|
|
973
870
|
)
|
974
871
|
|
975
872
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
873
|
+
|
874
|
+
|
875
|
+
@triton.jit
|
876
|
+
def get_num_kv_splits_triton(
|
877
|
+
num_kv_splits_ptr,
|
878
|
+
seq_lens_ptr,
|
879
|
+
num_seq,
|
880
|
+
num_group,
|
881
|
+
num_head,
|
882
|
+
num_kv_head,
|
883
|
+
max_kv_splits,
|
884
|
+
device_core_count,
|
885
|
+
MAX_NUM_SEQ: tl.constexpr,
|
886
|
+
):
|
887
|
+
# TODO: this method is tunable, we need more online serving data to tune it
|
888
|
+
offs_seq = tl.arange(0, MAX_NUM_SEQ)
|
889
|
+
mask_seq = offs_seq < num_seq
|
890
|
+
|
891
|
+
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
|
892
|
+
max_seq_len = tl.max(seq_lens)
|
893
|
+
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
|
894
|
+
min_seq_len = tl.min(seq_lens)
|
895
|
+
if max_seq_len * 8 < min_seq_len * 10:
|
896
|
+
min_seq_len = max_seq_len
|
897
|
+
max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
|
898
|
+
kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
|
899
|
+
|
900
|
+
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
|
901
|
+
ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
|
902
|
+
ext_device_core_count = tl.cast(
|
903
|
+
device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
|
904
|
+
)
|
905
|
+
block_h, num_kv_group = 16, num_head // num_kv_head
|
906
|
+
if num_kv_group == 1:
|
907
|
+
token_grid = num_seq * num_group * num_head
|
908
|
+
else:
|
909
|
+
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
|
910
|
+
block_h = tl.minimum(block_h, num_kv_group)
|
911
|
+
token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
|
912
|
+
max_kv_splits_2 = tl.minimum(
|
913
|
+
tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
|
914
|
+
)
|
915
|
+
kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
|
916
|
+
|
917
|
+
num_kv_splits = tl.maximum(
|
918
|
+
tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
|
919
|
+
)
|
920
|
+
|
921
|
+
offs_token = offs_seq * num_group
|
922
|
+
mask_token = offs_token < num_seq * num_group
|
923
|
+
for i in range(0, num_group):
|
924
|
+
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
|
925
|
+
|
926
|
+
|
927
|
+
def update_sliding_window_buffer(
|
928
|
+
window_kv_indptr,
|
929
|
+
req_to_token,
|
930
|
+
sliding_window_size,
|
931
|
+
seq_lens,
|
932
|
+
req_pool_indices,
|
933
|
+
bs,
|
934
|
+
device,
|
935
|
+
):
|
936
|
+
window_kv_lens = torch.minimum(
|
937
|
+
seq_lens,
|
938
|
+
torch.tensor(sliding_window_size + 1),
|
939
|
+
)
|
940
|
+
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
941
|
+
window_kv_indptr = window_kv_indptr[: bs + 1]
|
942
|
+
window_kv_indices = torch.empty(
|
943
|
+
window_kv_indptr[-1], dtype=torch.int32, device=device
|
944
|
+
)
|
945
|
+
window_kv_start_idx = seq_lens - window_kv_lens
|
946
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
947
|
+
req_to_token,
|
948
|
+
req_pool_indices,
|
949
|
+
window_kv_lens,
|
950
|
+
window_kv_indptr,
|
951
|
+
window_kv_start_idx,
|
952
|
+
window_kv_indices,
|
953
|
+
req_to_token.stride(0),
|
954
|
+
)
|
955
|
+
return window_kv_indptr, window_kv_indices, window_kv_lens
|
956
|
+
|
957
|
+
|
958
|
+
def update_sliding_window_buffer_cuda_graph(
|
959
|
+
window_kv_indptr,
|
960
|
+
window_kv_indices,
|
961
|
+
req_to_token,
|
962
|
+
sliding_window_size,
|
963
|
+
seq_lens,
|
964
|
+
req_pool_indices,
|
965
|
+
bs,
|
966
|
+
):
|
967
|
+
window_kv_lens = torch.minimum(
|
968
|
+
seq_lens,
|
969
|
+
torch.tensor(sliding_window_size + 1),
|
970
|
+
)
|
971
|
+
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
972
|
+
window_kv_indptr = window_kv_indptr[: bs + 1]
|
973
|
+
window_kv_start_idx = seq_lens - window_kv_lens
|
974
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
975
|
+
req_to_token,
|
976
|
+
req_pool_indices,
|
977
|
+
window_kv_lens,
|
978
|
+
window_kv_indptr,
|
979
|
+
window_kv_start_idx,
|
980
|
+
window_kv_indices,
|
981
|
+
req_to_token.stride(0),
|
982
|
+
)
|
983
|
+
return window_kv_indptr, window_kv_lens
|
@@ -31,11 +31,6 @@ _is_hip = is_hip()
|
|
31
31
|
|
32
32
|
logger = logging.getLogger(__name__)
|
33
33
|
|
34
|
-
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
|
35
|
-
logger.warning(
|
36
|
-
"The following error message 'operation scheduled before its operands' can be ignored."
|
37
|
-
)
|
38
|
-
|
39
34
|
|
40
35
|
_MIN_BLOCK_KV = 32
|
41
36
|
|
@@ -713,7 +708,7 @@ def decode_attention_fwd(
|
|
713
708
|
num_kv_splits,
|
714
709
|
max_kv_splits,
|
715
710
|
sm_scale,
|
716
|
-
logit_cap,
|
711
|
+
logit_cap=logit_cap,
|
717
712
|
)
|
718
713
|
else:
|
719
714
|
# GQA/MQA/MLA
|
@@ -729,5 +724,5 @@ def decode_attention_fwd(
|
|
729
724
|
num_kv_splits,
|
730
725
|
max_kv_splits,
|
731
726
|
sm_scale,
|
732
|
-
logit_cap,
|
727
|
+
logit_cap=logit_cap,
|
733
728
|
)
|
@@ -1,15 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import dataclasses
|
4
|
+
import functools
|
3
5
|
import math
|
4
|
-
from functools import lru_cache
|
5
|
-
from typing import Optional, Tuple
|
6
|
+
from functools import lru_cache
|
7
|
+
from typing import Any, Optional, Tuple, Union
|
6
8
|
|
7
9
|
import torch
|
8
10
|
import torch.nn as nn
|
9
11
|
import torch.nn.functional as F
|
10
12
|
from einops import rearrange
|
11
13
|
|
12
|
-
from sglang.srt.utils import is_cuda
|
14
|
+
from sglang.srt.utils import is_cuda, print_info_once
|
13
15
|
|
14
16
|
_is_cuda = is_cuda()
|
15
17
|
|
@@ -29,29 +31,42 @@ from sglang.srt.layers.linear import (
|
|
29
31
|
from sglang.srt.layers.quantization import QuantizationConfig
|
30
32
|
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
31
33
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
32
|
-
from sglang.srt.utils import add_prefix
|
34
|
+
from sglang.srt.utils import add_prefix
|
33
35
|
|
34
36
|
ROTARY_EMBED_CLASSES = {
|
35
37
|
"normal": apply_rotary_pos_emb,
|
36
38
|
}
|
37
39
|
|
38
40
|
|
39
|
-
|
40
|
-
|
41
|
+
@dataclasses.dataclass
|
42
|
+
class SingletonCache:
|
43
|
+
data: Any = None
|
41
44
|
|
42
|
-
|
43
|
-
|
44
|
-
nonlocal has_run
|
45
|
-
if not has_run:
|
46
|
-
func(*args, **kwargs)
|
47
|
-
has_run = True
|
45
|
+
def set_data(self, value: Any) -> None:
|
46
|
+
self.data = value
|
48
47
|
|
49
|
-
|
48
|
+
def get_data(self) -> Optional[Any]:
|
49
|
+
return self.data
|
50
50
|
|
51
|
+
def empty(self) -> bool:
|
52
|
+
return self.get_data() is None
|
51
53
|
|
52
|
-
|
53
|
-
|
54
|
-
|
54
|
+
|
55
|
+
# TODO: requires real seqlens from images
|
56
|
+
@functools.lru_cache(maxsize=128)
|
57
|
+
def _get_cu_seqlens_for_shape(batch_size: int, seqlen: int, device) -> torch.Tensor:
|
58
|
+
"""
|
59
|
+
Generates cumulative sequence lengths (cu_seqlens) for a given batch_size, seqlen, and device.
|
60
|
+
Caches the result based on these parameters.
|
61
|
+
"""
|
62
|
+
cu_seqlens = torch.arange(
|
63
|
+
0,
|
64
|
+
(batch_size + 1) * seqlen,
|
65
|
+
step=seqlen,
|
66
|
+
dtype=torch.int32,
|
67
|
+
device=device,
|
68
|
+
)
|
69
|
+
return cu_seqlens
|
55
70
|
|
56
71
|
|
57
72
|
class VisionSdpaAttention(nn.Module):
|
@@ -265,8 +280,9 @@ class VisionFlash3Attention(nn.Module):
|
|
265
280
|
q: torch.Tensor,
|
266
281
|
k: torch.Tensor,
|
267
282
|
v: torch.Tensor,
|
268
|
-
cu_seqlens: Optional[torch.Tensor],
|
269
|
-
|
283
|
+
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
|
284
|
+
bsz: int,
|
285
|
+
seq_len: int,
|
270
286
|
**kwargs,
|
271
287
|
) -> torch.Tensor:
|
272
288
|
r"""
|
@@ -275,7 +291,16 @@ class VisionFlash3Attention(nn.Module):
|
|
275
291
|
Returns:
|
276
292
|
[b * s, h, head_size]
|
277
293
|
"""
|
278
|
-
cu_seqlens
|
294
|
+
if cu_seqlens is None:
|
295
|
+
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
296
|
+
elif isinstance(cu_seqlens, SingletonCache):
|
297
|
+
if cu_seqlens.empty():
|
298
|
+
cu_seqlens.set_data(
|
299
|
+
_get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
300
|
+
)
|
301
|
+
cu_seqlens = cu_seqlens.get_data()
|
302
|
+
|
303
|
+
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
|
279
304
|
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
280
305
|
max_seqlen = seq_lens.max().item()
|
281
306
|
output = flash_attn_varlen_func(
|
@@ -346,11 +371,11 @@ class VisionAttention(nn.Module):
|
|
346
371
|
if global_server_args_dict["mm_attention_backend"] is None:
|
347
372
|
if qkv_backend is None:
|
348
373
|
qkv_backend = "sdpa"
|
349
|
-
|
374
|
+
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
350
375
|
else:
|
351
376
|
qkv_backend = global_server_args_dict["mm_attention_backend"]
|
352
377
|
|
353
|
-
|
378
|
+
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
354
379
|
|
355
380
|
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
|
356
381
|
head_dim=self.head_size,
|
@@ -423,15 +448,16 @@ class VisionAttention(nn.Module):
|
|
423
448
|
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
424
449
|
qkv, _ = self.qkv_proj(x)
|
425
450
|
|
426
|
-
# [s, b, head
|
451
|
+
# [s, b, head, head_dim_sum]
|
427
452
|
new_x_shape = qkv.size()[:-1] + (
|
428
453
|
head,
|
429
|
-
|
454
|
+
self.q_size + 2 * self.kv_size,
|
430
455
|
)
|
431
456
|
qkv = qkv.view(*new_x_shape)
|
432
457
|
|
433
458
|
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
434
|
-
q, k, v =
|
459
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
460
|
+
|
435
461
|
# [s, b, head, head_size] --> [b, s, head, head_size]
|
436
462
|
q, k, v = [
|
437
463
|
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
@@ -468,6 +494,7 @@ class VisionAttention(nn.Module):
|
|
468
494
|
k=k,
|
469
495
|
v=v,
|
470
496
|
bsz=bsz,
|
497
|
+
seq_len=s,
|
471
498
|
cu_seqlens=cu_seqlens,
|
472
499
|
attention_mask=attention_mask,
|
473
500
|
)
|