sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from dataclasses import dataclass
|
3
4
|
from typing import TYPE_CHECKING, Optional, Union
|
4
5
|
|
5
6
|
import torch
|
6
7
|
import triton
|
8
|
+
import triton.language as tl
|
7
9
|
|
8
10
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
9
11
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
10
12
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
11
13
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
14
|
+
from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
12
15
|
|
13
16
|
if TYPE_CHECKING:
|
14
17
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -16,6 +19,71 @@ if TYPE_CHECKING:
|
|
16
19
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
17
20
|
|
18
21
|
|
22
|
+
@triton.jit
|
23
|
+
def get_num_kv_splits_triton(
|
24
|
+
num_kv_splits_ptr,
|
25
|
+
seq_lens_ptr,
|
26
|
+
num_seq,
|
27
|
+
num_group,
|
28
|
+
num_head,
|
29
|
+
num_kv_head,
|
30
|
+
max_kv_splits,
|
31
|
+
device_core_count,
|
32
|
+
MAX_NUM_SEQ: tl.constexpr,
|
33
|
+
):
|
34
|
+
# TODO: this method is tunable, we need more online serving data to tune it
|
35
|
+
offs_seq = tl.arange(0, MAX_NUM_SEQ)
|
36
|
+
mask_seq = offs_seq < num_seq
|
37
|
+
|
38
|
+
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
|
39
|
+
max_seq_len = tl.max(seq_lens)
|
40
|
+
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
|
41
|
+
min_seq_len = tl.min(seq_lens)
|
42
|
+
if max_seq_len * 8 < min_seq_len * 10:
|
43
|
+
min_seq_len = max_seq_len
|
44
|
+
max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
|
45
|
+
kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
|
46
|
+
|
47
|
+
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
|
48
|
+
ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
|
49
|
+
ext_device_core_count = tl.cast(
|
50
|
+
device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
|
51
|
+
)
|
52
|
+
block_h, num_kv_group = 16, num_head // num_kv_head
|
53
|
+
if num_kv_group == 1:
|
54
|
+
token_grid = num_seq * num_group * num_head
|
55
|
+
else:
|
56
|
+
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
|
57
|
+
block_h = tl.minimum(block_h, num_kv_group)
|
58
|
+
token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
|
59
|
+
max_kv_splits_2 = tl.minimum(
|
60
|
+
tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
|
61
|
+
)
|
62
|
+
kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
|
63
|
+
|
64
|
+
num_kv_splits = tl.maximum(
|
65
|
+
tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
|
66
|
+
)
|
67
|
+
|
68
|
+
offs_token = offs_seq * num_group
|
69
|
+
mask_token = offs_token < num_seq * num_group
|
70
|
+
for i in range(0, num_group):
|
71
|
+
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
|
72
|
+
|
73
|
+
|
74
|
+
@dataclass
|
75
|
+
class ForwardMetadata:
|
76
|
+
attn_logits: torch.Tensor
|
77
|
+
attn_lse: torch.Tensor
|
78
|
+
max_extend_len: int
|
79
|
+
num_kv_splits: torch.Tensor
|
80
|
+
kv_indptr: torch.Tensor
|
81
|
+
kv_indices: torch.Tensor
|
82
|
+
qo_indptr: torch.Tensor
|
83
|
+
custom_mask: torch.Tensor
|
84
|
+
mask_indptr: torch.Tensor
|
85
|
+
|
86
|
+
|
19
87
|
class TritonAttnBackend(AttentionBackend):
|
20
88
|
def __init__(
|
21
89
|
self,
|
@@ -63,15 +131,55 @@ class TritonAttnBackend(AttentionBackend):
|
|
63
131
|
self.num_head = (
|
64
132
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
65
133
|
)
|
134
|
+
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
135
|
+
get_attention_tp_size()
|
136
|
+
)
|
66
137
|
|
67
|
-
self.
|
138
|
+
self.static_kv_splits = get_bool_env_var(
|
139
|
+
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
140
|
+
)
|
141
|
+
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
68
142
|
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
69
143
|
|
70
|
-
self.forward_metadata = None
|
144
|
+
self.forward_metadata: ForwardMetadata = None
|
71
145
|
|
72
146
|
self.max_context_len = model_runner.model_config.context_len
|
73
147
|
|
74
148
|
self.device = model_runner.device
|
149
|
+
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
150
|
+
|
151
|
+
def get_num_kv_splits(
|
152
|
+
self,
|
153
|
+
num_kv_splits: torch.Tensor,
|
154
|
+
seq_lens: torch.Tensor,
|
155
|
+
):
|
156
|
+
num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
|
157
|
+
num_group = num_token // num_seq
|
158
|
+
|
159
|
+
assert (
|
160
|
+
num_group * num_seq == num_token
|
161
|
+
), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"
|
162
|
+
|
163
|
+
if self.static_kv_splits or self.device_core_count <= 0:
|
164
|
+
num_kv_splits.fill_(self.max_kv_splits)
|
165
|
+
return
|
166
|
+
|
167
|
+
if num_seq < 256:
|
168
|
+
SCHEDULE_SEQ = 256
|
169
|
+
else:
|
170
|
+
SCHEDULE_SEQ = triton.next_power_of_2(num_seq)
|
171
|
+
|
172
|
+
get_num_kv_splits_triton[(1,)](
|
173
|
+
num_kv_splits,
|
174
|
+
seq_lens,
|
175
|
+
num_seq,
|
176
|
+
num_group,
|
177
|
+
self.num_head,
|
178
|
+
self.num_kv_head,
|
179
|
+
self.max_kv_splits,
|
180
|
+
self.device_core_count,
|
181
|
+
MAX_NUM_SEQ=SCHEDULE_SEQ,
|
182
|
+
)
|
75
183
|
|
76
184
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
77
185
|
"""Init auxiliary variables for triton attention backend."""
|
@@ -84,7 +192,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
84
192
|
if spec_info is None:
|
85
193
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
86
194
|
kv_indptr = kv_indptr[: bs + 1]
|
87
|
-
kv_indices = torch.
|
195
|
+
kv_indices = torch.empty(
|
88
196
|
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
89
197
|
)
|
90
198
|
create_flashinfer_kv_indices_triton[(bs,)](
|
@@ -100,16 +208,19 @@ class TritonAttnBackend(AttentionBackend):
|
|
100
208
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
101
209
|
bs = kv_indptr.shape[0] - 1
|
102
210
|
|
103
|
-
attn_logits = torch.
|
104
|
-
(
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
),
|
211
|
+
attn_logits = torch.empty(
|
212
|
+
(bs, self.num_head, self.max_kv_splits, self.v_head_dim),
|
213
|
+
dtype=torch.float32,
|
214
|
+
device=self.device,
|
215
|
+
)
|
216
|
+
attn_lse = torch.empty(
|
217
|
+
(bs, self.num_head, self.max_kv_splits),
|
110
218
|
dtype=torch.float32,
|
111
219
|
device=self.device,
|
112
220
|
)
|
221
|
+
num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
222
|
+
|
223
|
+
self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)
|
113
224
|
|
114
225
|
qo_indptr = None
|
115
226
|
custom_mask = None
|
@@ -127,7 +238,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
127
238
|
# Different with flashinfer kv_indptr and kv_indices construction
|
128
239
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
129
240
|
kv_indptr = kv_indptr[: bs + 1]
|
130
|
-
kv_indices = torch.
|
241
|
+
kv_indices = torch.empty(
|
131
242
|
kv_indptr[-1], dtype=torch.int32, device=self.device
|
132
243
|
)
|
133
244
|
create_flashinfer_kv_indices_triton[(bs,)](
|
@@ -148,7 +259,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
148
259
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
|
149
260
|
mask_indptr = mask_indptr[: bs + 1]
|
150
261
|
max_extend_len = self.num_draft_tokens
|
262
|
+
num_kv_splits = None
|
151
263
|
attn_logits = None
|
264
|
+
attn_lse = None
|
152
265
|
elif forward_batch.forward_mode.is_draft_extend():
|
153
266
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
154
267
|
spec_info.generate_attn_arg_prefill(
|
@@ -159,14 +272,19 @@ class TritonAttnBackend(AttentionBackend):
|
|
159
272
|
)
|
160
273
|
)
|
161
274
|
mask_indptr = None
|
275
|
+
# TODO(FIXME): This will trigger an invalid Eagle tree when using
|
276
|
+
# `max(spec_info.accept_length_cpu)`.
|
277
|
+
# It might have been forgotten to update somewhere.
|
162
278
|
max_extend_len = torch.max(spec_info.accept_length).item()
|
279
|
+
num_kv_splits = None
|
163
280
|
attn_logits = None
|
281
|
+
attn_lse = None
|
164
282
|
else:
|
165
283
|
kv_indptr[1 : bs + 1] = torch.cumsum(
|
166
284
|
forward_batch.extend_prefix_lens, dim=0
|
167
285
|
)
|
168
286
|
kv_indptr = kv_indptr[: bs + 1]
|
169
|
-
kv_indices = torch.
|
287
|
+
kv_indices = torch.empty(
|
170
288
|
forward_batch.extend_prefix_lens.sum().item(),
|
171
289
|
dtype=torch.int32,
|
172
290
|
device=self.device,
|
@@ -187,11 +305,15 @@ class TritonAttnBackend(AttentionBackend):
|
|
187
305
|
custom_mask = None
|
188
306
|
mask_indptr = None
|
189
307
|
attn_logits = None
|
308
|
+
attn_lse = None
|
190
309
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
310
|
+
num_kv_splits = None
|
191
311
|
|
192
|
-
self.forward_metadata = (
|
312
|
+
self.forward_metadata = ForwardMetadata(
|
193
313
|
attn_logits,
|
314
|
+
attn_lse,
|
194
315
|
max_extend_len,
|
316
|
+
num_kv_splits,
|
195
317
|
kv_indptr,
|
196
318
|
kv_indices,
|
197
319
|
qo_indptr,
|
@@ -203,10 +325,18 @@ class TritonAttnBackend(AttentionBackend):
|
|
203
325
|
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
204
326
|
):
|
205
327
|
self.cuda_graph_attn_logits = torch.zeros(
|
206
|
-
(max_bs, self.num_head, self.
|
328
|
+
(max_bs, self.num_head, self.max_kv_splits, self.v_head_dim),
|
207
329
|
dtype=torch.float32,
|
208
330
|
device=self.device,
|
209
331
|
)
|
332
|
+
self.cuda_graph_attn_lse = torch.zeros(
|
333
|
+
(max_bs, self.num_head, self.max_kv_splits),
|
334
|
+
dtype=torch.float32,
|
335
|
+
device=self.device,
|
336
|
+
)
|
337
|
+
self.cuda_graph_num_kv_splits = torch.full(
|
338
|
+
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
339
|
+
)
|
210
340
|
if kv_indices_buf is None:
|
211
341
|
self.cuda_graph_kv_indices = torch.zeros(
|
212
342
|
(max_bs * self.max_context_len),
|
@@ -254,7 +384,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
254
384
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
255
385
|
|
256
386
|
attn_logits = self.cuda_graph_attn_logits
|
387
|
+
attn_lse = self.cuda_graph_attn_lse
|
257
388
|
max_extend_len = None
|
389
|
+
num_kv_splits = self.cuda_graph_num_kv_splits
|
258
390
|
qo_indptr = None
|
259
391
|
custom_mask = None
|
260
392
|
mask_indptr = None
|
@@ -285,15 +417,19 @@ class TritonAttnBackend(AttentionBackend):
|
|
285
417
|
mask_indptr = self.mask_indptr[: bs + 1]
|
286
418
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
287
419
|
max_extend_len = self.num_draft_tokens
|
420
|
+
num_kv_splits = None
|
288
421
|
attn_logits = None
|
422
|
+
attn_lse = None
|
289
423
|
else:
|
290
424
|
raise ValueError(
|
291
425
|
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
|
292
426
|
)
|
293
427
|
|
294
|
-
self.forward_metadata = (
|
428
|
+
self.forward_metadata = ForwardMetadata(
|
295
429
|
attn_logits,
|
430
|
+
attn_lse,
|
296
431
|
max_extend_len,
|
432
|
+
num_kv_splits,
|
297
433
|
kv_indptr,
|
298
434
|
kv_indices,
|
299
435
|
qo_indptr,
|
@@ -317,6 +453,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
317
453
|
# Update kv_indptr, kv_indices
|
318
454
|
kv_indptr = self.kv_indptr
|
319
455
|
kv_indices = self.cuda_graph_kv_indices
|
456
|
+
num_kv_splits = self.cuda_graph_num_kv_splits
|
320
457
|
if spec_info is None:
|
321
458
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
322
459
|
kv_indptr = kv_indptr[: bs + 1]
|
@@ -329,9 +466,12 @@ class TritonAttnBackend(AttentionBackend):
|
|
329
466
|
kv_indices,
|
330
467
|
self.req_to_token.stride(0),
|
331
468
|
)
|
469
|
+
num_token = bs
|
332
470
|
else:
|
333
471
|
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
334
472
|
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
473
|
+
num_token = spec_info.kv_indptr.shape[0] - 1
|
474
|
+
self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
|
335
475
|
elif forward_mode.is_target_verify():
|
336
476
|
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
|
337
477
|
bs = len(req_pool_indices)
|
@@ -388,16 +528,6 @@ class TritonAttnBackend(AttentionBackend):
|
|
388
528
|
layer, forward_batch.out_cache_loc, k, v
|
389
529
|
)
|
390
530
|
|
391
|
-
(
|
392
|
-
_,
|
393
|
-
max_extend_len,
|
394
|
-
kv_indptr,
|
395
|
-
kv_indices,
|
396
|
-
qo_indptr,
|
397
|
-
custom_mask,
|
398
|
-
mask_indptr,
|
399
|
-
) = self.forward_metadata
|
400
|
-
|
401
531
|
self.extend_attention_fwd(
|
402
532
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
403
533
|
k.contiguous(),
|
@@ -405,12 +535,12 @@ class TritonAttnBackend(AttentionBackend):
|
|
405
535
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
406
536
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
407
537
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
408
|
-
qo_indptr,
|
409
|
-
kv_indptr,
|
410
|
-
kv_indices,
|
411
|
-
custom_mask,
|
412
|
-
mask_indptr,
|
413
|
-
max_extend_len,
|
538
|
+
self.forward_metadata.qo_indptr,
|
539
|
+
self.forward_metadata.kv_indptr,
|
540
|
+
self.forward_metadata.kv_indices,
|
541
|
+
self.forward_metadata.custom_mask,
|
542
|
+
self.forward_metadata.mask_indptr,
|
543
|
+
self.forward_metadata.max_extend_len,
|
414
544
|
layer.scaling,
|
415
545
|
layer.logit_cap,
|
416
546
|
)
|
@@ -435,8 +565,6 @@ class TritonAttnBackend(AttentionBackend):
|
|
435
565
|
else:
|
436
566
|
o = torch.empty_like(q)
|
437
567
|
|
438
|
-
attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
|
439
|
-
|
440
568
|
if save_kv_cache:
|
441
569
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
442
570
|
layer, forward_batch.out_cache_loc, k, v
|
@@ -447,10 +575,12 @@ class TritonAttnBackend(AttentionBackend):
|
|
447
575
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
448
576
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
449
577
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
450
|
-
kv_indptr,
|
451
|
-
kv_indices,
|
452
|
-
attn_logits,
|
453
|
-
self.
|
578
|
+
self.forward_metadata.kv_indptr,
|
579
|
+
self.forward_metadata.kv_indices,
|
580
|
+
self.forward_metadata.attn_logits,
|
581
|
+
self.forward_metadata.attn_lse,
|
582
|
+
self.forward_metadata.num_kv_splits,
|
583
|
+
self.max_kv_splits,
|
454
584
|
layer.scaling,
|
455
585
|
layer.logit_cap,
|
456
586
|
)
|
@@ -493,6 +623,9 @@ class TritonMultiStepDraftBackend:
|
|
493
623
|
)
|
494
624
|
)
|
495
625
|
self.max_context_len = self.attn_backends[0].max_context_len
|
626
|
+
self.num_head = (
|
627
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
628
|
+
)
|
496
629
|
self.device = model_runner.device
|
497
630
|
# Cached variables for generate_draft_decode_kv_indices
|
498
631
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
@@ -531,7 +664,7 @@ class TritonMultiStepDraftBackend:
|
|
531
664
|
call_fn(i, forward_batch)
|
532
665
|
|
533
666
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
534
|
-
kv_indices = torch.
|
667
|
+
kv_indices = torch.empty(
|
535
668
|
(
|
536
669
|
self.speculative_num_steps,
|
537
670
|
forward_batch.batch_size * self.topk * self.max_context_len,
|