sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +46 -25
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +184 -63
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -248
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/model_executor/cuda_graph_runner.py +15 -19
- sglang/srt/model_executor/forward_batch_info.py +94 -95
- sglang/srt/model_executor/model_runner.py +76 -75
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +14 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +71 -26
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +18 -9
- sglang/version.py +1 -1
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -474
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.2.dist-info/RECORD +0 -135
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,49 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
|
3
|
+
from torch import nn
|
4
|
+
|
5
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
6
|
+
|
7
|
+
|
8
|
+
class AttentionBackend(ABC):
|
9
|
+
"""The base class of attention backends"""
|
10
|
+
|
11
|
+
@abstractmethod
|
12
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
13
|
+
"""Init the metadata for a forward pass."""
|
14
|
+
raise NotImplementedError()
|
15
|
+
|
16
|
+
def init_cuda_graph_state(self, max_bs: int):
|
17
|
+
"""Init the global shared states for cuda graph."""
|
18
|
+
raise NotImplementedError()
|
19
|
+
|
20
|
+
def init_forward_metadata_capture_cuda_graph(
|
21
|
+
self, bs: int, req_pool_indices, seq_lens
|
22
|
+
):
|
23
|
+
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
24
|
+
raise NotImplementedError()
|
25
|
+
|
26
|
+
def init_forward_metadata_replay_cuda_graph(
|
27
|
+
self, bs: int, req_pool_indices, seq_lens
|
28
|
+
):
|
29
|
+
"""Init the metadata for a forward pass for replying a cuda graph."""
|
30
|
+
raise NotImplementedError()
|
31
|
+
|
32
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
33
|
+
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
34
|
+
raise NotImplementedError()
|
35
|
+
|
36
|
+
def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
37
|
+
"""Run forward on an attention layer."""
|
38
|
+
if forward_batch.forward_mode.is_decode():
|
39
|
+
return self.forward_decode(q, k, v, layer, forward_batch)
|
40
|
+
else:
|
41
|
+
return self.forward_extend(q, k, v, layer, forward_batch)
|
42
|
+
|
43
|
+
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
44
|
+
"""Run a forward for decode."""
|
45
|
+
raise NotImplementedError()
|
46
|
+
|
47
|
+
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
48
|
+
"""Run a forward for extend."""
|
49
|
+
raise NotImplementedError()
|
@@ -0,0 +1,277 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Support different attention backends.
|
5
|
+
Now there are two backends: FlashInfer and Triton.
|
6
|
+
FlashInfer is faster and Triton is easier to customize.
|
7
|
+
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
|
+
"""
|
9
|
+
|
10
|
+
from typing import TYPE_CHECKING
|
11
|
+
|
12
|
+
import torch
|
13
|
+
import torch.nn as nn
|
14
|
+
|
15
|
+
from sglang.global_config import global_config
|
16
|
+
from sglang.srt.layers.attention import AttentionBackend
|
17
|
+
from sglang.srt.layers.attention.flashinfer_utils import (
|
18
|
+
WrapperDispatch,
|
19
|
+
update_flashinfer_indices,
|
20
|
+
)
|
21
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
22
|
+
from sglang.srt.utils import is_flashinfer_available
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
26
|
+
|
27
|
+
if is_flashinfer_available():
|
28
|
+
from flashinfer import (
|
29
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
30
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
31
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
32
|
+
)
|
33
|
+
from flashinfer.cascade import merge_state
|
34
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
35
|
+
|
36
|
+
|
37
|
+
class FlashInferAttnBackend(AttentionBackend):
|
38
|
+
"""Flashinfer attention kernels."""
|
39
|
+
|
40
|
+
def __init__(self, model_runner: ModelRunner):
|
41
|
+
super().__init__()
|
42
|
+
self.model_runner = model_runner
|
43
|
+
|
44
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
45
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
46
|
+
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
47
|
+
):
|
48
|
+
self.decode_use_tensor_cores = True
|
49
|
+
else:
|
50
|
+
self.decode_use_tensor_cores = False
|
51
|
+
|
52
|
+
self.workspace_buffer = torch.empty(
|
53
|
+
global_config.flashinfer_workspace_size,
|
54
|
+
dtype=torch.uint8,
|
55
|
+
device="cuda",
|
56
|
+
)
|
57
|
+
|
58
|
+
assert not (
|
59
|
+
model_runner.sliding_window_size is not None
|
60
|
+
and model_runner.has_cross_attention
|
61
|
+
), "Sliding window and cross attention are not supported together"
|
62
|
+
|
63
|
+
self.num_wrappers = 1
|
64
|
+
self.dispatch_reason = None
|
65
|
+
if model_runner.sliding_window_size is not None:
|
66
|
+
self.num_wrappers = 2
|
67
|
+
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
|
68
|
+
elif model_runner.has_cross_attention:
|
69
|
+
self.num_wrappers = 2
|
70
|
+
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
|
71
|
+
|
72
|
+
# NOTE: we do not use ragged attention when there are multiple wrappers
|
73
|
+
self.prefill_wrapper_ragged = (
|
74
|
+
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
|
75
|
+
if self.num_wrappers == 1
|
76
|
+
else None
|
77
|
+
)
|
78
|
+
|
79
|
+
# Two wrappers: one for sliding window attention and one for full attention.
|
80
|
+
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
81
|
+
self.prefill_wrappers_paged = []
|
82
|
+
self.decode_wrappers = []
|
83
|
+
for _ in range(self.num_wrappers):
|
84
|
+
self.prefill_wrappers_paged.append(
|
85
|
+
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
86
|
+
)
|
87
|
+
self.decode_wrappers.append(
|
88
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
89
|
+
self.workspace_buffer,
|
90
|
+
"NHD",
|
91
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
92
|
+
)
|
93
|
+
)
|
94
|
+
|
95
|
+
self.forward_metadata = None
|
96
|
+
self.cuda_graph_metadata = {}
|
97
|
+
|
98
|
+
def _get_wrapper_idx(self, layer: nn.Module):
|
99
|
+
if self.num_wrappers == 1:
|
100
|
+
return 0
|
101
|
+
|
102
|
+
if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
103
|
+
return layer.sliding_window_size == -1
|
104
|
+
if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
105
|
+
return layer.is_cross_attention
|
106
|
+
|
107
|
+
raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
|
108
|
+
|
109
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
110
|
+
if forward_batch.forward_mode.is_decode():
|
111
|
+
prefix_lens = None
|
112
|
+
use_ragged = False
|
113
|
+
extend_no_prefix = False
|
114
|
+
total_num_tokens = None
|
115
|
+
else:
|
116
|
+
prefix_lens = forward_batch.extend_prefix_lens
|
117
|
+
|
118
|
+
# Some heuristics to check whether to use ragged forward
|
119
|
+
use_ragged = False
|
120
|
+
if (
|
121
|
+
torch.sum(forward_batch.seq_lens).item() >= 4096
|
122
|
+
and self.num_wrappers == 1
|
123
|
+
):
|
124
|
+
use_ragged = True
|
125
|
+
|
126
|
+
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
|
127
|
+
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
|
128
|
+
|
129
|
+
update_flashinfer_indices(
|
130
|
+
forward_batch.forward_mode,
|
131
|
+
self.model_runner,
|
132
|
+
forward_batch.req_pool_indices,
|
133
|
+
forward_batch.seq_lens,
|
134
|
+
prefix_lens,
|
135
|
+
use_ragged=use_ragged,
|
136
|
+
)
|
137
|
+
|
138
|
+
self.forward_metadata = (
|
139
|
+
use_ragged,
|
140
|
+
extend_no_prefix,
|
141
|
+
total_num_tokens,
|
142
|
+
self.decode_wrappers,
|
143
|
+
)
|
144
|
+
|
145
|
+
def init_cuda_graph_state(self, max_bs: int):
|
146
|
+
self.cuda_graph_kv_indptr = torch.zeros(
|
147
|
+
(max_bs + 1,), dtype=torch.int32, device="cuda"
|
148
|
+
)
|
149
|
+
self.cuda_graph_kv_indices = torch.zeros(
|
150
|
+
(max_bs * self.model_runner.model_config.context_len,),
|
151
|
+
dtype=torch.int32,
|
152
|
+
device="cuda",
|
153
|
+
)
|
154
|
+
self.cuda_graph_kv_last_page_len = torch.ones(
|
155
|
+
(max_bs,), dtype=torch.int32, device="cuda"
|
156
|
+
)
|
157
|
+
|
158
|
+
# NOTE: the buffers are always in the form of list
|
159
|
+
self.cuda_graph_kv_indptr = [self.cuda_graph_kv_indptr] + [
|
160
|
+
self.cuda_graph_kv_indptr.clone() for _ in range(self.num_wrappers - 1)
|
161
|
+
]
|
162
|
+
self.cuda_graph_kv_indices = [self.cuda_graph_kv_indices] + [
|
163
|
+
self.cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
164
|
+
]
|
165
|
+
|
166
|
+
def init_forward_metadata_capture_cuda_graph(
|
167
|
+
self, bs: int, req_pool_indices, seq_lens
|
168
|
+
):
|
169
|
+
decode_wrappers = []
|
170
|
+
for i in range(self.num_wrappers):
|
171
|
+
decode_wrappers.append(
|
172
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
173
|
+
self.workspace_buffer,
|
174
|
+
"NHD",
|
175
|
+
use_cuda_graph=True,
|
176
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
177
|
+
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
|
178
|
+
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
179
|
+
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
|
180
|
+
)
|
181
|
+
)
|
182
|
+
|
183
|
+
update_flashinfer_indices(
|
184
|
+
ForwardMode.DECODE,
|
185
|
+
self.model_runner,
|
186
|
+
req_pool_indices,
|
187
|
+
seq_lens,
|
188
|
+
None,
|
189
|
+
decode_wrappers,
|
190
|
+
)
|
191
|
+
|
192
|
+
self.cuda_graph_metadata[bs] = decode_wrappers
|
193
|
+
|
194
|
+
self.forward_metadata = (False, False, None, decode_wrappers)
|
195
|
+
|
196
|
+
def init_forward_metadata_replay_cuda_graph(
|
197
|
+
self, bs: int, req_pool_indices, seq_lens
|
198
|
+
):
|
199
|
+
update_flashinfer_indices(
|
200
|
+
ForwardMode.DECODE,
|
201
|
+
self.model_runner,
|
202
|
+
req_pool_indices[:bs],
|
203
|
+
seq_lens[:bs],
|
204
|
+
None,
|
205
|
+
self.cuda_graph_metadata[bs],
|
206
|
+
)
|
207
|
+
|
208
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
209
|
+
return 0
|
210
|
+
|
211
|
+
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
212
|
+
prefill_wrapper_paged = self.prefill_wrappers_paged[
|
213
|
+
self._get_wrapper_idx(layer)
|
214
|
+
]
|
215
|
+
|
216
|
+
use_ragged, extend_no_prefix, _, _ = self.forward_metadata
|
217
|
+
|
218
|
+
if not use_ragged:
|
219
|
+
if k is not None:
|
220
|
+
assert v is not None
|
221
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
222
|
+
layer.layer_id, forward_batch.out_cache_loc, k, v
|
223
|
+
)
|
224
|
+
o = prefill_wrapper_paged.forward(
|
225
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
226
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
227
|
+
causal=True,
|
228
|
+
sm_scale=layer.scaling,
|
229
|
+
window_left=layer.sliding_window_size,
|
230
|
+
logits_soft_cap=layer.logit_cap,
|
231
|
+
)
|
232
|
+
else:
|
233
|
+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
234
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
235
|
+
k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
|
236
|
+
v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
|
237
|
+
causal=True,
|
238
|
+
sm_scale=layer.scaling,
|
239
|
+
logits_soft_cap=layer.logit_cap,
|
240
|
+
)
|
241
|
+
|
242
|
+
if extend_no_prefix:
|
243
|
+
o = o1
|
244
|
+
else:
|
245
|
+
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
246
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
247
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
248
|
+
causal=False,
|
249
|
+
sm_scale=layer.scaling,
|
250
|
+
logits_soft_cap=layer.logit_cap,
|
251
|
+
)
|
252
|
+
|
253
|
+
o, _ = merge_state(o1, s1, o2, s2)
|
254
|
+
|
255
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
256
|
+
layer.layer_id, forward_batch.out_cache_loc, k, v
|
257
|
+
)
|
258
|
+
|
259
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
260
|
+
|
261
|
+
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
262
|
+
decode_wrapper = self.forward_metadata[-1][self._get_wrapper_idx(layer)]
|
263
|
+
|
264
|
+
if k is not None:
|
265
|
+
assert v is not None
|
266
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
267
|
+
layer.layer_id, forward_batch.out_cache_loc, k, v
|
268
|
+
)
|
269
|
+
|
270
|
+
o = decode_wrapper.forward(
|
271
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
272
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
273
|
+
sm_scale=layer.scaling,
|
274
|
+
logits_soft_cap=layer.logit_cap,
|
275
|
+
)
|
276
|
+
|
277
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -1,8 +1,15 @@
|
|
1
|
+
from enum import Enum, auto
|
2
|
+
|
1
3
|
import torch
|
2
4
|
import triton
|
3
5
|
import triton.language as tl
|
4
6
|
|
5
7
|
|
8
|
+
class WrapperDispatch(Enum):
|
9
|
+
SLIDING_WINDOW = auto()
|
10
|
+
CROSS_ATTENTION = auto()
|
11
|
+
|
12
|
+
|
6
13
|
@triton.jit
|
7
14
|
def create_flashinfer_kv_indices_triton(
|
8
15
|
req_to_token_ptr, # [max_batch, max_context_len]
|
@@ -47,7 +54,7 @@ class FlashinferUpdater:
|
|
47
54
|
req_pool_indices,
|
48
55
|
seq_lens,
|
49
56
|
prefix_lens,
|
50
|
-
|
57
|
+
decode_wrappers=None,
|
51
58
|
use_ragged=False,
|
52
59
|
):
|
53
60
|
self.forward_mode = forward_mode
|
@@ -66,82 +73,22 @@ class FlashinferUpdater:
|
|
66
73
|
self.head_dim = model_runner.model_config.head_dim
|
67
74
|
self.batch_size = len(req_pool_indices)
|
68
75
|
|
69
|
-
self.
|
70
|
-
|
76
|
+
self.decode_wrappers = (
|
77
|
+
decode_wrappers or self.model_runner.attn_backend.decode_wrappers
|
71
78
|
)
|
72
79
|
self.prefill_wrapper_ragged = (
|
73
80
|
self.model_runner.attn_backend.prefill_wrapper_ragged
|
74
81
|
)
|
75
|
-
self.
|
76
|
-
self.model_runner.attn_backend.
|
82
|
+
self.prefill_wrappers_paged = (
|
83
|
+
self.model_runner.attn_backend.prefill_wrappers_paged
|
77
84
|
)
|
78
85
|
|
79
86
|
self.kv_last_page_len = torch.ones(
|
80
87
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
81
88
|
)
|
82
89
|
|
83
|
-
def _init_indices_no_sliding_window(self):
|
84
|
-
if self.use_ragged:
|
85
|
-
paged_kernel_lens = self.prefix_lens
|
86
|
-
else:
|
87
|
-
paged_kernel_lens = self.seq_lens
|
88
|
-
|
89
|
-
self.kv_indptr = torch.zeros(
|
90
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
91
|
-
)
|
92
|
-
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
93
|
-
self.kv_indices = torch.empty(
|
94
|
-
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
95
|
-
)
|
96
|
-
|
97
|
-
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
98
|
-
self.model_runner.req_to_token_pool.req_to_token,
|
99
|
-
self.req_pool_indices,
|
100
|
-
paged_kernel_lens,
|
101
|
-
self.kv_indptr,
|
102
|
-
None,
|
103
|
-
self.kv_indices,
|
104
|
-
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
105
|
-
)
|
106
|
-
|
107
|
-
def _init_indices_sliding_window(self, wrapper_id):
|
108
|
-
if wrapper_id == 0:
|
109
|
-
# window attention use paged only
|
110
|
-
if self.forward_mode.is_decode():
|
111
|
-
paged_kernel_lens = torch.minimum(
|
112
|
-
self.seq_lens,
|
113
|
-
torch.tensor(self.model_runner.sliding_window_size + 1),
|
114
|
-
)
|
115
|
-
else:
|
116
|
-
paged_kernel_lens = torch.minimum(
|
117
|
-
self.seq_lens,
|
118
|
-
torch.tensor(self.model_runner.sliding_window_size)
|
119
|
-
+ self.seq_lens
|
120
|
-
- self.prefix_lens,
|
121
|
-
)
|
122
|
-
else:
|
123
|
-
# full attention
|
124
|
-
paged_kernel_lens = self.seq_lens
|
125
|
-
|
126
|
-
kv_start_idx = self.seq_lens - paged_kernel_lens
|
127
|
-
self.kv_indptr = torch.zeros(
|
128
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
129
|
-
)
|
130
|
-
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
131
|
-
self.kv_indices = torch.empty(
|
132
|
-
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
133
|
-
)
|
134
|
-
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
135
|
-
self.model_runner.req_to_token_pool.req_to_token,
|
136
|
-
self.req_pool_indices,
|
137
|
-
paged_kernel_lens,
|
138
|
-
self.kv_indptr,
|
139
|
-
kv_start_idx,
|
140
|
-
self.kv_indices,
|
141
|
-
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
142
|
-
)
|
143
|
-
|
144
90
|
def _update_decode_indices(self, decode_wrapper):
|
91
|
+
assert not isinstance(decode_wrapper, list)
|
145
92
|
decode_wrapper.end_forward()
|
146
93
|
decode_wrapper.begin_forward(
|
147
94
|
self.kv_indptr,
|
@@ -156,6 +103,9 @@ class FlashinferUpdater:
|
|
156
103
|
)
|
157
104
|
|
158
105
|
def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
|
106
|
+
assert not isinstance(paged_wrapper, list)
|
107
|
+
assert not isinstance(ragged_wrapper, list)
|
108
|
+
|
159
109
|
# extend part
|
160
110
|
qo_indptr = torch.zeros(
|
161
111
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
@@ -185,28 +135,75 @@ class FlashinferUpdater:
|
|
185
135
|
1,
|
186
136
|
)
|
187
137
|
|
188
|
-
def
|
189
|
-
|
138
|
+
def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0):
|
139
|
+
if dispatch_reason is None:
|
140
|
+
if self.use_ragged:
|
141
|
+
paged_kernel_lens = self.prefix_lens
|
142
|
+
else:
|
143
|
+
paged_kernel_lens = self.seq_lens
|
144
|
+
self.kv_start_idx = None
|
145
|
+
elif dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
146
|
+
if wrapper_id == 0:
|
147
|
+
# window attention use paged only
|
148
|
+
if self.forward_mode.is_decode():
|
149
|
+
paged_kernel_lens = torch.minimum(
|
150
|
+
self.seq_lens,
|
151
|
+
torch.tensor(self.model_runner.sliding_window_size + 1),
|
152
|
+
)
|
153
|
+
else:
|
154
|
+
paged_kernel_lens = torch.minimum(
|
155
|
+
self.seq_lens,
|
156
|
+
torch.tensor(self.model_runner.sliding_window_size)
|
157
|
+
+ self.seq_lens
|
158
|
+
- self.prefix_lens,
|
159
|
+
)
|
160
|
+
else:
|
161
|
+
# full attention
|
162
|
+
paged_kernel_lens = self.seq_lens
|
163
|
+
self.kv_start_idx = self.seq_lens - paged_kernel_lens
|
164
|
+
|
165
|
+
self.kv_indptr = torch.zeros(
|
166
|
+
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
167
|
+
)
|
168
|
+
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
169
|
+
self.kv_indices = torch.empty(
|
170
|
+
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
171
|
+
)
|
172
|
+
|
173
|
+
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
174
|
+
self.model_runner.req_to_token_pool.req_to_token,
|
175
|
+
self.req_pool_indices,
|
176
|
+
paged_kernel_lens,
|
177
|
+
self.kv_indptr,
|
178
|
+
self.kv_start_idx,
|
179
|
+
self.kv_indices,
|
180
|
+
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
181
|
+
)
|
182
|
+
|
183
|
+
def _update_indicess_single_wrapper(self):
|
184
|
+
self._get_indices()
|
190
185
|
|
191
186
|
if self.forward_mode.is_decode():
|
192
|
-
self._update_decode_indices(self.
|
187
|
+
self._update_decode_indices(self.decode_wrappers[0])
|
193
188
|
else:
|
194
189
|
self._update_extend_indices(
|
195
190
|
self.prefill_wrapper_ragged,
|
196
|
-
self.
|
191
|
+
self.prefill_wrappers_paged[0],
|
197
192
|
)
|
198
193
|
|
199
|
-
def
|
200
|
-
|
194
|
+
def _update_indices_cross_attention(self):
|
195
|
+
pass
|
201
196
|
|
197
|
+
def _update_indices_sliding_window(self):
|
198
|
+
assert self.use_ragged is False
|
202
199
|
for wrapper_id in range(2):
|
203
|
-
self.
|
200
|
+
self._get_indices(WrapperDispatch.SLIDING_WINDOW, wrapper_id)
|
204
201
|
if self.forward_mode.is_decode():
|
205
|
-
self._update_decode_indices(self.
|
202
|
+
self._update_decode_indices(self.decode_wrappers[wrapper_id])
|
206
203
|
else:
|
207
204
|
self._update_extend_indices(
|
208
205
|
None,
|
209
|
-
self.
|
206
|
+
self.prefill_wrappers_paged[wrapper_id],
|
210
207
|
)
|
211
208
|
|
212
209
|
|
@@ -216,7 +213,7 @@ def update_flashinfer_indices(
|
|
216
213
|
req_pool_indices,
|
217
214
|
seq_lens,
|
218
215
|
prefix_lens,
|
219
|
-
|
216
|
+
decode_wrappers=None,
|
220
217
|
use_ragged=False,
|
221
218
|
):
|
222
219
|
updater = FlashinferUpdater(
|
@@ -225,11 +222,16 @@ def update_flashinfer_indices(
|
|
225
222
|
req_pool_indices,
|
226
223
|
seq_lens,
|
227
224
|
prefix_lens,
|
228
|
-
|
225
|
+
decode_wrappers,
|
229
226
|
use_ragged,
|
230
227
|
)
|
231
228
|
|
232
|
-
|
233
|
-
|
229
|
+
dispatch_reason = model_runner.attn_backend.dispatch_reason
|
230
|
+
|
231
|
+
if dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
232
|
+
updater._update_indices_sliding_window()
|
233
|
+
elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
234
|
+
updater._update_indices_cross_attention()
|
234
235
|
else:
|
235
|
-
|
236
|
+
assert model_runner.attn_backend.num_wrappers == 1
|
237
|
+
updater._update_indicess_single_wrapper()
|