sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. sglang/bench_latency.py +31 -13
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/conversation.py +11 -2
  6. sglang/srt/layers/attention/__init__.py +27 -5
  7. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  9. sglang/srt/layers/attention/triton_backend.py +6 -4
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  12. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  13. sglang/srt/layers/sampler.py +6 -2
  14. sglang/srt/managers/data_parallel_controller.py +177 -0
  15. sglang/srt/managers/detokenizer_manager.py +31 -10
  16. sglang/srt/managers/io_struct.py +11 -2
  17. sglang/srt/managers/schedule_batch.py +126 -43
  18. sglang/srt/managers/schedule_policy.py +2 -1
  19. sglang/srt/managers/scheduler.py +245 -142
  20. sglang/srt/managers/tokenizer_manager.py +14 -1
  21. sglang/srt/managers/tp_worker.py +111 -1
  22. sglang/srt/mem_cache/chunk_cache.py +8 -4
  23. sglang/srt/mem_cache/memory_pool.py +77 -4
  24. sglang/srt/mem_cache/radix_cache.py +15 -7
  25. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  26. sglang/srt/model_executor/forward_batch_info.py +16 -21
  27. sglang/srt/model_executor/model_runner.py +100 -36
  28. sglang/srt/models/baichuan.py +2 -3
  29. sglang/srt/models/chatglm.py +5 -6
  30. sglang/srt/models/commandr.py +1 -2
  31. sglang/srt/models/dbrx.py +1 -2
  32. sglang/srt/models/deepseek.py +4 -5
  33. sglang/srt/models/deepseek_v2.py +5 -6
  34. sglang/srt/models/exaone.py +1 -2
  35. sglang/srt/models/gemma.py +2 -2
  36. sglang/srt/models/gemma2.py +5 -5
  37. sglang/srt/models/gpt_bigcode.py +5 -5
  38. sglang/srt/models/grok.py +1 -2
  39. sglang/srt/models/internlm2.py +1 -2
  40. sglang/srt/models/llama.py +1 -2
  41. sglang/srt/models/llama_classification.py +1 -2
  42. sglang/srt/models/llama_reward.py +2 -3
  43. sglang/srt/models/llava.py +4 -8
  44. sglang/srt/models/llavavid.py +1 -2
  45. sglang/srt/models/minicpm.py +1 -2
  46. sglang/srt/models/minicpm3.py +5 -6
  47. sglang/srt/models/mixtral.py +1 -2
  48. sglang/srt/models/mixtral_quant.py +1 -2
  49. sglang/srt/models/olmo.py +352 -0
  50. sglang/srt/models/olmoe.py +1 -2
  51. sglang/srt/models/qwen.py +1 -2
  52. sglang/srt/models/qwen2.py +1 -2
  53. sglang/srt/models/qwen2_moe.py +4 -5
  54. sglang/srt/models/stablelm.py +1 -2
  55. sglang/srt/models/torch_native_llama.py +1 -2
  56. sglang/srt/models/xverse.py +1 -2
  57. sglang/srt/models/xverse_moe.py +4 -5
  58. sglang/srt/models/yivl.py +1 -2
  59. sglang/srt/openai_api/adapter.py +97 -52
  60. sglang/srt/openai_api/protocol.py +10 -2
  61. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  62. sglang/srt/sampling/sampling_batch_info.py +105 -59
  63. sglang/srt/sampling/sampling_params.py +2 -0
  64. sglang/srt/server.py +171 -37
  65. sglang/srt/server_args.py +127 -48
  66. sglang/srt/utils.py +37 -14
  67. sglang/test/few_shot_gsm8k.py +4 -1
  68. sglang/test/few_shot_gsm8k_engine.py +144 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  70. sglang/version.py +1 -1
  71. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
  72. sglang-0.3.4.dist-info/RECORD +143 -0
  73. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  74. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  75. sglang-0.3.3.dist-info/RECORD +0 -139
  76. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  77. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -7,18 +7,17 @@ FlashInfer is faster and Triton is easier to customize.
7
7
  Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
8
  """
9
9
 
10
+ from enum import Enum, auto
10
11
  from typing import TYPE_CHECKING
11
12
 
12
13
  import torch
13
14
  import torch.nn as nn
15
+ import triton
16
+ import triton.language as tl
14
17
 
15
18
  from sglang.global_config import global_config
16
19
  from sglang.srt.layers.attention import AttentionBackend
17
- from sglang.srt.layers.attention.flashinfer_utils import (
18
- WrapperDispatch,
19
- update_flashinfer_indices,
20
- )
21
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
20
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
22
21
  from sglang.srt.utils import is_flashinfer_available
23
22
 
24
23
  if TYPE_CHECKING:
@@ -34,13 +33,18 @@ if is_flashinfer_available():
34
33
  from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
35
34
 
36
35
 
36
+ class WrapperDispatch(Enum):
37
+ SLIDING_WINDOW = auto()
38
+ CROSS_ATTENTION = auto()
39
+
40
+
37
41
  class FlashInferAttnBackend(AttentionBackend):
38
42
  """Flashinfer attention kernels."""
39
43
 
40
44
  def __init__(self, model_runner: ModelRunner):
41
45
  super().__init__()
42
- self.model_runner = model_runner
43
46
 
47
+ # Parse constants
44
48
  if not _grouped_size_compiled_for_decode_kernels(
45
49
  model_runner.model_config.num_attention_heads // model_runner.tp_size,
46
50
  model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
@@ -48,27 +52,43 @@ class FlashInferAttnBackend(AttentionBackend):
48
52
  self.decode_use_tensor_cores = True
49
53
  else:
50
54
  self.decode_use_tensor_cores = False
51
-
52
- self.workspace_buffer = torch.empty(
53
- global_config.flashinfer_workspace_size,
54
- dtype=torch.uint8,
55
- device="cuda",
56
- )
55
+ self.max_context_len = model_runner.model_config.context_len
57
56
 
58
57
  assert not (
59
58
  model_runner.sliding_window_size is not None
60
59
  and model_runner.has_cross_attention
61
60
  ), "Sliding window and cross attention are not supported together"
62
61
 
63
- self.num_wrappers = 1
64
- self.dispatch_reason = None
65
62
  if model_runner.sliding_window_size is not None:
66
63
  self.num_wrappers = 2
67
64
  self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
68
65
  elif model_runner.has_cross_attention:
69
66
  self.num_wrappers = 2
70
67
  self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
68
+ else:
69
+ self.num_wrappers = 1
70
+ self.dispatch_reason = None
71
+
72
+ # Allocate buffers
73
+ self.workspace_buffer = torch.empty(
74
+ global_config.flashinfer_workspace_size,
75
+ dtype=torch.uint8,
76
+ device=model_runner.device,
77
+ )
78
+ max_bs = model_runner.req_to_token_pool.size
79
+ self.kv_indptr = [
80
+ torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
81
+ for _ in range(self.num_wrappers)
82
+ ]
83
+ self.kv_last_page_len = torch.ones(
84
+ (max_bs,), dtype=torch.int32, device=model_runner.device
85
+ )
86
+ self.qo_indptr = [
87
+ torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
88
+ for _ in range(self.num_wrappers)
89
+ ]
71
90
 
91
+ # Create wrappers
72
92
  # NOTE: we do not use ragged attention when there are multiple wrappers
73
93
  self.prefill_wrapper_ragged = (
74
94
  BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
@@ -92,26 +112,23 @@ class FlashInferAttnBackend(AttentionBackend):
92
112
  )
93
113
  )
94
114
 
115
+ # Create indices updater
116
+ self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
117
+ self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
118
+ model_runner, self
119
+ )
120
+
121
+ # Other metadata
95
122
  self.forward_metadata = None
96
123
  self.cuda_graph_metadata = {}
97
124
 
98
- def _get_wrapper_idx(self, layer: nn.Module):
99
- if self.num_wrappers == 1:
100
- return 0
101
-
102
- if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
103
- return layer.sliding_window_size == -1
104
- if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
105
- return layer.is_cross_attention
106
-
107
- raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
108
-
109
125
  def init_forward_metadata(self, forward_batch: ForwardBatch):
110
126
  if forward_batch.forward_mode.is_decode():
111
- prefix_lens = None
112
- use_ragged = False
113
- extend_no_prefix = False
114
- total_num_tokens = None
127
+ self.indices_updater_decode.update(
128
+ forward_batch.req_pool_indices,
129
+ forward_batch.seq_lens,
130
+ )
131
+ self.forward_metadata = (self.decode_wrappers,)
115
132
  else:
116
133
  prefix_lens = forward_batch.extend_prefix_lens
117
134
 
@@ -123,48 +140,32 @@ class FlashInferAttnBackend(AttentionBackend):
123
140
  ):
124
141
  use_ragged = True
125
142
 
126
- total_num_tokens = torch.sum(forward_batch.seq_lens).item()
127
143
  extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
128
144
 
129
- update_flashinfer_indices(
130
- forward_batch.forward_mode,
131
- self.model_runner,
132
- forward_batch.req_pool_indices,
133
- forward_batch.seq_lens,
134
- prefix_lens,
135
- use_ragged=use_ragged,
136
- )
145
+ self.indices_updater_prefill.update(
146
+ forward_batch.req_pool_indices,
147
+ forward_batch.seq_lens,
148
+ prefix_lens,
149
+ use_ragged,
150
+ )
137
151
 
138
- self.forward_metadata = (
139
- use_ragged,
140
- extend_no_prefix,
141
- total_num_tokens,
142
- self.decode_wrappers,
143
- )
152
+ self.forward_metadata = (
153
+ use_ragged,
154
+ extend_no_prefix,
155
+ )
144
156
 
145
157
  def init_cuda_graph_state(self, max_bs: int):
146
- self.cuda_graph_kv_indptr = torch.zeros(
147
- (max_bs + 1,), dtype=torch.int32, device="cuda"
148
- )
149
- self.cuda_graph_kv_indices = torch.zeros(
150
- (max_bs * self.model_runner.model_config.context_len,),
158
+ cuda_graph_kv_indices = torch.zeros(
159
+ (max_bs * self.max_context_len,),
151
160
  dtype=torch.int32,
152
161
  device="cuda",
153
162
  )
154
- self.cuda_graph_kv_last_page_len = torch.ones(
155
- (max_bs,), dtype=torch.int32, device="cuda"
156
- )
157
-
158
- # NOTE: the buffers are always in the form of list
159
- self.cuda_graph_kv_indptr = [self.cuda_graph_kv_indptr] + [
160
- self.cuda_graph_kv_indptr.clone() for _ in range(self.num_wrappers - 1)
161
- ]
162
- self.cuda_graph_kv_indices = [self.cuda_graph_kv_indices] + [
163
- self.cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
163
+ self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
164
+ cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
164
165
  ]
165
166
 
166
167
  def init_forward_metadata_capture_cuda_graph(
167
- self, bs: int, req_pool_indices, seq_lens
168
+ self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
168
169
  ):
169
170
  decode_wrappers = []
170
171
  for i in range(self.num_wrappers):
@@ -174,35 +175,21 @@ class FlashInferAttnBackend(AttentionBackend):
174
175
  "NHD",
175
176
  use_cuda_graph=True,
176
177
  use_tensor_cores=self.decode_use_tensor_cores,
177
- paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
178
+ paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1],
178
179
  paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
179
- paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
180
+ paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs],
180
181
  )
181
182
  )
182
183
 
183
- update_flashinfer_indices(
184
- ForwardMode.DECODE,
185
- self.model_runner,
186
- req_pool_indices,
187
- seq_lens,
188
- None,
189
- decode_wrappers,
190
- )
191
-
184
+ self.indices_updater_decode.update(req_pool_indices, seq_lens, decode_wrappers)
192
185
  self.cuda_graph_metadata[bs] = decode_wrappers
193
-
194
- self.forward_metadata = (False, False, None, decode_wrappers)
186
+ self.forward_metadata = (decode_wrappers,)
195
187
 
196
188
  def init_forward_metadata_replay_cuda_graph(
197
- self, bs: int, req_pool_indices, seq_lens
189
+ self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
198
190
  ):
199
- update_flashinfer_indices(
200
- ForwardMode.DECODE,
201
- self.model_runner,
202
- req_pool_indices[:bs],
203
- seq_lens[:bs],
204
- None,
205
- self.cuda_graph_metadata[bs],
191
+ self.indices_updater_decode.update(
192
+ req_pool_indices[:bs], seq_lens[:bs], self.cuda_graph_metadata[bs]
206
193
  )
207
194
 
208
195
  def get_cuda_graph_seq_len_fill_value(self):
@@ -213,7 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
213
200
  self._get_wrapper_idx(layer)
214
201
  ]
215
202
 
216
- use_ragged, extend_no_prefix, _, _ = self.forward_metadata
203
+ use_ragged, extend_no_prefix = self.forward_metadata
217
204
 
218
205
  if not use_ragged:
219
206
  if k is not None:
@@ -259,7 +246,7 @@ class FlashInferAttnBackend(AttentionBackend):
259
246
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
260
247
 
261
248
  def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
262
- decode_wrapper = self.forward_metadata[-1][self._get_wrapper_idx(layer)]
249
+ decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
263
250
 
264
251
  if k is not None:
265
252
  assert v is not None
@@ -275,3 +262,285 @@ class FlashInferAttnBackend(AttentionBackend):
275
262
  )
276
263
 
277
264
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
265
+
266
+ def _get_wrapper_idx(self, layer: nn.Module):
267
+ if self.num_wrappers == 1:
268
+ return 0
269
+
270
+ if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
271
+ return layer.sliding_window_size == -1
272
+ if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
273
+ return layer.is_cross_attention
274
+
275
+ raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
276
+
277
+
278
+ class FlashInferIndicesUpdaterDecode:
279
+ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
280
+ # Constants
281
+ self.num_qo_heads = (
282
+ model_runner.model_config.num_attention_heads // model_runner.tp_size
283
+ )
284
+ self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
285
+ model_runner.tp_size
286
+ )
287
+ self.head_dim = model_runner.model_config.head_dim
288
+ self.data_type = model_runner.kv_cache_dtype
289
+ self.q_data_type = model_runner.dtype
290
+ self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
291
+ self.sliding_window_size = model_runner.sliding_window_size
292
+
293
+ # Buffers and wrappers
294
+ self.kv_indptr = attn_backend.kv_indptr
295
+ self.kv_last_page_len = attn_backend.kv_last_page_len
296
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
297
+ self.decode_wrappers = attn_backend.decode_wrappers
298
+
299
+ # Dispatch
300
+ if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
301
+ self.update = self.update_sliding_window
302
+ elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
303
+ self.update = self.update_cross_attention
304
+ else:
305
+ assert attn_backend.num_wrappers == 1
306
+ self.update = self.update_single_wrapper
307
+
308
+ def update_single_wrapper(self, req_pool_indices, seq_lens, decode_wrappers=None):
309
+ decode_wrappers = decode_wrappers or self.decode_wrappers
310
+ self.call_begin_forward(
311
+ decode_wrappers[0], req_pool_indices, seq_lens, self.kv_indptr[0], None
312
+ )
313
+
314
+ def update_sliding_window(self, req_pool_indices, seq_lens, decode_wrappers=None):
315
+ decode_wrappers = decode_wrappers or self.decode_wrappers
316
+
317
+ for wrapper_id in range(2):
318
+ if wrapper_id == 0:
319
+ # Sliding window attention
320
+ paged_kernel_lens = torch.minimum( # TODO: replace this with clamp
321
+ seq_lens,
322
+ torch.tensor(self.sliding_window_size + 1),
323
+ )
324
+ else:
325
+ # Full attention
326
+ paged_kernel_lens = seq_lens
327
+
328
+ kv_start_idx = seq_lens - paged_kernel_lens
329
+
330
+ self.call_begin_forward(
331
+ decode_wrappers[wrapper_id],
332
+ req_pool_indices,
333
+ paged_kernel_lens,
334
+ self.kv_indptr[wrapper_id],
335
+ kv_start_idx,
336
+ )
337
+
338
+ def update_cross_attention(self):
339
+ raise NotImplementedError()
340
+
341
+ def call_begin_forward(
342
+ self, wrapper, req_pool_indices, paged_kernel_lens, kv_indptr, kv_start_idx
343
+ ):
344
+ bs = len(req_pool_indices)
345
+ kv_indptr = kv_indptr[: bs + 1]
346
+ # TODO: optimize the blocking call on kv_indptr[-1]
347
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
348
+ kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
349
+
350
+ create_flashinfer_kv_indices_triton[(bs,)](
351
+ self.req_to_token,
352
+ req_pool_indices,
353
+ paged_kernel_lens,
354
+ kv_indptr,
355
+ kv_start_idx,
356
+ kv_indices,
357
+ self.max_context_len,
358
+ )
359
+
360
+ wrapper.end_forward()
361
+ wrapper.begin_forward(
362
+ kv_indptr,
363
+ kv_indices,
364
+ self.kv_last_page_len[:bs],
365
+ self.num_qo_heads,
366
+ self.num_kv_heads,
367
+ self.head_dim,
368
+ 1,
369
+ data_type=self.data_type,
370
+ q_data_type=self.q_data_type,
371
+ )
372
+
373
+
374
+ class FlashInferIndicesUpdaterPrefill:
375
+ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
376
+ # Constants
377
+ self.num_qo_heads = (
378
+ model_runner.model_config.num_attention_heads // model_runner.tp_size
379
+ )
380
+ self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
381
+ model_runner.tp_size
382
+ )
383
+ self.head_dim = model_runner.model_config.head_dim
384
+ self.data_type = model_runner.kv_cache_dtype
385
+ self.q_data_type = model_runner.dtype
386
+ self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
387
+ self.sliding_window_size = model_runner.sliding_window_size
388
+
389
+ # Buffers and wrappers
390
+ self.kv_indptr = attn_backend.kv_indptr
391
+ self.kv_last_page_len = attn_backend.kv_last_page_len
392
+ self.qo_indptr = attn_backend.qo_indptr
393
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
394
+ self.wrapper_ragged = attn_backend.prefill_wrapper_ragged
395
+ self.wrappers_paged = attn_backend.prefill_wrappers_paged
396
+
397
+ # Dispatch
398
+ if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
399
+ self.update = self.update_sliding_window
400
+ elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
401
+ self.update = self.update_cross_attention
402
+ else:
403
+ assert attn_backend.num_wrappers == 1
404
+ self.update = self.update_single_wrapper
405
+
406
+ def update_single_wrapper(
407
+ self, req_pool_indices, seq_lens, prefix_lens, use_ragged
408
+ ):
409
+ if use_ragged:
410
+ paged_kernel_lens = prefix_lens
411
+ else:
412
+ paged_kernel_lens = seq_lens
413
+
414
+ self.call_begin_forward(
415
+ self.wrapper_ragged,
416
+ self.wrappers_paged[0],
417
+ req_pool_indices,
418
+ paged_kernel_lens,
419
+ seq_lens,
420
+ prefix_lens,
421
+ None,
422
+ self.kv_indptr[0],
423
+ self.qo_indptr[0],
424
+ use_ragged,
425
+ )
426
+
427
+ def update_sliding_window(
428
+ self, req_pool_indices, seq_lens, prefix_lens, use_ragged
429
+ ):
430
+ for wrapper_id in range(2):
431
+ if wrapper_id == 0:
432
+ # window attention use paged only
433
+ paged_kernel_lens = torch.minimum(
434
+ seq_lens,
435
+ torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
436
+ )
437
+ else:
438
+ # full attention
439
+ paged_kernel_lens = seq_lens
440
+ kv_start_idx = seq_lens - paged_kernel_lens
441
+
442
+ self.call_begin_forward(
443
+ self.wrapper_ragged,
444
+ self.wrappers_paged[wrapper_id],
445
+ req_pool_indices,
446
+ paged_kernel_lens,
447
+ seq_lens,
448
+ prefix_lens,
449
+ kv_start_idx,
450
+ self.kv_indptr[wrapper_id],
451
+ self.qo_indptr[wrapper_id],
452
+ use_ragged,
453
+ )
454
+
455
+ def update_cross_attention(self):
456
+ raise NotImplementedError()
457
+
458
+ def call_begin_forward(
459
+ self,
460
+ wrapper_ragged,
461
+ wrapper_paged,
462
+ req_pool_indices,
463
+ paged_kernel_lens,
464
+ seq_lens,
465
+ prefix_lens,
466
+ kv_start_idx,
467
+ kv_indptr,
468
+ qo_indptr,
469
+ use_ragged,
470
+ ):
471
+ bs = len(req_pool_indices)
472
+ kv_indptr = kv_indptr[: bs + 1]
473
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
474
+ kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
475
+ create_flashinfer_kv_indices_triton[(bs,)](
476
+ self.req_to_token,
477
+ req_pool_indices,
478
+ paged_kernel_lens,
479
+ kv_indptr,
480
+ kv_start_idx,
481
+ kv_indices,
482
+ self.max_context_len,
483
+ )
484
+
485
+ qo_indptr = qo_indptr[: bs + 1]
486
+ qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
487
+
488
+ # extend part
489
+ if use_ragged:
490
+ wrapper_ragged.end_forward()
491
+ wrapper_ragged.begin_forward(
492
+ qo_indptr,
493
+ qo_indptr,
494
+ self.num_qo_heads,
495
+ self.num_kv_heads,
496
+ self.head_dim,
497
+ )
498
+
499
+ # cached part
500
+ wrapper_paged.end_forward()
501
+ wrapper_paged.begin_forward(
502
+ qo_indptr,
503
+ kv_indptr,
504
+ kv_indices,
505
+ self.kv_last_page_len[:bs],
506
+ self.num_qo_heads,
507
+ self.num_kv_heads,
508
+ self.head_dim,
509
+ 1,
510
+ )
511
+
512
+
513
+ @triton.jit
514
+ def create_flashinfer_kv_indices_triton(
515
+ req_to_token_ptr, # [max_batch, max_context_len]
516
+ req_pool_indices_ptr,
517
+ page_kernel_lens_ptr,
518
+ kv_indptr,
519
+ kv_start_idx,
520
+ kv_indices_ptr,
521
+ max_context_len: tl.constexpr,
522
+ ):
523
+ BLOCK_SIZE: tl.constexpr = 512
524
+ pid = tl.program_id(axis=0)
525
+ req_pool_index = tl.load(req_pool_indices_ptr + pid)
526
+ kv_indices_offset = tl.load(kv_indptr + pid)
527
+
528
+ kv_start = 0
529
+ kv_end = 0
530
+ if kv_start_idx:
531
+ kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
532
+ kv_end = kv_start
533
+ kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
534
+
535
+ req_to_token_ptr += req_pool_index * max_context_len
536
+ kv_indices_ptr += kv_indices_offset
537
+
538
+ ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
539
+ st_offset = tl.arange(0, BLOCK_SIZE)
540
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
541
+ for _ in range(num_loop):
542
+ mask = ld_offset < kv_end
543
+ data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
544
+ tl.store(kv_indices_ptr + st_offset, data, mask=mask)
545
+ ld_offset += BLOCK_SIZE
546
+ st_offset += BLOCK_SIZE
@@ -40,6 +40,8 @@ class TritonAttnBackend(AttentionBackend):
40
40
 
41
41
  self.cuda_graph_max_seq_len = model_runner.model_config.context_len
42
42
 
43
+ self.device = model_runner.device
44
+
43
45
  def init_forward_metadata(self, forward_batch: ForwardBatch):
44
46
  """Init auxiliary variables for triton attention backend."""
45
47
 
@@ -51,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
51
53
  attn_logits = torch.empty(
52
54
  (self.num_head, total_num_tokens),
53
55
  dtype=self.reduce_dtype,
54
- device="cuda",
56
+ device=self.device,
55
57
  )
56
58
 
57
59
  max_seq_len = torch.max(forward_batch.seq_lens).item()
@@ -67,7 +69,7 @@ class TritonAttnBackend(AttentionBackend):
67
69
  self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
68
70
 
69
71
  self.cuda_graph_start_loc = torch.zeros(
70
- (max_bs,), dtype=torch.int32, device="cuda"
72
+ (max_bs,), dtype=torch.int32, device=self.device
71
73
  )
72
74
  self.cuda_graph_attn_logits = torch.empty(
73
75
  (
@@ -79,7 +81,7 @@ class TritonAttnBackend(AttentionBackend):
79
81
  )
80
82
 
81
83
  def init_forward_metadata_capture_cuda_graph(
82
- self, bs: int, req_pool_indices, seq_lens
84
+ self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
83
85
  ):
84
86
  self.forward_metadata = (
85
87
  self.cuda_graph_start_loc,
@@ -89,7 +91,7 @@ class TritonAttnBackend(AttentionBackend):
89
91
  )
90
92
 
91
93
  def init_forward_metadata_replay_cuda_graph(
92
- self, bs: int, req_pool_indices, seq_lens
94
+ self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
93
95
  ):
94
96
  self.cuda_graph_start_loc.zero_()
95
97
  self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)