sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +9 -6
  3. sglang/bench_serving.py +46 -22
  4. sglang/global_config.py +1 -1
  5. sglang/lang/backend/runtime_endpoint.py +60 -49
  6. sglang/lang/compiler.py +2 -2
  7. sglang/lang/interpreter.py +4 -2
  8. sglang/lang/ir.py +16 -7
  9. sglang/srt/constrained/base_tool_cache.py +1 -1
  10. sglang/srt/constrained/fsm_cache.py +12 -2
  11. sglang/srt/constrained/jump_forward.py +13 -2
  12. sglang/srt/layers/activation.py +32 -0
  13. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  14. sglang/srt/layers/extend_attention.py +9 -2
  15. sglang/srt/layers/fused_moe/__init__.py +1 -0
  16. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  17. sglang/srt/layers/fused_moe/layer.py +587 -0
  18. sglang/srt/layers/layernorm.py +65 -0
  19. sglang/srt/layers/logits_processor.py +7 -2
  20. sglang/srt/layers/pooler.py +50 -0
  21. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  22. sglang/srt/layers/radix_attention.py +40 -16
  23. sglang/srt/managers/detokenizer_manager.py +31 -9
  24. sglang/srt/managers/io_struct.py +63 -0
  25. sglang/srt/managers/policy_scheduler.py +173 -25
  26. sglang/srt/managers/schedule_batch.py +115 -97
  27. sglang/srt/managers/tokenizer_manager.py +194 -112
  28. sglang/srt/managers/tp_worker.py +290 -359
  29. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  30. sglang/srt/mem_cache/chunk_cache.py +43 -20
  31. sglang/srt/mem_cache/memory_pool.py +2 -2
  32. sglang/srt/mem_cache/radix_cache.py +74 -40
  33. sglang/srt/model_executor/cuda_graph_runner.py +71 -25
  34. sglang/srt/model_executor/forward_batch_info.py +293 -156
  35. sglang/srt/model_executor/model_runner.py +77 -57
  36. sglang/srt/models/chatglm.py +2 -2
  37. sglang/srt/models/commandr.py +1 -1
  38. sglang/srt/models/deepseek.py +2 -2
  39. sglang/srt/models/deepseek_v2.py +7 -6
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +11 -6
  42. sglang/srt/models/grok.py +50 -396
  43. sglang/srt/models/internlm2.py +2 -7
  44. sglang/srt/models/llama2.py +4 -4
  45. sglang/srt/models/llama_embedding.py +88 -0
  46. sglang/srt/models/minicpm.py +2 -2
  47. sglang/srt/models/mixtral.py +56 -254
  48. sglang/srt/models/mixtral_quant.py +1 -4
  49. sglang/srt/models/qwen.py +2 -2
  50. sglang/srt/models/qwen2.py +2 -2
  51. sglang/srt/models/qwen2_moe.py +2 -13
  52. sglang/srt/models/stablelm.py +1 -1
  53. sglang/srt/openai_api/adapter.py +187 -48
  54. sglang/srt/openai_api/protocol.py +37 -1
  55. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  56. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  57. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  58. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  59. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  60. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  61. sglang/srt/sampling_params.py +31 -8
  62. sglang/srt/server.py +91 -29
  63. sglang/srt/server_args.py +32 -19
  64. sglang/srt/utils.py +32 -15
  65. sglang/test/run_eval.py +10 -1
  66. sglang/test/runners.py +81 -73
  67. sglang/test/simple_eval_humaneval.py +2 -8
  68. sglang/test/simple_eval_mgsm.py +203 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  70. sglang/test/test_layernorm.py +60 -0
  71. sglang/test/test_programs.py +36 -7
  72. sglang/test/test_utils.py +24 -2
  73. sglang/utils.py +0 -1
  74. sglang/version.py +1 -1
  75. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
  76. sglang-0.2.13.dist-info/RECORD +112 -0
  77. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  78. sglang/srt/layers/linear.py +0 -884
  79. sglang/srt/layers/quantization/__init__.py +0 -64
  80. sglang/srt/layers/quantization/fp8.py +0 -677
  81. sglang/srt/model_loader/model_loader.py +0 -292
  82. sglang/srt/model_loader/utils.py +0 -275
  83. sglang-0.2.11.dist-info/RECORD +0 -102
  84. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  85. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -16,13 +16,17 @@ limitations under the License.
16
16
  """ModelRunner runs the forward passes of the models."""
17
17
  from dataclasses import dataclass
18
18
  from enum import IntEnum, auto
19
- from typing import List
19
+ from typing import TYPE_CHECKING, List, Optional
20
20
 
21
21
  import numpy as np
22
22
  import torch
23
23
 
24
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
24
25
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
25
26
 
27
+ if TYPE_CHECKING:
28
+ from sglang.srt.model_executor.model_runner import ModelRunner
29
+
26
30
 
27
31
  class ForwardMode(IntEnum):
28
32
  # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
@@ -39,25 +43,33 @@ class InputMetadata:
39
43
 
40
44
  forward_mode: ForwardMode
41
45
  batch_size: int
42
- total_num_tokens: int
43
46
  req_pool_indices: torch.Tensor
44
47
  seq_lens: torch.Tensor
45
- positions: torch.Tensor
46
48
  req_to_token_pool: ReqToTokenPool
47
49
  token_to_kv_pool: BaseTokenToKVPool
48
50
 
49
- # For extend
50
- extend_seq_lens: torch.Tensor
51
- extend_start_loc: torch.Tensor
52
- extend_no_prefix: bool
53
-
54
51
  # Output location of the KV cache
55
- out_cache_loc: torch.Tensor = None
52
+ out_cache_loc: torch.Tensor
53
+
54
+ total_num_tokens: int = None
55
+
56
+ # Position information
57
+ positions: torch.Tensor = None
58
+
59
+ # For extend
60
+ extend_seq_lens: torch.Tensor = None
61
+ extend_start_loc: torch.Tensor = None
62
+ extend_no_prefix: bool = None
56
63
 
57
64
  # Output options
58
65
  return_logprob: bool = False
59
66
  top_logprobs_nums: List[int] = None
60
67
 
68
+ # For multimodal
69
+ pixel_values: List[torch.Tensor] = None
70
+ image_sizes: List[List[int]] = None
71
+ image_offsets: List[int] = None
72
+
61
73
  # Trition attention backend
62
74
  triton_max_seq_len: int = 0
63
75
  triton_max_extend_len: int = 0
@@ -70,107 +82,175 @@ class InputMetadata:
70
82
  flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
71
83
  flashinfer_use_ragged: bool = False
72
84
 
73
- @classmethod
74
- def create(
75
- cls,
76
- model_runner,
77
- forward_mode,
78
- req_pool_indices,
79
- seq_lens,
80
- prefix_lens,
81
- position_ids_offsets,
82
- out_cache_loc,
83
- top_logprobs_nums=None,
84
- return_logprob=False,
85
- skip_flashinfer_init=False,
86
- ):
87
- flashinfer_use_ragged = False
88
- if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
89
- if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
90
- flashinfer_use_ragged = True
91
- init_flashinfer_args(
92
- forward_mode,
93
- model_runner,
94
- req_pool_indices,
95
- seq_lens,
96
- prefix_lens,
97
- model_runner.flashinfer_decode_wrapper,
98
- flashinfer_use_ragged,
85
+ def init_multimuldal_info(self, batch: ScheduleBatch):
86
+ reqs = batch.reqs
87
+ self.pixel_values = [r.pixel_values for r in reqs]
88
+ self.image_sizes = [r.image_size for r in reqs]
89
+ self.image_offsets = [
90
+ (
91
+ (r.image_offset - len(r.prefix_indices))
92
+ if r.image_offset is not None
93
+ else 0
99
94
  )
95
+ for r in reqs
96
+ ]
100
97
 
101
- batch_size = len(req_pool_indices)
98
+ def compute_positions(self, batch: ScheduleBatch):
99
+ position_ids_offsets = batch.position_ids_offsets
102
100
 
103
- if forward_mode == ForwardMode.DECODE:
104
- positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
105
- extend_seq_lens = extend_start_loc = extend_no_prefix = None
106
- if not model_runner.server_args.disable_flashinfer:
107
- # This variable is not needed in this case,
108
- # we do not compute it to make it compatbile with cuda graph.
109
- total_num_tokens = None
101
+ if self.forward_mode == ForwardMode.DECODE:
102
+ if True:
103
+ self.positions = self.seq_lens - 1
110
104
  else:
111
- total_num_tokens = int(torch.sum(seq_lens))
105
+ # Deprecated
106
+ self.positions = (self.seq_lens - 1) + position_ids_offsets
112
107
  else:
113
- seq_lens_cpu = seq_lens.cpu().numpy()
114
- prefix_lens_cpu = prefix_lens.cpu().numpy()
115
- position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
116
- positions = torch.tensor(
117
- np.concatenate(
118
- [
119
- np.arange(
120
- prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
121
- seq_lens_cpu[i] + position_ids_offsets_cpu[i],
122
- )
123
- for i in range(batch_size)
124
- ],
125
- axis=0,
126
- ),
127
- device="cuda",
128
- )
129
- extend_seq_lens = seq_lens - prefix_lens
130
- extend_start_loc = torch.zeros_like(seq_lens)
131
- extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
132
- extend_no_prefix = torch.all(prefix_lens == 0)
133
- total_num_tokens = int(torch.sum(seq_lens))
108
+ if True:
109
+ self.positions = torch.tensor(
110
+ np.concatenate(
111
+ [
112
+ np.arange(len(req.prefix_indices), len(req.fill_ids))
113
+ for req in batch.reqs
114
+ ],
115
+ axis=0,
116
+ ),
117
+ device="cuda",
118
+ )
119
+ else:
120
+ # Deprecated
121
+ position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
122
+ self.positions = torch.tensor(
123
+ np.concatenate(
124
+ [
125
+ np.arange(
126
+ len(req.prefix_indices) + position_ids_offsets_cpu[i],
127
+ len(req.fill_ids) + position_ids_offsets_cpu[i],
128
+ )
129
+ for i, req in enumerate(batch.reqs)
130
+ ],
131
+ axis=0,
132
+ ),
133
+ device="cuda",
134
+ )
135
+
136
+ # Positions should be in long type
137
+ self.positions = self.positions.to(torch.int64)
138
+
139
+ def compute_extend_infos(self, batch: ScheduleBatch):
140
+ if self.forward_mode == ForwardMode.DECODE:
141
+ self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
142
+ else:
143
+ extend_lens_cpu = [
144
+ len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
145
+ ]
146
+ self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
147
+ self.extend_start_loc = torch.zeros_like(self.seq_lens)
148
+ self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
149
+ self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
134
150
 
151
+ @classmethod
152
+ def from_schedule_batch(
153
+ cls,
154
+ model_runner: "ModelRunner",
155
+ batch: ScheduleBatch,
156
+ forward_mode: ForwardMode,
157
+ ):
135
158
  ret = cls(
136
159
  forward_mode=forward_mode,
137
- batch_size=batch_size,
138
- total_num_tokens=total_num_tokens,
139
- req_pool_indices=req_pool_indices,
140
- seq_lens=seq_lens,
141
- positions=positions,
160
+ batch_size=batch.batch_size(),
161
+ req_pool_indices=batch.req_pool_indices,
162
+ seq_lens=batch.seq_lens,
142
163
  req_to_token_pool=model_runner.req_to_token_pool,
143
164
  token_to_kv_pool=model_runner.token_to_kv_pool,
144
- out_cache_loc=out_cache_loc,
145
- extend_seq_lens=extend_seq_lens,
146
- extend_start_loc=extend_start_loc,
147
- extend_no_prefix=extend_no_prefix,
148
- return_logprob=return_logprob,
149
- top_logprobs_nums=top_logprobs_nums,
150
- flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
151
- flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
152
- flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
153
- flashinfer_use_ragged=flashinfer_use_ragged,
165
+ out_cache_loc=batch.out_cache_loc,
166
+ return_logprob=batch.return_logprob,
167
+ top_logprobs_nums=batch.top_logprobs_nums,
154
168
  )
155
169
 
170
+ ret.compute_positions(batch)
171
+
172
+ ret.compute_extend_infos(batch)
173
+
174
+ if (
175
+ forward_mode != ForwardMode.DECODE
176
+ or model_runner.server_args.disable_flashinfer
177
+ ):
178
+ ret.total_num_tokens = int(torch.sum(ret.seq_lens))
179
+
180
+ if forward_mode != ForwardMode.DECODE:
181
+ ret.init_multimuldal_info(batch)
182
+
183
+ prefix_lens = None
184
+ if forward_mode != ForwardMode.DECODE:
185
+ prefix_lens = torch.tensor(
186
+ [len(r.prefix_indices) for r in batch.reqs], device="cuda"
187
+ )
188
+
156
189
  if model_runner.server_args.disable_flashinfer:
157
- (
158
- ret.triton_max_seq_len,
159
- ret.triton_max_extend_len,
160
- ret.triton_start_loc,
161
- ret.triton_prefix_lens,
162
- ) = init_triton_args(forward_mode, seq_lens, prefix_lens)
190
+ ret.init_triton_args(batch, prefix_lens)
191
+
192
+ flashinfer_use_ragged = False
193
+ if not model_runner.server_args.disable_flashinfer:
194
+ if (
195
+ forward_mode != ForwardMode.DECODE
196
+ and int(torch.sum(ret.seq_lens)) > 4096
197
+ and model_runner.sliding_window_size is None
198
+ ):
199
+ flashinfer_use_ragged = True
200
+ ret.init_flashinfer_handlers(
201
+ model_runner, prefix_lens, flashinfer_use_ragged
202
+ )
163
203
 
164
204
  return ret
165
205
 
206
+ def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
207
+ """Init auxiliary variables for triton attention backend."""
208
+ self.triton_max_seq_len = int(torch.max(self.seq_lens))
209
+ self.triton_prefix_lens = prefix_lens
210
+ self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
211
+ self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
212
+
213
+ if self.forward_mode == ForwardMode.DECODE:
214
+ self.triton_max_extend_len = None
215
+ else:
216
+ extend_seq_lens = self.seq_lens - prefix_lens
217
+ self.triton_max_extend_len = int(torch.max(extend_seq_lens))
218
+
219
+ def init_flashinfer_handlers(
220
+ self,
221
+ model_runner,
222
+ prefix_lens,
223
+ flashinfer_use_ragged,
224
+ ):
225
+ update_flashinfer_indices(
226
+ self.forward_mode,
227
+ model_runner,
228
+ self.req_pool_indices,
229
+ self.seq_lens,
230
+ prefix_lens,
231
+ flashinfer_use_ragged=flashinfer_use_ragged,
232
+ )
166
233
 
167
- def init_flashinfer_args(
234
+ (
235
+ self.flashinfer_prefill_wrapper_ragged,
236
+ self.flashinfer_prefill_wrapper_paged,
237
+ self.flashinfer_decode_wrapper,
238
+ self.flashinfer_use_ragged,
239
+ ) = (
240
+ model_runner.flashinfer_prefill_wrapper_ragged,
241
+ model_runner.flashinfer_prefill_wrapper_paged,
242
+ model_runner.flashinfer_decode_wrapper,
243
+ flashinfer_use_ragged,
244
+ )
245
+
246
+
247
+ def update_flashinfer_indices(
168
248
  forward_mode,
169
249
  model_runner,
170
250
  req_pool_indices,
171
251
  seq_lens,
172
252
  prefix_lens,
173
- flashinfer_decode_wrapper,
253
+ flashinfer_decode_wrapper=None,
174
254
  flashinfer_use_ragged=False,
175
255
  ):
176
256
  """Init auxiliary variables for FlashInfer attention backend."""
@@ -178,79 +258,136 @@ def init_flashinfer_args(
178
258
  num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
179
259
  head_dim = model_runner.model_config.head_dim
180
260
  batch_size = len(req_pool_indices)
181
- total_num_tokens = int(torch.sum(seq_lens))
182
-
183
- if flashinfer_use_ragged:
184
- paged_kernel_lens = prefix_lens
185
- else:
186
- paged_kernel_lens = seq_lens
187
-
188
- kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
189
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
190
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
191
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
192
- kv_indices = torch.cat(
193
- [
194
- model_runner.req_to_token_pool.req_to_token[
195
- req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
196
- ]
197
- for i in range(batch_size)
198
- ],
199
- dim=0,
200
- ).contiguous()
201
- kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
202
-
203
- if forward_mode == ForwardMode.DECODE:
204
- flashinfer_decode_wrapper.end_forward()
205
- flashinfer_decode_wrapper.begin_forward(
206
- kv_indptr,
207
- kv_indices,
208
- kv_last_page_len,
209
- num_qo_heads,
210
- num_kv_heads,
211
- head_dim,
212
- 1,
213
- )
214
- else:
215
- # extend part
216
- qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
217
- qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
218
261
 
262
+ if model_runner.sliding_window_size is None:
219
263
  if flashinfer_use_ragged:
220
- model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
221
- model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
222
- qo_indptr,
264
+ paged_kernel_lens = prefix_lens
265
+ else:
266
+ paged_kernel_lens = seq_lens
267
+
268
+ kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
269
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
270
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
271
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
272
+ kv_indices = torch.cat(
273
+ [
274
+ model_runner.req_to_token_pool.req_to_token[
275
+ req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
276
+ ]
277
+ for i in range(batch_size)
278
+ ],
279
+ dim=0,
280
+ ).contiguous()
281
+ kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
282
+
283
+ if forward_mode == ForwardMode.DECODE:
284
+ # CUDA graph uses different flashinfer_decode_wrapper
285
+ if flashinfer_decode_wrapper is None:
286
+ flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
287
+
288
+ flashinfer_decode_wrapper.end_forward()
289
+ flashinfer_decode_wrapper.begin_forward(
290
+ kv_indptr,
291
+ kv_indices,
292
+ kv_last_page_len,
293
+ num_qo_heads,
294
+ num_kv_heads,
295
+ head_dim,
296
+ 1,
297
+ )
298
+ else:
299
+ # extend part
300
+ qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
301
+ qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
302
+
303
+ if flashinfer_use_ragged:
304
+ model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
305
+ model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
306
+ qo_indptr,
307
+ qo_indptr,
308
+ num_qo_heads,
309
+ num_kv_heads,
310
+ head_dim,
311
+ )
312
+
313
+ # cached part
314
+ model_runner.flashinfer_prefill_wrapper_paged.end_forward()
315
+ model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
223
316
  qo_indptr,
317
+ kv_indptr,
318
+ kv_indices,
319
+ kv_last_page_len,
224
320
  num_qo_heads,
225
321
  num_kv_heads,
226
322
  head_dim,
323
+ 1,
227
324
  )
228
-
229
- # cached part
230
- model_runner.flashinfer_prefill_wrapper_paged.end_forward()
231
- model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
232
- qo_indptr,
233
- kv_indptr,
234
- kv_indices,
235
- kv_last_page_len,
236
- num_qo_heads,
237
- num_kv_heads,
238
- head_dim,
239
- 1,
240
- )
241
-
242
-
243
- def init_triton_args(forward_mode, seq_lens, prefix_lens):
244
- """Init auxiliary variables for triton attention backend."""
245
- batch_size = len(seq_lens)
246
- max_seq_len = int(torch.max(seq_lens))
247
- start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
248
- start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
249
-
250
- if forward_mode == ForwardMode.DECODE:
251
- max_extend_len = None
252
325
  else:
253
- extend_seq_lens = seq_lens - prefix_lens
254
- max_extend_len = int(torch.max(extend_seq_lens))
255
-
256
- return max_seq_len, max_extend_len, start_loc, prefix_lens
326
+ # window attention use paged only
327
+ kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
328
+ for wrapper_id in range(2):
329
+ if wrapper_id == 0:
330
+ if forward_mode == ForwardMode.DECODE:
331
+ paged_kernel_lens = torch.minimum(
332
+ seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
333
+ )
334
+ else:
335
+ paged_kernel_lens = torch.minimum(
336
+ seq_lens,
337
+ torch.tensor(model_runner.sliding_window_size)
338
+ + seq_lens
339
+ - prefix_lens,
340
+ )
341
+ else:
342
+ paged_kernel_lens = seq_lens
343
+
344
+ kv_start_idx = seq_lens - paged_kernel_lens
345
+
346
+ kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
347
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
348
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
349
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
350
+ kv_indices = torch.cat(
351
+ [
352
+ model_runner.req_to_token_pool.req_to_token[
353
+ req_pool_indices_cpu[i],
354
+ kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
355
+ ]
356
+ for i in range(batch_size)
357
+ ],
358
+ dim=0,
359
+ ).contiguous()
360
+
361
+ if forward_mode == ForwardMode.DECODE:
362
+ # CUDA graph uses different flashinfer_decode_wrapper
363
+ if flashinfer_decode_wrapper is None:
364
+ flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
365
+
366
+ flashinfer_decode_wrapper[wrapper_id].end_forward()
367
+ flashinfer_decode_wrapper[wrapper_id].begin_forward(
368
+ kv_indptr,
369
+ kv_indices,
370
+ kv_last_page_len,
371
+ num_qo_heads,
372
+ num_kv_heads,
373
+ head_dim,
374
+ 1,
375
+ )
376
+ else:
377
+ # extend part
378
+ qo_indptr = torch.zeros(
379
+ (batch_size + 1,), dtype=torch.int32, device="cuda"
380
+ )
381
+ qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
382
+
383
+ model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward()
384
+ model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward(
385
+ qo_indptr,
386
+ kv_indptr,
387
+ kv_indices,
388
+ kv_last_page_len,
389
+ num_qo_heads,
390
+ num_kv_heads,
391
+ head_dim,
392
+ 1,
393
+ )