sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +46 -25
  4. sglang/bench_serving.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +14 -1
  6. sglang/lang/interpreter.py +16 -6
  7. sglang/lang/ir.py +20 -4
  8. sglang/srt/configs/model_config.py +11 -9
  9. sglang/srt/constrained/fsm_cache.py +9 -1
  10. sglang/srt/constrained/jump_forward.py +15 -2
  11. sglang/srt/layers/activation.py +4 -4
  12. sglang/srt/layers/attention/__init__.py +49 -0
  13. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  14. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  15. sglang/srt/layers/attention/triton_backend.py +161 -0
  16. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  17. sglang/srt/layers/layernorm.py +4 -4
  18. sglang/srt/layers/logits_processor.py +19 -15
  19. sglang/srt/layers/pooler.py +3 -3
  20. sglang/srt/layers/quantization/__init__.py +0 -2
  21. sglang/srt/layers/radix_attention.py +6 -4
  22. sglang/srt/layers/sampler.py +6 -4
  23. sglang/srt/layers/torchao_utils.py +18 -0
  24. sglang/srt/lora/lora.py +20 -21
  25. sglang/srt/lora/lora_manager.py +97 -25
  26. sglang/srt/managers/detokenizer_manager.py +31 -18
  27. sglang/srt/managers/image_processor.py +187 -0
  28. sglang/srt/managers/io_struct.py +99 -75
  29. sglang/srt/managers/schedule_batch.py +184 -63
  30. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  31. sglang/srt/managers/scheduler.py +1021 -0
  32. sglang/srt/managers/tokenizer_manager.py +120 -248
  33. sglang/srt/managers/tp_worker.py +28 -925
  34. sglang/srt/mem_cache/memory_pool.py +34 -52
  35. sglang/srt/model_executor/cuda_graph_runner.py +15 -19
  36. sglang/srt/model_executor/forward_batch_info.py +94 -95
  37. sglang/srt/model_executor/model_runner.py +76 -75
  38. sglang/srt/models/baichuan.py +10 -10
  39. sglang/srt/models/chatglm.py +12 -12
  40. sglang/srt/models/commandr.py +10 -10
  41. sglang/srt/models/dbrx.py +12 -12
  42. sglang/srt/models/deepseek.py +10 -10
  43. sglang/srt/models/deepseek_v2.py +14 -15
  44. sglang/srt/models/exaone.py +10 -10
  45. sglang/srt/models/gemma.py +10 -10
  46. sglang/srt/models/gemma2.py +11 -11
  47. sglang/srt/models/gpt_bigcode.py +10 -10
  48. sglang/srt/models/grok.py +10 -10
  49. sglang/srt/models/internlm2.py +10 -10
  50. sglang/srt/models/llama.py +14 -10
  51. sglang/srt/models/llama_classification.py +5 -5
  52. sglang/srt/models/llama_embedding.py +4 -4
  53. sglang/srt/models/llama_reward.py +142 -0
  54. sglang/srt/models/llava.py +39 -33
  55. sglang/srt/models/llavavid.py +31 -28
  56. sglang/srt/models/minicpm.py +10 -10
  57. sglang/srt/models/minicpm3.py +14 -15
  58. sglang/srt/models/mixtral.py +10 -10
  59. sglang/srt/models/mixtral_quant.py +10 -10
  60. sglang/srt/models/olmoe.py +10 -10
  61. sglang/srt/models/qwen.py +10 -10
  62. sglang/srt/models/qwen2.py +11 -11
  63. sglang/srt/models/qwen2_moe.py +10 -10
  64. sglang/srt/models/stablelm.py +10 -10
  65. sglang/srt/models/torch_native_llama.py +506 -0
  66. sglang/srt/models/xverse.py +10 -10
  67. sglang/srt/models/xverse_moe.py +10 -10
  68. sglang/srt/sampling/sampling_batch_info.py +36 -27
  69. sglang/srt/sampling/sampling_params.py +3 -1
  70. sglang/srt/server.py +170 -119
  71. sglang/srt/server_args.py +54 -27
  72. sglang/srt/utils.py +101 -128
  73. sglang/test/runners.py +71 -26
  74. sglang/test/test_programs.py +38 -5
  75. sglang/test/test_utils.py +18 -9
  76. sglang/version.py +1 -1
  77. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
  78. sglang-0.3.3.dist-info/RECORD +139 -0
  79. sglang/srt/layers/attention_backend.py +0 -474
  80. sglang/srt/managers/controller_multi.py +0 -207
  81. sglang/srt/managers/controller_single.py +0 -164
  82. sglang-0.3.2.dist-info/RECORD +0 -135
  83. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  84. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  85. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  86. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  87. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,49 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from torch import nn
4
+
5
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
6
+
7
+
8
+ class AttentionBackend(ABC):
9
+ """The base class of attention backends"""
10
+
11
+ @abstractmethod
12
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
13
+ """Init the metadata for a forward pass."""
14
+ raise NotImplementedError()
15
+
16
+ def init_cuda_graph_state(self, max_bs: int):
17
+ """Init the global shared states for cuda graph."""
18
+ raise NotImplementedError()
19
+
20
+ def init_forward_metadata_capture_cuda_graph(
21
+ self, bs: int, req_pool_indices, seq_lens
22
+ ):
23
+ """Init the metadata for a forward pass for capturing a cuda graph."""
24
+ raise NotImplementedError()
25
+
26
+ def init_forward_metadata_replay_cuda_graph(
27
+ self, bs: int, req_pool_indices, seq_lens
28
+ ):
29
+ """Init the metadata for a forward pass for replying a cuda graph."""
30
+ raise NotImplementedError()
31
+
32
+ def get_cuda_graph_seq_len_fill_value(self):
33
+ """Get the fill value for padded seq lens. Typically, it is 0 or 1."""
34
+ raise NotImplementedError()
35
+
36
+ def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
37
+ """Run forward on an attention layer."""
38
+ if forward_batch.forward_mode.is_decode():
39
+ return self.forward_decode(q, k, v, layer, forward_batch)
40
+ else:
41
+ return self.forward_extend(q, k, v, layer, forward_batch)
42
+
43
+ def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
44
+ """Run a forward for decode."""
45
+ raise NotImplementedError()
46
+
47
+ def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
48
+ """Run a forward for extend."""
49
+ raise NotImplementedError()
@@ -0,0 +1,277 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Support different attention backends.
5
+ Now there are two backends: FlashInfer and Triton.
6
+ FlashInfer is faster and Triton is easier to customize.
7
+ Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
+ """
9
+
10
+ from typing import TYPE_CHECKING
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ from sglang.global_config import global_config
16
+ from sglang.srt.layers.attention import AttentionBackend
17
+ from sglang.srt.layers.attention.flashinfer_utils import (
18
+ WrapperDispatch,
19
+ update_flashinfer_indices,
20
+ )
21
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
22
+ from sglang.srt.utils import is_flashinfer_available
23
+
24
+ if TYPE_CHECKING:
25
+ from sglang.srt.model_executor.model_runner import ModelRunner
26
+
27
+ if is_flashinfer_available():
28
+ from flashinfer import (
29
+ BatchDecodeWithPagedKVCacheWrapper,
30
+ BatchPrefillWithPagedKVCacheWrapper,
31
+ BatchPrefillWithRaggedKVCacheWrapper,
32
+ )
33
+ from flashinfer.cascade import merge_state
34
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
35
+
36
+
37
+ class FlashInferAttnBackend(AttentionBackend):
38
+ """Flashinfer attention kernels."""
39
+
40
+ def __init__(self, model_runner: ModelRunner):
41
+ super().__init__()
42
+ self.model_runner = model_runner
43
+
44
+ if not _grouped_size_compiled_for_decode_kernels(
45
+ model_runner.model_config.num_attention_heads // model_runner.tp_size,
46
+ model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
47
+ ):
48
+ self.decode_use_tensor_cores = True
49
+ else:
50
+ self.decode_use_tensor_cores = False
51
+
52
+ self.workspace_buffer = torch.empty(
53
+ global_config.flashinfer_workspace_size,
54
+ dtype=torch.uint8,
55
+ device="cuda",
56
+ )
57
+
58
+ assert not (
59
+ model_runner.sliding_window_size is not None
60
+ and model_runner.has_cross_attention
61
+ ), "Sliding window and cross attention are not supported together"
62
+
63
+ self.num_wrappers = 1
64
+ self.dispatch_reason = None
65
+ if model_runner.sliding_window_size is not None:
66
+ self.num_wrappers = 2
67
+ self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
68
+ elif model_runner.has_cross_attention:
69
+ self.num_wrappers = 2
70
+ self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
71
+
72
+ # NOTE: we do not use ragged attention when there are multiple wrappers
73
+ self.prefill_wrapper_ragged = (
74
+ BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
75
+ if self.num_wrappers == 1
76
+ else None
77
+ )
78
+
79
+ # Two wrappers: one for sliding window attention and one for full attention.
80
+ # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
81
+ self.prefill_wrappers_paged = []
82
+ self.decode_wrappers = []
83
+ for _ in range(self.num_wrappers):
84
+ self.prefill_wrappers_paged.append(
85
+ BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
86
+ )
87
+ self.decode_wrappers.append(
88
+ BatchDecodeWithPagedKVCacheWrapper(
89
+ self.workspace_buffer,
90
+ "NHD",
91
+ use_tensor_cores=self.decode_use_tensor_cores,
92
+ )
93
+ )
94
+
95
+ self.forward_metadata = None
96
+ self.cuda_graph_metadata = {}
97
+
98
+ def _get_wrapper_idx(self, layer: nn.Module):
99
+ if self.num_wrappers == 1:
100
+ return 0
101
+
102
+ if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
103
+ return layer.sliding_window_size == -1
104
+ if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
105
+ return layer.is_cross_attention
106
+
107
+ raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
108
+
109
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
110
+ if forward_batch.forward_mode.is_decode():
111
+ prefix_lens = None
112
+ use_ragged = False
113
+ extend_no_prefix = False
114
+ total_num_tokens = None
115
+ else:
116
+ prefix_lens = forward_batch.extend_prefix_lens
117
+
118
+ # Some heuristics to check whether to use ragged forward
119
+ use_ragged = False
120
+ if (
121
+ torch.sum(forward_batch.seq_lens).item() >= 4096
122
+ and self.num_wrappers == 1
123
+ ):
124
+ use_ragged = True
125
+
126
+ total_num_tokens = torch.sum(forward_batch.seq_lens).item()
127
+ extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
128
+
129
+ update_flashinfer_indices(
130
+ forward_batch.forward_mode,
131
+ self.model_runner,
132
+ forward_batch.req_pool_indices,
133
+ forward_batch.seq_lens,
134
+ prefix_lens,
135
+ use_ragged=use_ragged,
136
+ )
137
+
138
+ self.forward_metadata = (
139
+ use_ragged,
140
+ extend_no_prefix,
141
+ total_num_tokens,
142
+ self.decode_wrappers,
143
+ )
144
+
145
+ def init_cuda_graph_state(self, max_bs: int):
146
+ self.cuda_graph_kv_indptr = torch.zeros(
147
+ (max_bs + 1,), dtype=torch.int32, device="cuda"
148
+ )
149
+ self.cuda_graph_kv_indices = torch.zeros(
150
+ (max_bs * self.model_runner.model_config.context_len,),
151
+ dtype=torch.int32,
152
+ device="cuda",
153
+ )
154
+ self.cuda_graph_kv_last_page_len = torch.ones(
155
+ (max_bs,), dtype=torch.int32, device="cuda"
156
+ )
157
+
158
+ # NOTE: the buffers are always in the form of list
159
+ self.cuda_graph_kv_indptr = [self.cuda_graph_kv_indptr] + [
160
+ self.cuda_graph_kv_indptr.clone() for _ in range(self.num_wrappers - 1)
161
+ ]
162
+ self.cuda_graph_kv_indices = [self.cuda_graph_kv_indices] + [
163
+ self.cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
164
+ ]
165
+
166
+ def init_forward_metadata_capture_cuda_graph(
167
+ self, bs: int, req_pool_indices, seq_lens
168
+ ):
169
+ decode_wrappers = []
170
+ for i in range(self.num_wrappers):
171
+ decode_wrappers.append(
172
+ BatchDecodeWithPagedKVCacheWrapper(
173
+ self.workspace_buffer,
174
+ "NHD",
175
+ use_cuda_graph=True,
176
+ use_tensor_cores=self.decode_use_tensor_cores,
177
+ paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
178
+ paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
179
+ paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
180
+ )
181
+ )
182
+
183
+ update_flashinfer_indices(
184
+ ForwardMode.DECODE,
185
+ self.model_runner,
186
+ req_pool_indices,
187
+ seq_lens,
188
+ None,
189
+ decode_wrappers,
190
+ )
191
+
192
+ self.cuda_graph_metadata[bs] = decode_wrappers
193
+
194
+ self.forward_metadata = (False, False, None, decode_wrappers)
195
+
196
+ def init_forward_metadata_replay_cuda_graph(
197
+ self, bs: int, req_pool_indices, seq_lens
198
+ ):
199
+ update_flashinfer_indices(
200
+ ForwardMode.DECODE,
201
+ self.model_runner,
202
+ req_pool_indices[:bs],
203
+ seq_lens[:bs],
204
+ None,
205
+ self.cuda_graph_metadata[bs],
206
+ )
207
+
208
+ def get_cuda_graph_seq_len_fill_value(self):
209
+ return 0
210
+
211
+ def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
212
+ prefill_wrapper_paged = self.prefill_wrappers_paged[
213
+ self._get_wrapper_idx(layer)
214
+ ]
215
+
216
+ use_ragged, extend_no_prefix, _, _ = self.forward_metadata
217
+
218
+ if not use_ragged:
219
+ if k is not None:
220
+ assert v is not None
221
+ forward_batch.token_to_kv_pool.set_kv_buffer(
222
+ layer.layer_id, forward_batch.out_cache_loc, k, v
223
+ )
224
+ o = prefill_wrapper_paged.forward(
225
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
226
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
227
+ causal=True,
228
+ sm_scale=layer.scaling,
229
+ window_left=layer.sliding_window_size,
230
+ logits_soft_cap=layer.logit_cap,
231
+ )
232
+ else:
233
+ o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
234
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
235
+ k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
236
+ v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
237
+ causal=True,
238
+ sm_scale=layer.scaling,
239
+ logits_soft_cap=layer.logit_cap,
240
+ )
241
+
242
+ if extend_no_prefix:
243
+ o = o1
244
+ else:
245
+ o2, s2 = prefill_wrapper_paged.forward_return_lse(
246
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
247
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
248
+ causal=False,
249
+ sm_scale=layer.scaling,
250
+ logits_soft_cap=layer.logit_cap,
251
+ )
252
+
253
+ o, _ = merge_state(o1, s1, o2, s2)
254
+
255
+ forward_batch.token_to_kv_pool.set_kv_buffer(
256
+ layer.layer_id, forward_batch.out_cache_loc, k, v
257
+ )
258
+
259
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
260
+
261
+ def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
262
+ decode_wrapper = self.forward_metadata[-1][self._get_wrapper_idx(layer)]
263
+
264
+ if k is not None:
265
+ assert v is not None
266
+ forward_batch.token_to_kv_pool.set_kv_buffer(
267
+ layer.layer_id, forward_batch.out_cache_loc, k, v
268
+ )
269
+
270
+ o = decode_wrapper.forward(
271
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
272
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
273
+ sm_scale=layer.scaling,
274
+ logits_soft_cap=layer.logit_cap,
275
+ )
276
+
277
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -1,8 +1,15 @@
1
+ from enum import Enum, auto
2
+
1
3
  import torch
2
4
  import triton
3
5
  import triton.language as tl
4
6
 
5
7
 
8
+ class WrapperDispatch(Enum):
9
+ SLIDING_WINDOW = auto()
10
+ CROSS_ATTENTION = auto()
11
+
12
+
6
13
  @triton.jit
7
14
  def create_flashinfer_kv_indices_triton(
8
15
  req_to_token_ptr, # [max_batch, max_context_len]
@@ -47,7 +54,7 @@ class FlashinferUpdater:
47
54
  req_pool_indices,
48
55
  seq_lens,
49
56
  prefix_lens,
50
- decode_wrapper=None,
57
+ decode_wrappers=None,
51
58
  use_ragged=False,
52
59
  ):
53
60
  self.forward_mode = forward_mode
@@ -66,82 +73,22 @@ class FlashinferUpdater:
66
73
  self.head_dim = model_runner.model_config.head_dim
67
74
  self.batch_size = len(req_pool_indices)
68
75
 
69
- self.decode_wrapper = (
70
- decode_wrapper or self.model_runner.attn_backend.decode_wrapper
76
+ self.decode_wrappers = (
77
+ decode_wrappers or self.model_runner.attn_backend.decode_wrappers
71
78
  )
72
79
  self.prefill_wrapper_ragged = (
73
80
  self.model_runner.attn_backend.prefill_wrapper_ragged
74
81
  )
75
- self.prefill_wrapper_paged = (
76
- self.model_runner.attn_backend.prefill_wrapper_paged
82
+ self.prefill_wrappers_paged = (
83
+ self.model_runner.attn_backend.prefill_wrappers_paged
77
84
  )
78
85
 
79
86
  self.kv_last_page_len = torch.ones(
80
87
  (self.batch_size,), dtype=torch.int32, device="cuda"
81
88
  )
82
89
 
83
- def _init_indices_no_sliding_window(self):
84
- if self.use_ragged:
85
- paged_kernel_lens = self.prefix_lens
86
- else:
87
- paged_kernel_lens = self.seq_lens
88
-
89
- self.kv_indptr = torch.zeros(
90
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
91
- )
92
- self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
93
- self.kv_indices = torch.empty(
94
- self.kv_indptr[-1], dtype=torch.int32, device="cuda"
95
- )
96
-
97
- create_flashinfer_kv_indices_triton[(self.batch_size,)](
98
- self.model_runner.req_to_token_pool.req_to_token,
99
- self.req_pool_indices,
100
- paged_kernel_lens,
101
- self.kv_indptr,
102
- None,
103
- self.kv_indices,
104
- self.model_runner.req_to_token_pool.req_to_token.size(1),
105
- )
106
-
107
- def _init_indices_sliding_window(self, wrapper_id):
108
- if wrapper_id == 0:
109
- # window attention use paged only
110
- if self.forward_mode.is_decode():
111
- paged_kernel_lens = torch.minimum(
112
- self.seq_lens,
113
- torch.tensor(self.model_runner.sliding_window_size + 1),
114
- )
115
- else:
116
- paged_kernel_lens = torch.minimum(
117
- self.seq_lens,
118
- torch.tensor(self.model_runner.sliding_window_size)
119
- + self.seq_lens
120
- - self.prefix_lens,
121
- )
122
- else:
123
- # full attention
124
- paged_kernel_lens = self.seq_lens
125
-
126
- kv_start_idx = self.seq_lens - paged_kernel_lens
127
- self.kv_indptr = torch.zeros(
128
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
129
- )
130
- self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
131
- self.kv_indices = torch.empty(
132
- self.kv_indptr[-1], dtype=torch.int32, device="cuda"
133
- )
134
- create_flashinfer_kv_indices_triton[(self.batch_size,)](
135
- self.model_runner.req_to_token_pool.req_to_token,
136
- self.req_pool_indices,
137
- paged_kernel_lens,
138
- self.kv_indptr,
139
- kv_start_idx,
140
- self.kv_indices,
141
- self.model_runner.req_to_token_pool.req_to_token.size(1),
142
- )
143
-
144
90
  def _update_decode_indices(self, decode_wrapper):
91
+ assert not isinstance(decode_wrapper, list)
145
92
  decode_wrapper.end_forward()
146
93
  decode_wrapper.begin_forward(
147
94
  self.kv_indptr,
@@ -156,6 +103,9 @@ class FlashinferUpdater:
156
103
  )
157
104
 
158
105
  def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
106
+ assert not isinstance(paged_wrapper, list)
107
+ assert not isinstance(ragged_wrapper, list)
108
+
159
109
  # extend part
160
110
  qo_indptr = torch.zeros(
161
111
  (self.batch_size + 1,), dtype=torch.int32, device="cuda"
@@ -185,28 +135,75 @@ class FlashinferUpdater:
185
135
  1,
186
136
  )
187
137
 
188
- def update_indices_no_sliding_window(self):
189
- self._init_indices_no_sliding_window()
138
+ def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0):
139
+ if dispatch_reason is None:
140
+ if self.use_ragged:
141
+ paged_kernel_lens = self.prefix_lens
142
+ else:
143
+ paged_kernel_lens = self.seq_lens
144
+ self.kv_start_idx = None
145
+ elif dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
146
+ if wrapper_id == 0:
147
+ # window attention use paged only
148
+ if self.forward_mode.is_decode():
149
+ paged_kernel_lens = torch.minimum(
150
+ self.seq_lens,
151
+ torch.tensor(self.model_runner.sliding_window_size + 1),
152
+ )
153
+ else:
154
+ paged_kernel_lens = torch.minimum(
155
+ self.seq_lens,
156
+ torch.tensor(self.model_runner.sliding_window_size)
157
+ + self.seq_lens
158
+ - self.prefix_lens,
159
+ )
160
+ else:
161
+ # full attention
162
+ paged_kernel_lens = self.seq_lens
163
+ self.kv_start_idx = self.seq_lens - paged_kernel_lens
164
+
165
+ self.kv_indptr = torch.zeros(
166
+ (self.batch_size + 1,), dtype=torch.int32, device="cuda"
167
+ )
168
+ self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
169
+ self.kv_indices = torch.empty(
170
+ self.kv_indptr[-1], dtype=torch.int32, device="cuda"
171
+ )
172
+
173
+ create_flashinfer_kv_indices_triton[(self.batch_size,)](
174
+ self.model_runner.req_to_token_pool.req_to_token,
175
+ self.req_pool_indices,
176
+ paged_kernel_lens,
177
+ self.kv_indptr,
178
+ self.kv_start_idx,
179
+ self.kv_indices,
180
+ self.model_runner.req_to_token_pool.req_to_token.size(1),
181
+ )
182
+
183
+ def _update_indicess_single_wrapper(self):
184
+ self._get_indices()
190
185
 
191
186
  if self.forward_mode.is_decode():
192
- self._update_decode_indices(self.decode_wrapper)
187
+ self._update_decode_indices(self.decode_wrappers[0])
193
188
  else:
194
189
  self._update_extend_indices(
195
190
  self.prefill_wrapper_ragged,
196
- self.prefill_wrapper_paged,
191
+ self.prefill_wrappers_paged[0],
197
192
  )
198
193
 
199
- def update_indices_sliding_window(self):
200
- assert self.use_ragged is False
194
+ def _update_indices_cross_attention(self):
195
+ pass
201
196
 
197
+ def _update_indices_sliding_window(self):
198
+ assert self.use_ragged is False
202
199
  for wrapper_id in range(2):
203
- self._init_indices_sliding_window(wrapper_id)
200
+ self._get_indices(WrapperDispatch.SLIDING_WINDOW, wrapper_id)
204
201
  if self.forward_mode.is_decode():
205
- self._update_decode_indices(self.decode_wrapper[wrapper_id])
202
+ self._update_decode_indices(self.decode_wrappers[wrapper_id])
206
203
  else:
207
204
  self._update_extend_indices(
208
205
  None,
209
- self.prefill_wrapper_paged[wrapper_id],
206
+ self.prefill_wrappers_paged[wrapper_id],
210
207
  )
211
208
 
212
209
 
@@ -216,7 +213,7 @@ def update_flashinfer_indices(
216
213
  req_pool_indices,
217
214
  seq_lens,
218
215
  prefix_lens,
219
- decode_wrapper=None,
216
+ decode_wrappers=None,
220
217
  use_ragged=False,
221
218
  ):
222
219
  updater = FlashinferUpdater(
@@ -225,11 +222,16 @@ def update_flashinfer_indices(
225
222
  req_pool_indices,
226
223
  seq_lens,
227
224
  prefix_lens,
228
- decode_wrapper,
225
+ decode_wrappers,
229
226
  use_ragged,
230
227
  )
231
228
 
232
- if model_runner.sliding_window_size is None:
233
- updater.update_indices_no_sliding_window()
229
+ dispatch_reason = model_runner.attn_backend.dispatch_reason
230
+
231
+ if dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
232
+ updater._update_indices_sliding_window()
233
+ elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
234
+ updater._update_indices_cross_attention()
234
235
  else:
235
- updater.update_indices_sliding_window()
236
+ assert model_runner.attn_backend.num_wrappers == 1
237
+ updater._update_indicess_single_wrapper()