sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -25,16 +25,18 @@ from vllm.distributed.parallel_state import graph_capture
25
25
  from vllm.model_executor.custom_op import CustomOp
26
26
 
27
27
  from sglang.srt.layers.logits_processor import (
28
- LogitProcessorOutput,
29
28
  LogitsMetadata,
30
29
  LogitsProcessor,
30
+ LogitsProcessorOutput,
31
31
  )
32
+ from sglang.srt.layers.sampler import SampleOutput
32
33
  from sglang.srt.managers.schedule_batch import ScheduleBatch
33
34
  from sglang.srt.model_executor.forward_batch_info import (
34
35
  ForwardMode,
35
36
  InputMetadata,
36
37
  update_flashinfer_indices,
37
38
  )
39
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
38
40
  from sglang.srt.utils import monkey_patch_vllm_all_gather
39
41
 
40
42
 
@@ -84,13 +86,20 @@ def set_torch_compile_config():
84
86
 
85
87
 
86
88
  class CudaGraphRunner:
87
- def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
89
+ def __init__(
90
+ self,
91
+ model_runner,
92
+ max_batch_size_to_capture: int,
93
+ use_torch_compile: bool,
94
+ disable_padding: bool,
95
+ ):
88
96
  self.model_runner = model_runner
89
97
  self.graphs = {}
90
98
  self.input_buffers = {}
91
99
  self.output_buffers = {}
92
100
  self.flashinfer_handlers = {}
93
101
  self.graph_memory_pool = None
102
+ self.disable_padding = disable_padding
94
103
 
95
104
  # Common inputs
96
105
  self.max_bs = max_batch_size_to_capture
@@ -98,8 +107,8 @@ class CudaGraphRunner:
98
107
  self.req_pool_indices = torch.zeros(
99
108
  (self.max_bs,), dtype=torch.int32, device="cuda"
100
109
  )
101
- self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
102
- self.position_ids_offsets = torch.zeros(
110
+ self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
111
+ self.position_ids_offsets = torch.ones(
103
112
  (self.max_bs,), dtype=torch.int32, device="cuda"
104
113
  )
105
114
  self.out_cache_loc = torch.zeros(
@@ -107,9 +116,6 @@ class CudaGraphRunner:
107
116
  )
108
117
 
109
118
  # FlashInfer inputs
110
- self.flashinfer_workspace_buffer = (
111
- self.model_runner.flashinfer_workspace_buffers[0]
112
- )
113
119
  self.flashinfer_kv_indptr = torch.zeros(
114
120
  (self.max_bs + 1,), dtype=torch.int32, device="cuda"
115
121
  )
@@ -121,6 +127,27 @@ class CudaGraphRunner:
121
127
  self.flashinfer_kv_last_page_len = torch.ones(
122
128
  (self.max_bs,), dtype=torch.int32, device="cuda"
123
129
  )
130
+ if model_runner.sliding_window_size is None:
131
+ self.flashinfer_workspace_buffer = (
132
+ self.model_runner.flashinfer_workspace_buffer
133
+ )
134
+ else:
135
+ self.flashinfer_workspace_buffer = (
136
+ self.model_runner.flashinfer_workspace_buffer
137
+ )
138
+
139
+ self.flashinfer_kv_indptr = [
140
+ self.flashinfer_kv_indptr,
141
+ self.flashinfer_kv_indptr.clone(),
142
+ ]
143
+ self.flashinfer_kv_indices = [
144
+ self.flashinfer_kv_indices,
145
+ self.flashinfer_kv_indices.clone(),
146
+ ]
147
+
148
+ # Sampling inputs
149
+ vocab_size = model_runner.model_config.vocab_size
150
+ self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
124
151
 
125
152
  self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
126
153
 
@@ -128,7 +155,10 @@ class CudaGraphRunner:
128
155
  set_torch_compile_config()
129
156
 
130
157
  def can_run(self, batch_size):
131
- return batch_size < self.max_bs
158
+ if self.disable_padding:
159
+ return batch_size in self.graphs
160
+ else:
161
+ return batch_size <= self.max_bs
132
162
 
133
163
  def capture(self, batch_size_list):
134
164
  self.batch_size_list = batch_size_list
@@ -171,15 +201,32 @@ class CudaGraphRunner:
171
201
  use_tensor_cores = True
172
202
  else:
173
203
  use_tensor_cores = False
174
- flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
175
- self.flashinfer_workspace_buffer,
176
- "NHD",
177
- use_cuda_graph=True,
178
- use_tensor_cores=use_tensor_cores,
179
- paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
180
- paged_kv_indices_buffer=self.flashinfer_kv_indices,
181
- paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
182
- )
204
+ if self.model_runner.sliding_window_size is None:
205
+ flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
206
+ self.flashinfer_workspace_buffer,
207
+ "NHD",
208
+ use_cuda_graph=True,
209
+ use_tensor_cores=use_tensor_cores,
210
+ paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
211
+ paged_kv_indices_buffer=self.flashinfer_kv_indices,
212
+ paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
213
+ )
214
+ else:
215
+ flashinfer_decode_wrapper = []
216
+ for i in range(2):
217
+ flashinfer_decode_wrapper.append(
218
+ BatchDecodeWithPagedKVCacheWrapper(
219
+ self.flashinfer_workspace_buffer,
220
+ "NHD",
221
+ use_cuda_graph=True,
222
+ use_tensor_cores=use_tensor_cores,
223
+ paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
224
+ paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
225
+ paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
226
+ :bs
227
+ ],
228
+ )
229
+ )
183
230
  update_flashinfer_indices(
184
231
  ForwardMode.DECODE,
185
232
  self.model_runner,
@@ -193,6 +240,7 @@ class CudaGraphRunner:
193
240
  def run_once():
194
241
  input_metadata = InputMetadata(
195
242
  forward_mode=ForwardMode.DECODE,
243
+ sampling_info=self.sampling_info[:bs],
196
244
  batch_size=bs,
197
245
  req_pool_indices=req_pool_indices,
198
246
  seq_lens=seq_lens,
@@ -201,19 +249,30 @@ class CudaGraphRunner:
201
249
  out_cache_loc=out_cache_loc,
202
250
  return_logprob=False,
203
251
  top_logprobs_nums=0,
204
- positions=(seq_lens - 1).to(torch.int64),
252
+ positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
205
253
  flashinfer_decode_wrapper=flashinfer_decode_wrapper,
206
254
  )
207
255
 
208
256
  return forward(input_ids, input_metadata.positions, input_metadata)
209
257
 
210
258
  for _ in range(2):
259
+ torch.cuda.synchronize()
260
+ self.model_runner.tp_group.barrier()
261
+
211
262
  run_once()
212
263
 
264
+ torch.cuda.synchronize()
265
+ self.model_runner.tp_group.barrier()
266
+
213
267
  torch.cuda.synchronize()
268
+ self.model_runner.tp_group.barrier()
269
+
214
270
  with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
215
271
  out = run_once()
272
+
216
273
  torch.cuda.synchronize()
274
+ self.model_runner.tp_group.barrier()
275
+
217
276
  self.graph_memory_pool = graph.pool()
218
277
  return graph, None, out, flashinfer_decode_wrapper
219
278
 
@@ -225,8 +284,8 @@ class CudaGraphRunner:
225
284
  index = bisect.bisect_left(self.batch_size_list, raw_bs)
226
285
  bs = self.batch_size_list[index]
227
286
  if bs != raw_bs:
228
- self.seq_lens.fill_(1)
229
- self.position_ids_offsets.zero_()
287
+ self.seq_lens.zero_()
288
+ self.position_ids_offsets.fill_(1)
230
289
  self.out_cache_loc.zero_()
231
290
 
232
291
  # Common inputs
@@ -246,25 +305,35 @@ class CudaGraphRunner:
246
305
  self.flashinfer_handlers[bs],
247
306
  )
248
307
 
308
+ # Sampling inputs
309
+ self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
310
+
249
311
  # Replay
312
+ torch.cuda.synchronize()
250
313
  self.graphs[bs].replay()
251
- output = self.output_buffers[bs]
314
+ torch.cuda.synchronize()
315
+ sample_output, logits_output = self.output_buffers[bs]
252
316
 
253
317
  # Unpad
254
318
  if bs != raw_bs:
255
- output = LogitProcessorOutput(
256
- next_token_logits=output.next_token_logits[:raw_bs],
319
+ logits_output = LogitsProcessorOutput(
320
+ next_token_logits=logits_output.next_token_logits[:raw_bs],
257
321
  next_token_logprobs=None,
258
322
  normalized_prompt_logprobs=None,
259
323
  input_token_logprobs=None,
260
324
  input_top_logprobs=None,
261
325
  output_top_logprobs=None,
262
326
  )
327
+ sample_output = SampleOutput(
328
+ sample_output.success[:raw_bs],
329
+ sample_output.probs[:raw_bs],
330
+ sample_output.batch_next_token_ids[:raw_bs],
331
+ )
263
332
 
264
333
  # Extract logprobs
265
334
  if batch.return_logprob:
266
- output.next_token_logprobs = torch.nn.functional.log_softmax(
267
- output.next_token_logits, dim=-1
335
+ logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
336
+ logits_output.next_token_logits, dim=-1
268
337
  )
269
338
  return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
270
339
  if return_top_logprob:
@@ -272,8 +341,8 @@ class CudaGraphRunner:
272
341
  forward_mode=ForwardMode.DECODE,
273
342
  top_logprobs_nums=batch.top_logprobs_nums,
274
343
  )
275
- output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
276
- output.next_token_logprobs, logits_metadata
344
+ logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
345
+ logits_output.next_token_logprobs, logits_metadata
277
346
  )[1]
278
347
 
279
- return output
348
+ return sample_output, logits_output
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2023-2024 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -26,6 +28,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
26
28
 
27
29
  if TYPE_CHECKING:
28
30
  from sglang.srt.model_executor.model_runner import ModelRunner
31
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
29
32
 
30
33
 
31
34
  class ForwardMode(IntEnum):
@@ -42,6 +45,7 @@ class InputMetadata:
42
45
  """Store all inforamtion of a forward pass."""
43
46
 
44
47
  forward_mode: ForwardMode
48
+ sampling_info: SamplingBatchInfo
45
49
  batch_size: int
46
50
  req_pool_indices: torch.Tensor
47
51
  seq_lens: torch.Tensor
@@ -61,9 +65,11 @@ class InputMetadata:
61
65
  extend_start_loc: torch.Tensor = None
62
66
  extend_no_prefix: bool = None
63
67
 
64
- # Output options
68
+ # For logprob
65
69
  return_logprob: bool = False
66
70
  top_logprobs_nums: List[int] = None
71
+ extend_seq_lens_cpu: List[int] = None
72
+ logprob_start_lens_cpu: List[int] = None
67
73
 
68
74
  # For multimodal
69
75
  pixel_values: List[torch.Tensor] = None
@@ -86,14 +92,19 @@ class InputMetadata:
86
92
  reqs = batch.reqs
87
93
  self.pixel_values = [r.pixel_values for r in reqs]
88
94
  self.image_sizes = [r.image_size for r in reqs]
89
- self.image_offsets = [
90
- (
91
- (r.image_offset - len(r.prefix_indices))
92
- if r.image_offset is not None
93
- else 0
94
- )
95
- for r in reqs
96
- ]
95
+ self.image_offsets = []
96
+ for r in reqs:
97
+ if isinstance(r.image_offset, list):
98
+ self.image_offsets.append(
99
+ [
100
+ (image_offset - len(r.prefix_indices))
101
+ for image_offset in r.image_offset
102
+ ]
103
+ )
104
+ elif isinstance(r.image_offset, int):
105
+ self.image_offsets.append(r.image_offset - len(r.prefix_indices))
106
+ elif r.image_offset is None:
107
+ self.image_offsets.append(0)
97
108
 
98
109
  def compute_positions(self, batch: ScheduleBatch):
99
110
  position_ids_offsets = batch.position_ids_offsets
@@ -109,8 +120,8 @@ class InputMetadata:
109
120
  self.positions = torch.tensor(
110
121
  np.concatenate(
111
122
  [
112
- np.arange(len(req.prefix_indices), len(req.fill_ids))
113
- for req in batch.reqs
123
+ np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
124
+ for i, req in enumerate(batch.reqs)
114
125
  ],
115
126
  axis=0,
116
127
  ),
@@ -123,7 +134,7 @@ class InputMetadata:
123
134
  np.concatenate(
124
135
  [
125
136
  np.arange(
126
- len(req.prefix_indices) + position_ids_offsets_cpu[i],
137
+ batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
127
138
  len(req.fill_ids) + position_ids_offsets_cpu[i],
128
139
  )
129
140
  for i, req in enumerate(batch.reqs)
@@ -139,14 +150,29 @@ class InputMetadata:
139
150
  def compute_extend_infos(self, batch: ScheduleBatch):
140
151
  if self.forward_mode == ForwardMode.DECODE:
141
152
  self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
153
+ self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
142
154
  else:
143
155
  extend_lens_cpu = [
144
- len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
156
+ len(r.fill_ids) - batch.prefix_lens_cpu[i]
157
+ for i, r in enumerate(batch.reqs)
145
158
  ]
146
159
  self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
147
160
  self.extend_start_loc = torch.zeros_like(self.seq_lens)
148
161
  self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
149
- self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
162
+ self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
163
+
164
+ self.extend_seq_lens_cpu = extend_lens_cpu
165
+ self.logprob_start_lens_cpu = [
166
+ (
167
+ min(
168
+ req.logprob_start_len - batch.prefix_lens_cpu[i],
169
+ extend_lens_cpu[i] - 1,
170
+ )
171
+ if req.logprob_start_len >= batch.prefix_lens_cpu[i]
172
+ else extend_lens_cpu[i] - 1 # Fake extend, actually decode
173
+ )
174
+ for i, req in enumerate(batch.reqs)
175
+ ]
150
176
 
151
177
  @classmethod
152
178
  def from_schedule_batch(
@@ -157,6 +183,7 @@ class InputMetadata:
157
183
  ):
158
184
  ret = cls(
159
185
  forward_mode=forward_mode,
186
+ sampling_info=batch.sampling_info,
160
187
  batch_size=batch.batch_size(),
161
188
  req_pool_indices=batch.req_pool_indices,
162
189
  seq_lens=batch.seq_lens,
@@ -167,6 +194,8 @@ class InputMetadata:
167
194
  top_logprobs_nums=batch.top_logprobs_nums,
168
195
  )
169
196
 
197
+ ret.sampling_info.prepare_penalties()
198
+
170
199
  ret.compute_positions(batch)
171
200
 
172
201
  ret.compute_extend_infos(batch)
@@ -180,44 +209,47 @@ class InputMetadata:
180
209
  if forward_mode != ForwardMode.DECODE:
181
210
  ret.init_multimuldal_info(batch)
182
211
 
183
- prefix_lens = None
184
- if forward_mode != ForwardMode.DECODE:
185
- prefix_lens = torch.tensor(
186
- [len(r.prefix_indices) for r in batch.reqs], device="cuda"
187
- )
188
-
189
212
  if model_runner.server_args.disable_flashinfer:
190
- ret.init_triton_args(batch, prefix_lens)
213
+ ret.init_triton_args(batch)
191
214
 
192
215
  flashinfer_use_ragged = False
193
216
  if not model_runner.server_args.disable_flashinfer:
194
217
  if (
195
218
  forward_mode != ForwardMode.DECODE
196
219
  and int(torch.sum(ret.seq_lens)) > 4096
220
+ and model_runner.sliding_window_size is None
197
221
  ):
198
222
  flashinfer_use_ragged = True
199
223
  ret.init_flashinfer_handlers(
200
- model_runner, prefix_lens, flashinfer_use_ragged
224
+ model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
201
225
  )
202
226
 
203
227
  return ret
204
228
 
205
- def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
229
+ def init_triton_args(self, batch: ScheduleBatch):
206
230
  """Init auxiliary variables for triton attention backend."""
207
231
  self.triton_max_seq_len = int(torch.max(self.seq_lens))
208
- self.triton_prefix_lens = prefix_lens
209
232
  self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
210
233
  self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
211
234
 
212
235
  if self.forward_mode == ForwardMode.DECODE:
213
236
  self.triton_max_extend_len = None
214
237
  else:
215
- extend_seq_lens = self.seq_lens - prefix_lens
238
+ self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
239
+ extend_seq_lens = self.seq_lens - self.triton_prefix_lens
216
240
  self.triton_max_extend_len = int(torch.max(extend_seq_lens))
217
241
 
218
242
  def init_flashinfer_handlers(
219
- self, model_runner, prefix_lens, flashinfer_use_ragged
243
+ self,
244
+ model_runner,
245
+ prefix_lens_cpu,
246
+ flashinfer_use_ragged,
220
247
  ):
248
+ if self.forward_mode != ForwardMode.DECODE:
249
+ prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
250
+ else:
251
+ prefix_lens = None
252
+
221
253
  update_flashinfer_indices(
222
254
  self.forward_mode,
223
255
  model_runner,
@@ -255,65 +287,139 @@ def update_flashinfer_indices(
255
287
  head_dim = model_runner.model_config.head_dim
256
288
  batch_size = len(req_pool_indices)
257
289
 
258
- if flashinfer_use_ragged:
259
- paged_kernel_lens = prefix_lens
260
- else:
261
- paged_kernel_lens = seq_lens
262
-
263
- kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
264
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
265
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
266
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
267
- kv_indices = torch.cat(
268
- [
269
- model_runner.req_to_token_pool.req_to_token[
270
- req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
271
- ]
272
- for i in range(batch_size)
273
- ],
274
- dim=0,
275
- ).contiguous()
276
- kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
277
-
278
- if forward_mode == ForwardMode.DECODE:
279
- # CUDA graph uses different flashinfer_decode_wrapper
280
- if flashinfer_decode_wrapper is None:
281
- flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
282
-
283
- flashinfer_decode_wrapper.end_forward()
284
- flashinfer_decode_wrapper.begin_forward(
285
- kv_indptr,
286
- kv_indices,
287
- kv_last_page_len,
288
- num_qo_heads,
289
- num_kv_heads,
290
- head_dim,
291
- 1,
292
- )
293
- else:
294
- # extend part
295
- qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
296
- qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
297
-
290
+ if model_runner.sliding_window_size is None:
298
291
  if flashinfer_use_ragged:
299
- model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
300
- model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
301
- qo_indptr,
302
- qo_indptr,
292
+ paged_kernel_lens = prefix_lens
293
+ else:
294
+ paged_kernel_lens = seq_lens
295
+
296
+ kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
297
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
298
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
299
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
300
+ kv_indices = torch.cat(
301
+ [
302
+ model_runner.req_to_token_pool.req_to_token[
303
+ req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
304
+ ]
305
+ for i in range(batch_size)
306
+ ],
307
+ dim=0,
308
+ ).contiguous()
309
+ kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
310
+
311
+ if forward_mode == ForwardMode.DECODE:
312
+ # CUDA graph uses different flashinfer_decode_wrapper
313
+ if flashinfer_decode_wrapper is None:
314
+ flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
315
+
316
+ flashinfer_decode_wrapper.end_forward()
317
+ flashinfer_decode_wrapper.begin_forward(
318
+ kv_indptr,
319
+ kv_indices,
320
+ kv_last_page_len,
303
321
  num_qo_heads,
304
322
  num_kv_heads,
305
323
  head_dim,
324
+ 1,
325
+ data_type=model_runner.kv_cache_dtype,
326
+ q_data_type=model_runner.dtype,
306
327
  )
328
+ else:
329
+ # extend part
330
+ qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
331
+ qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
332
+
333
+ if flashinfer_use_ragged:
334
+ model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
335
+ model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
336
+ qo_indptr,
337
+ qo_indptr,
338
+ num_qo_heads,
339
+ num_kv_heads,
340
+ head_dim,
341
+ )
307
342
 
308
- # cached part
309
- model_runner.flashinfer_prefill_wrapper_paged.end_forward()
310
- model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
311
- qo_indptr,
312
- kv_indptr,
313
- kv_indices,
314
- kv_last_page_len,
315
- num_qo_heads,
316
- num_kv_heads,
317
- head_dim,
318
- 1,
319
- )
343
+ # cached part
344
+ model_runner.flashinfer_prefill_wrapper_paged.end_forward()
345
+ model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
346
+ qo_indptr,
347
+ kv_indptr,
348
+ kv_indices,
349
+ kv_last_page_len,
350
+ num_qo_heads,
351
+ num_kv_heads,
352
+ head_dim,
353
+ 1,
354
+ )
355
+ else:
356
+ # window attention use paged only
357
+ kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
358
+ for wrapper_id in range(2):
359
+ if wrapper_id == 0:
360
+ if forward_mode == ForwardMode.DECODE:
361
+ paged_kernel_lens = torch.minimum(
362
+ seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
363
+ )
364
+ else:
365
+ paged_kernel_lens = torch.minimum(
366
+ seq_lens,
367
+ torch.tensor(model_runner.sliding_window_size)
368
+ + seq_lens
369
+ - prefix_lens,
370
+ )
371
+ else:
372
+ paged_kernel_lens = seq_lens
373
+
374
+ kv_start_idx = seq_lens - paged_kernel_lens
375
+
376
+ kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
377
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
378
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
379
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
380
+ kv_indices = torch.cat(
381
+ [
382
+ model_runner.req_to_token_pool.req_to_token[
383
+ req_pool_indices_cpu[i],
384
+ kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
385
+ ]
386
+ for i in range(batch_size)
387
+ ],
388
+ dim=0,
389
+ ).contiguous()
390
+
391
+ if forward_mode == ForwardMode.DECODE:
392
+ # CUDA graph uses different flashinfer_decode_wrapper
393
+ if flashinfer_decode_wrapper is None:
394
+ flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
395
+
396
+ flashinfer_decode_wrapper[wrapper_id].end_forward()
397
+ flashinfer_decode_wrapper[wrapper_id].begin_forward(
398
+ kv_indptr,
399
+ kv_indices,
400
+ kv_last_page_len,
401
+ num_qo_heads,
402
+ num_kv_heads,
403
+ head_dim,
404
+ 1,
405
+ data_type=model_runner.kv_cache_dtype,
406
+ q_data_type=model_runner.dtype,
407
+ )
408
+ else:
409
+ # extend part
410
+ qo_indptr = torch.zeros(
411
+ (batch_size + 1,), dtype=torch.int32, device="cuda"
412
+ )
413
+ qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
414
+
415
+ model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward()
416
+ model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward(
417
+ qo_indptr,
418
+ kv_indptr,
419
+ kv_indices,
420
+ kv_last_page_len,
421
+ num_qo_heads,
422
+ num_kv_heads,
423
+ head_dim,
424
+ 1,
425
+ )