sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_latency.py +2 -1
  2. sglang/lang/chat_template.py +17 -0
  3. sglang/launch_server_llavavid.py +1 -1
  4. sglang/srt/configs/__init__.py +3 -0
  5. sglang/srt/configs/model_config.py +27 -2
  6. sglang/srt/configs/qwen2vl.py +133 -0
  7. sglang/srt/constrained/fsm_cache.py +10 -3
  8. sglang/srt/conversation.py +27 -0
  9. sglang/srt/hf_transformers_utils.py +16 -1
  10. sglang/srt/layers/attention/__init__.py +16 -5
  11. sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
  12. sglang/srt/layers/attention/flashinfer_backend.py +174 -54
  13. sglang/srt/layers/attention/triton_backend.py +22 -6
  14. sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
  15. sglang/srt/layers/linear.py +89 -63
  16. sglang/srt/layers/logits_processor.py +5 -5
  17. sglang/srt/layers/rotary_embedding.py +112 -0
  18. sglang/srt/layers/sampler.py +51 -39
  19. sglang/srt/lora/lora.py +3 -1
  20. sglang/srt/managers/data_parallel_controller.py +1 -1
  21. sglang/srt/managers/detokenizer_manager.py +4 -0
  22. sglang/srt/managers/image_processor.py +186 -13
  23. sglang/srt/managers/io_struct.py +10 -0
  24. sglang/srt/managers/schedule_batch.py +238 -68
  25. sglang/srt/managers/scheduler.py +69 -50
  26. sglang/srt/managers/tokenizer_manager.py +24 -4
  27. sglang/srt/managers/tp_worker.py +26 -111
  28. sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
  29. sglang/srt/mem_cache/memory_pool.py +56 -10
  30. sglang/srt/mem_cache/radix_cache.py +4 -3
  31. sglang/srt/model_executor/cuda_graph_runner.py +87 -28
  32. sglang/srt/model_executor/forward_batch_info.py +83 -3
  33. sglang/srt/model_executor/model_runner.py +32 -11
  34. sglang/srt/models/chatglm.py +3 -3
  35. sglang/srt/models/deepseek_v2.py +2 -2
  36. sglang/srt/models/mllama.py +1004 -0
  37. sglang/srt/models/qwen2_vl.py +724 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  39. sglang/srt/sampling/sampling_batch_info.py +13 -3
  40. sglang/srt/sampling/sampling_params.py +5 -7
  41. sglang/srt/server.py +12 -0
  42. sglang/srt/server_args.py +10 -0
  43. sglang/srt/utils.py +22 -0
  44. sglang/test/run_eval.py +2 -0
  45. sglang/test/runners.py +20 -1
  46. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  47. sglang/test/test_utils.py +100 -3
  48. sglang/version.py +1 -1
  49. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
  50. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
  51. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
  52. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
  53. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,6 @@ from enum import Enum, auto
11
11
  from typing import TYPE_CHECKING
12
12
 
13
13
  import torch
14
- import torch.nn as nn
15
14
  import triton
16
15
  import triton.language as tl
17
16
 
@@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
21
20
  from sglang.srt.utils import is_flashinfer_available
22
21
 
23
22
  if TYPE_CHECKING:
23
+ from sglang.srt.layers.radix_attention import RadixAttention
24
24
  from sglang.srt.model_executor.model_runner import ModelRunner
25
25
 
26
26
  if is_flashinfer_available():
@@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend):
56
56
 
57
57
  assert not (
58
58
  model_runner.sliding_window_size is not None
59
- and model_runner.has_cross_attention
59
+ and model_runner.model_config.is_encoder_decoder
60
60
  ), "Sliding window and cross attention are not supported together"
61
61
 
62
62
  if model_runner.sliding_window_size is not None:
63
63
  self.num_wrappers = 2
64
64
  self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
65
- elif model_runner.has_cross_attention:
65
+ elif model_runner.model_config.is_encoder_decoder:
66
66
  self.num_wrappers = 2
67
67
  self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
68
68
  else:
@@ -127,6 +127,9 @@ class FlashInferAttnBackend(AttentionBackend):
127
127
  self.indices_updater_decode.update(
128
128
  forward_batch.req_pool_indices,
129
129
  forward_batch.seq_lens,
130
+ forward_batch.seq_lens_sum,
131
+ decode_wrappers=None,
132
+ encoder_lens=forward_batch.encoder_lens,
130
133
  )
131
134
  self.forward_metadata = (self.decode_wrappers,)
132
135
  else:
@@ -134,10 +137,7 @@ class FlashInferAttnBackend(AttentionBackend):
134
137
 
135
138
  # Some heuristics to check whether to use ragged forward
136
139
  use_ragged = False
137
- if (
138
- torch.sum(forward_batch.seq_lens).item() >= 4096
139
- and self.num_wrappers == 1
140
- ):
140
+ if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
141
141
  use_ragged = True
142
142
 
143
143
  extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
@@ -146,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend):
146
146
  forward_batch.req_pool_indices,
147
147
  forward_batch.seq_lens,
148
148
  prefix_lens,
149
- use_ragged,
149
+ use_ragged=use_ragged,
150
+ encoder_lens=forward_batch.encoder_lens,
150
151
  )
151
152
 
152
- self.forward_metadata = (
153
- use_ragged,
154
- extend_no_prefix,
155
- )
153
+ self.forward_metadata = (use_ragged, extend_no_prefix)
156
154
 
157
155
  def init_cuda_graph_state(self, max_bs: int):
158
156
  cuda_graph_kv_indices = torch.zeros(
@@ -165,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
165
163
  ]
166
164
 
167
165
  def init_forward_metadata_capture_cuda_graph(
168
- self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
166
+ self,
167
+ bs: int,
168
+ req_pool_indices: torch.Tensor,
169
+ seq_lens: torch.Tensor,
170
+ encoder_lens: torch.Tensor = None,
169
171
  ):
170
172
  decode_wrappers = []
171
173
  for i in range(self.num_wrappers):
@@ -181,37 +183,59 @@ class FlashInferAttnBackend(AttentionBackend):
181
183
  )
182
184
  )
183
185
 
184
- self.indices_updater_decode.update(req_pool_indices, seq_lens, decode_wrappers)
186
+ seq_lens_sum = seq_lens.sum().item()
187
+ self.indices_updater_decode.update(
188
+ req_pool_indices,
189
+ seq_lens,
190
+ seq_lens_sum,
191
+ decode_wrappers=decode_wrappers,
192
+ encoder_lens=encoder_lens,
193
+ )
185
194
  self.cuda_graph_metadata[bs] = decode_wrappers
186
195
  self.forward_metadata = (decode_wrappers,)
187
196
 
188
197
  def init_forward_metadata_replay_cuda_graph(
189
- self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
198
+ self,
199
+ bs: int,
200
+ req_pool_indices: torch.Tensor,
201
+ seq_lens: torch.Tensor,
202
+ seq_lens_sum: int,
203
+ encoder_lens: torch.Tensor = None,
190
204
  ):
191
205
  self.indices_updater_decode.update(
192
- req_pool_indices[:bs], seq_lens[:bs], self.cuda_graph_metadata[bs]
206
+ req_pool_indices[:bs],
207
+ seq_lens[:bs],
208
+ seq_lens_sum,
209
+ decode_wrappers=self.cuda_graph_metadata[bs],
210
+ encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
193
211
  )
194
212
 
195
213
  def get_cuda_graph_seq_len_fill_value(self):
196
214
  return 0
197
215
 
198
- def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
216
+ def forward_extend(
217
+ self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
218
+ ):
199
219
  prefill_wrapper_paged = self.prefill_wrappers_paged[
200
220
  self._get_wrapper_idx(layer)
201
221
  ]
202
222
 
203
223
  use_ragged, extend_no_prefix = self.forward_metadata
224
+ cache_loc = (
225
+ forward_batch.out_cache_loc
226
+ if not layer.is_cross_attention
227
+ else forward_batch.encoder_out_cache_loc
228
+ )
204
229
 
205
230
  if not use_ragged:
206
231
  if k is not None:
207
232
  assert v is not None
208
- forward_batch.token_to_kv_pool.set_kv_buffer(
209
- layer.layer_id, forward_batch.out_cache_loc, k, v
210
- )
233
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
234
+
211
235
  o = prefill_wrapper_paged.forward(
212
236
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
213
237
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
214
- causal=True,
238
+ causal=not layer.is_cross_attention,
215
239
  sm_scale=layer.scaling,
216
240
  window_left=layer.sliding_window_size,
217
241
  logits_soft_cap=layer.logit_cap,
@@ -239,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend):
239
263
 
240
264
  o, _ = merge_state(o1, s1, o2, s2)
241
265
 
242
- forward_batch.token_to_kv_pool.set_kv_buffer(
243
- layer.layer_id, forward_batch.out_cache_loc, k, v
244
- )
266
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
245
267
 
246
268
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
247
269
 
248
- def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
270
+ def forward_decode(
271
+ self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
272
+ ):
249
273
  decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
274
+ cache_loc = (
275
+ forward_batch.out_cache_loc
276
+ if not layer.is_cross_attention
277
+ else forward_batch.encoder_out_cache_loc
278
+ )
250
279
 
251
280
  if k is not None:
252
281
  assert v is not None
253
- forward_batch.token_to_kv_pool.set_kv_buffer(
254
- layer.layer_id, forward_batch.out_cache_loc, k, v
255
- )
282
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
256
283
 
257
284
  o = decode_wrapper.forward(
258
285
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -263,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
263
290
 
264
291
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
265
292
 
266
- def _get_wrapper_idx(self, layer: nn.Module):
293
+ def _get_wrapper_idx(self, layer: RadixAttention):
267
294
  if self.num_wrappers == 1:
268
295
  return 0
269
296
 
@@ -290,6 +317,8 @@ class FlashInferIndicesUpdaterDecode:
290
317
  self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
291
318
  self.sliding_window_size = model_runner.sliding_window_size
292
319
 
320
+ self.attn_backend = attn_backend
321
+
293
322
  # Buffers and wrappers
294
323
  self.kv_indptr = attn_backend.kv_indptr
295
324
  self.kv_last_page_len = attn_backend.kv_last_page_len
@@ -297,55 +326,117 @@ class FlashInferIndicesUpdaterDecode:
297
326
  self.decode_wrappers = attn_backend.decode_wrappers
298
327
 
299
328
  # Dispatch
300
- if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
329
+ if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
301
330
  self.update = self.update_sliding_window
302
- elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
331
+ elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
303
332
  self.update = self.update_cross_attention
304
333
  else:
305
- assert attn_backend.num_wrappers == 1
334
+ assert self.attn_backend.num_wrappers == 1
306
335
  self.update = self.update_single_wrapper
307
336
 
308
- def update_single_wrapper(self, req_pool_indices, seq_lens, decode_wrappers=None):
337
+ def update(
338
+ self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
339
+ ):
340
+ # Keep the signature for type checking. It will be assigned during runtime.
341
+ raise NotImplementedError()
342
+
343
+ def update_single_wrapper(
344
+ self,
345
+ req_pool_indices: torch.Tensor,
346
+ seq_lens: torch.Tensor,
347
+ seq_lens_sum: int,
348
+ decode_wrappers=None,
349
+ encoder_lens=None,
350
+ ):
309
351
  decode_wrappers = decode_wrappers or self.decode_wrappers
310
352
  self.call_begin_forward(
311
- decode_wrappers[0], req_pool_indices, seq_lens, self.kv_indptr[0], None
353
+ decode_wrappers[0],
354
+ req_pool_indices,
355
+ seq_lens,
356
+ seq_lens_sum,
357
+ self.kv_indptr[0],
358
+ None,
312
359
  )
313
360
 
314
- def update_sliding_window(self, req_pool_indices, seq_lens, decode_wrappers=None):
361
+ def update_sliding_window(
362
+ self,
363
+ req_pool_indices: torch.Tensor,
364
+ seq_lens: torch.Tensor,
365
+ seq_lens_sum: int,
366
+ decode_wrappers=None,
367
+ encoder_lens=None,
368
+ ):
315
369
  decode_wrappers = decode_wrappers or self.decode_wrappers
316
370
 
317
371
  for wrapper_id in range(2):
318
372
  if wrapper_id == 0:
319
373
  # Sliding window attention
320
- paged_kernel_lens = torch.minimum( # TODO: replace this with clamp
374
+ paged_kernel_lens_tmp = torch.minimum( # TODO: replace this with clamp
321
375
  seq_lens,
322
376
  torch.tensor(self.sliding_window_size + 1),
323
377
  )
378
+ paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
379
+ kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
324
380
  else:
325
381
  # Full attention
326
- paged_kernel_lens = seq_lens
382
+ paged_kernel_lens_tmp = seq_lens
383
+ paged_kernel_lens_sum_tmp = seq_lens_sum
384
+ kv_start_idx_tmp = None
327
385
 
328
- kv_start_idx = seq_lens - paged_kernel_lens
386
+ self.call_begin_forward(
387
+ decode_wrappers[wrapper_id],
388
+ req_pool_indices,
389
+ paged_kernel_lens_tmp,
390
+ paged_kernel_lens_sum_tmp,
391
+ self.kv_indptr[wrapper_id],
392
+ kv_start_idx_tmp,
393
+ )
394
+
395
+ def update_cross_attention(
396
+ self,
397
+ req_pool_indices,
398
+ seq_lens,
399
+ seq_lens_sum,
400
+ decode_wrappers=None,
401
+ encoder_lens=None,
402
+ ):
403
+ decode_wrappers = decode_wrappers or self.decode_wrappers
404
+
405
+ for wrapper_id in range(2):
406
+ if wrapper_id == 0:
407
+ # Normal attention
408
+ paged_kernel_lens = seq_lens
409
+ kv_start_idx = encoder_lens
410
+ else:
411
+ # Cross attention
412
+ paged_kernel_lens = encoder_lens
413
+ kv_start_idx = torch.zeros_like(encoder_lens)
414
+ seq_lens_sum = encoder_lens.sum().item()
329
415
 
330
416
  self.call_begin_forward(
331
417
  decode_wrappers[wrapper_id],
332
418
  req_pool_indices,
333
419
  paged_kernel_lens,
420
+ seq_lens_sum,
334
421
  self.kv_indptr[wrapper_id],
335
422
  kv_start_idx,
336
423
  )
337
424
 
338
- def update_cross_attention(self):
339
- raise NotImplementedError()
340
-
341
425
  def call_begin_forward(
342
- self, wrapper, req_pool_indices, paged_kernel_lens, kv_indptr, kv_start_idx
426
+ self,
427
+ wrapper,
428
+ req_pool_indices,
429
+ paged_kernel_lens,
430
+ paged_kernel_lens_sum,
431
+ kv_indptr,
432
+ kv_start_idx,
343
433
  ):
344
434
  bs = len(req_pool_indices)
435
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
345
436
  kv_indptr = kv_indptr[: bs + 1]
346
- # TODO: optimize the blocking call on kv_indptr[-1]
347
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
348
- kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
437
+ kv_indices = torch.empty(
438
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
439
+ )
349
440
 
350
441
  create_flashinfer_kv_indices_triton[(bs,)](
351
442
  self.req_to_token,
@@ -386,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill:
386
477
  self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
387
478
  self.sliding_window_size = model_runner.sliding_window_size
388
479
 
480
+ self.attn_backend = attn_backend
481
+
389
482
  # Buffers and wrappers
390
483
  self.kv_indptr = attn_backend.kv_indptr
391
484
  self.kv_last_page_len = attn_backend.kv_last_page_len
@@ -395,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill:
395
488
  self.wrappers_paged = attn_backend.prefill_wrappers_paged
396
489
 
397
490
  # Dispatch
398
- if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
491
+ if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
399
492
  self.update = self.update_sliding_window
400
- elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
493
+ elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
401
494
  self.update = self.update_cross_attention
402
495
  else:
403
- assert attn_backend.num_wrappers == 1
496
+ assert self.attn_backend.num_wrappers == 1
404
497
  self.update = self.update_single_wrapper
405
498
 
499
+ def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
500
+ # Keep the signature for type checking. It will be assigned during runtime.
501
+ raise NotImplementedError()
502
+
406
503
  def update_single_wrapper(
407
- self, req_pool_indices, seq_lens, prefix_lens, use_ragged
504
+ self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
408
505
  ):
409
506
  if use_ragged:
410
507
  paged_kernel_lens = prefix_lens
@@ -425,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill:
425
522
  )
426
523
 
427
524
  def update_sliding_window(
428
- self, req_pool_indices, seq_lens, prefix_lens, use_ragged
525
+ self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
429
526
  ):
430
527
  for wrapper_id in range(2):
431
528
  if wrapper_id == 0:
@@ -452,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill:
452
549
  use_ragged,
453
550
  )
454
551
 
455
- def update_cross_attention(self):
456
- raise NotImplementedError()
552
+ def update_cross_attention(
553
+ self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
554
+ ):
555
+ for wrapper_id in range(2):
556
+ if wrapper_id == 0:
557
+ # normal attention
558
+ paged_kernel_lens = seq_lens
559
+ kv_start_idx = encoder_lens
560
+ else:
561
+ # cross attention
562
+ paged_kernel_lens = encoder_lens
563
+ kv_start_idx = torch.zeros_like(encoder_lens)
564
+
565
+ self.call_begin_forward(
566
+ self.wrapper_ragged,
567
+ self.wrappers_paged[wrapper_id],
568
+ req_pool_indices,
569
+ paged_kernel_lens,
570
+ seq_lens,
571
+ prefix_lens,
572
+ kv_start_idx,
573
+ self.kv_indptr[wrapper_id],
574
+ self.qo_indptr[wrapper_id],
575
+ use_ragged,
576
+ )
457
577
 
458
578
  def call_begin_forward(
459
579
  self,
@@ -469,8 +589,8 @@ class FlashInferIndicesUpdaterPrefill:
469
589
  use_ragged,
470
590
  ):
471
591
  bs = len(req_pool_indices)
592
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
472
593
  kv_indptr = kv_indptr[: bs + 1]
473
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
474
594
  kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
475
595
  create_flashinfer_kv_indices_triton[(bs,)](
476
596
  self.req_to_token,
@@ -482,8 +602,8 @@ class FlashInferIndicesUpdaterPrefill:
482
602
  self.max_context_len,
483
603
  )
484
604
 
605
+ qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
485
606
  qo_indptr = qo_indptr[: bs + 1]
486
- qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
487
607
 
488
608
  # extend part
489
609
  if use_ragged:
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
10
10
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
11
11
 
12
12
  if TYPE_CHECKING:
13
+ from sglang.srt.layers.radix_attention import RadixAttention
13
14
  from sglang.srt.model_executor.model_runner import ModelRunner
14
15
 
15
16
 
@@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend):
81
82
  )
82
83
 
83
84
  def init_forward_metadata_capture_cuda_graph(
84
- self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
85
+ self,
86
+ bs: int,
87
+ req_pool_indices: torch.Tensor,
88
+ seq_lens: torch.Tensor,
89
+ encoder_lens=None,
85
90
  ):
91
+ # NOTE: encoder_lens expected to be zeros or None
86
92
  self.forward_metadata = (
87
93
  self.cuda_graph_start_loc,
88
94
  self.cuda_graph_attn_logits,
@@ -91,15 +97,23 @@ class TritonAttnBackend(AttentionBackend):
91
97
  )
92
98
 
93
99
  def init_forward_metadata_replay_cuda_graph(
94
- self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
100
+ self,
101
+ bs: int,
102
+ req_pool_indices: torch.Tensor,
103
+ seq_lens: torch.Tensor,
104
+ seq_lens_sum: int,
105
+ encoder_lens=None,
95
106
  ):
107
+ # NOTE: encoder_lens expected to be zeros or None
96
108
  self.cuda_graph_start_loc.zero_()
97
109
  self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
98
110
 
99
111
  def get_cuda_graph_seq_len_fill_value(self):
100
112
  return 1
101
113
 
102
- def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
114
+ def forward_extend(
115
+ self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
116
+ ):
103
117
  # TODO: reuse the buffer across layers
104
118
  if layer.qk_head_dim != layer.v_head_dim:
105
119
  o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
@@ -107,7 +121,7 @@ class TritonAttnBackend(AttentionBackend):
107
121
  o = torch.empty_like(q)
108
122
 
109
123
  forward_batch.token_to_kv_pool.set_kv_buffer(
110
- layer.layer_id, forward_batch.out_cache_loc, k, v
124
+ layer, forward_batch.out_cache_loc, k, v
111
125
  )
112
126
 
113
127
  start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
@@ -129,7 +143,9 @@ class TritonAttnBackend(AttentionBackend):
129
143
  )
130
144
  return o
131
145
 
132
- def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
146
+ def forward_decode(
147
+ self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
148
+ ):
133
149
  # During torch.compile, there is a bug in rotary_emb that causes the
134
150
  # output value to have a 3D tensor shape. This reshapes the output correctly.
135
151
  q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
@@ -143,7 +159,7 @@ class TritonAttnBackend(AttentionBackend):
143
159
  start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
144
160
 
145
161
  forward_batch.token_to_kv_pool.set_kv_buffer(
146
- layer.layer_id, forward_batch.out_cache_loc, k, v
162
+ layer, forward_batch.out_cache_loc, k, v
147
163
  )
148
164
 
149
165
  self.decode_attention_fwd(
@@ -50,6 +50,7 @@ def _fwd_kernel(
50
50
  BLOCK_M: tl.constexpr,
51
51
  BLOCK_DMODEL: tl.constexpr,
52
52
  BLOCK_N: tl.constexpr,
53
+ IS_CAUSAL: tl.constexpr,
53
54
  Lk: tl.constexpr,
54
55
  ):
55
56
  cur_batch = tl.program_id(0)
@@ -78,7 +79,9 @@ def _fwd_kernel(
78
79
  mask_d = offs_d < Lk
79
80
 
80
81
  q = tl.load(
81
- Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0
82
+ Q + off_q,
83
+ mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
84
+ other=0.0,
82
85
  )
83
86
 
84
87
  k_ptrs = K + off_k
@@ -91,7 +94,12 @@ def _fwd_kernel(
91
94
 
92
95
  block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
93
96
 
94
- for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
97
+ end_n = (
98
+ cur_batch_seq_len
99
+ if not IS_CAUSAL
100
+ else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)
101
+ )
102
+ for start_n in range(0, block_mask * end_n, BLOCK_N):
95
103
  start_n = tl.multiple_of(start_n, BLOCK_N)
96
104
  # -- compute qk ----
97
105
  k = tl.load(
@@ -104,7 +112,18 @@ def _fwd_kernel(
104
112
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
105
113
  qk += tl.dot(q, k)
106
114
  qk *= sm_scale
107
- qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
115
+
116
+ if IS_CAUSAL:
117
+ qk += tl.where(
118
+ (start_n + offs_n[None, :] < cur_batch_seq_len)
119
+ & (offs_m[:, None] >= (start_n + offs_n[None, :])),
120
+ 0,
121
+ float("-inf"),
122
+ )
123
+ else:
124
+ qk += tl.where(
125
+ (start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
126
+ )
108
127
 
109
128
  # -- compute m_ij, p, l_ij
110
129
  m_ij = tl.max(qk, 1)
@@ -146,7 +165,9 @@ def _fwd_kernel(
146
165
  )
147
166
 
148
167
 
149
- def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
168
+ def context_attention_fwd(
169
+ q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
170
+ ):
150
171
  if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
151
172
  BLOCK = 128
152
173
  else:
@@ -181,6 +202,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
181
202
  BLOCK_M=BLOCK,
182
203
  BLOCK_DMODEL=triton.next_power_of_2(Lk),
183
204
  BLOCK_N=BLOCK,
205
+ IS_CAUSAL=is_causal,
184
206
  num_warps=num_warps,
185
207
  num_stages=1,
186
208
  Lk=Lk,