tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (54) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/lora/test_layers.py +0 -6
  3. tests/lora/utils.py +0 -8
  4. tests/test_envs.py +32 -11
  5. tests/test_utils.py +1 -2
  6. tpu_inference/__init__.py +22 -3
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +3 -4
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +61 -8
  11. tpu_inference/executors/ray_distributed_executor.py +31 -11
  12. tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
  15. tpu_inference/layers/common/attention_interface.py +7 -1
  16. tpu_inference/layers/common/sharding.py +5 -5
  17. tpu_inference/layers/vllm/fused_moe.py +74 -25
  18. tpu_inference/layers/vllm/quantization/common.py +6 -1
  19. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
  20. tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
  21. tpu_inference/layers/vllm/sharding.py +2 -2
  22. tpu_inference/lora/torch_punica_tpu.py +1 -2
  23. tpu_inference/models/common/model_loader.py +45 -11
  24. tpu_inference/models/jax/llama3.py +2 -1
  25. tpu_inference/models/jax/llama_eagle3.py +8 -5
  26. tpu_inference/models/jax/llama_guard_4.py +361 -0
  27. tpu_inference/models/jax/qwen2.py +2 -1
  28. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  29. tpu_inference/models/jax/qwen3.py +2 -1
  30. tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
  31. tpu_inference/models/jax/utils/weight_utils.py +198 -143
  32. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
  33. tpu_inference/platforms/tpu_platform.py +28 -22
  34. tpu_inference/runner/compilation_manager.py +144 -59
  35. tpu_inference/runner/kv_cache_manager.py +17 -18
  36. tpu_inference/runner/persistent_batch_manager.py +40 -2
  37. tpu_inference/runner/structured_decoding_manager.py +2 -3
  38. tpu_inference/runner/tpu_runner.py +271 -147
  39. tpu_inference/runner/utils.py +2 -2
  40. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  41. tpu_inference/tpu_info.py +4 -3
  42. tpu_inference/utils.py +36 -13
  43. tpu_inference/worker/tpu_worker.py +162 -25
  44. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
  45. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
  46. tpu_inference/mock/__init__.py +0 -0
  47. tpu_inference/mock/vllm_config_utils.py +0 -28
  48. tpu_inference/mock/vllm_envs.py +0 -1219
  49. tpu_inference/mock/vllm_logger.py +0 -212
  50. tpu_inference/mock/vllm_logging_utils.py +0 -15
  51. tpu_inference/models/jax/phi3.py +0 -376
  52. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
  53. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
  54. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
@@ -440,42 +440,54 @@ def _ragged_paged_attention_kernel(
440
440
  debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
441
441
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
442
442
 
443
- # Fetch effective kv from kv cache.
444
- def loop_body(i, offset):
445
- sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
446
- _async_copy(
447
- cache_hbm_ref.at[pl.ds(
448
- page_indices_ref[page_indices_offset + i] * page_size,
449
- sz)],
450
- vmem_ref.at[pl.ds(i * page_size, sz)],
451
- sem,
452
- wait,
443
+ if not wait:
444
+ # Fetch effective kv from kv cache.
445
+ def loop_body(i, offset):
446
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
447
+ _async_copy(
448
+ cache_hbm_ref.at[pl.ds(
449
+ page_indices_ref[page_indices_offset + i] * page_size,
450
+ sz)],
451
+ vmem_ref.at[pl.ds(i * page_size, sz)],
452
+ sem,
453
+ wait=False,
454
+ )
455
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
456
+ return offset + sz
457
+
458
+ offset = lax.fori_loop(
459
+ 0,
460
+ bkv_p_frm_cache,
461
+ loop_body,
462
+ 0, # offset
463
+ unroll=False,
453
464
  )
454
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
455
- return offset + sz
456
-
457
- offset = lax.fori_loop(
458
- 0,
459
- bkv_p_frm_cache,
460
- loop_body,
461
- 0, # offset
462
- unroll=False,
463
- )
464
465
 
465
- # Fetch kv directly from new kv.
466
- @pl.when(bkv_sz_frm_new > 0)
467
- def _fetch_bkv_from_new_kv():
468
- new_kv_len_start = q_end - kv_left_frm_new
469
- debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
470
- debug_print("[RPA debug] offset_in_bkv={}", offset)
466
+ # Fetch kv directly from new kv.
467
+ @pl.when(bkv_sz_frm_new > 0)
468
+ def _fetch_bkv_from_new_kv():
469
+ new_kv_len_start = q_end - kv_left_frm_new
470
+ debug_print("[RPA debug] new_kv_len_start={}",
471
+ new_kv_len_start)
472
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
473
+ _async_copy(
474
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
475
+ vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
476
+ sem,
477
+ wait,
478
+ )
479
+
480
+ return kv_len_start + offset, bkv_sz_frm_new
481
+ else:
482
+ offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
483
+ dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
471
484
  _async_copy(
472
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
473
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
474
- sem,
475
- wait,
485
+ src=dst,
486
+ dst=dst,
487
+ sem=sem,
488
+ wait=True,
476
489
  )
477
-
478
- return kv_len_start + offset, bkv_sz_frm_new
490
+ return kv_len_start + offset, bkv_sz_frm_new
479
491
 
480
492
  def _update_kv_cache(seq_idx,
481
493
  bkv_sem_idx,
@@ -511,30 +523,41 @@ def _ragged_paged_attention_kernel(
511
523
  debug_print("[RPA debug] p_ignore={}", p_ignore)
512
524
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
513
525
 
514
- def loop_body(i, states):
515
- update_sz, ignore = states
516
- sz = jnp.minimum(page_size - ignore, update_sz)
517
-
526
+ if not wait:
527
+
528
+ def loop_body(i, states):
529
+ update_sz, ignore = states
530
+ sz = jnp.minimum(page_size - ignore, update_sz)
531
+
532
+ _async_copy(
533
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
534
+ sz)],
535
+ cache_hbm_ref.at[pl.ds(
536
+ page_indices_ref[page_indices_offset + i] * page_size +
537
+ ignore,
538
+ sz,
539
+ )],
540
+ sem,
541
+ wait=False,
542
+ )
543
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
544
+ return update_sz - sz, 0
545
+
546
+ lax.fori_loop(
547
+ 0,
548
+ kv_p_end - kv_p_start,
549
+ loop_body,
550
+ (update_sz, ignore), # total transfer size
551
+ unroll=False,
552
+ )
553
+ else:
554
+ dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
518
555
  _async_copy(
519
- vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
520
- cache_hbm_ref.at[pl.ds(
521
- page_indices_ref[page_indices_offset + i] * page_size +
522
- ignore,
523
- sz,
524
- )],
525
- sem,
526
- wait,
556
+ src=dst,
557
+ dst=dst,
558
+ sem=sem,
559
+ wait=True,
527
560
  )
528
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
529
- return update_sz - sz, 0
530
-
531
- lax.fori_loop(
532
- 0,
533
- kv_p_end - kv_p_start,
534
- loop_body,
535
- (update_sz, ignore), # total transfer size
536
- unroll=False,
537
- )
538
561
 
539
562
  def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
540
563
  sem = sems.at[1, bq_sem_idx]
@@ -267,6 +267,7 @@ def _ragged_paged_attention_kernel(
267
267
  *,
268
268
  sm_scale: float,
269
269
  sliding_window: int | None = None,
270
+ strict_sliding_window: bool = True,
270
271
  soft_cap: float | None = None,
271
272
  mask_value: float = DEFAULT_MASK_VALUE,
272
273
  q_scale: float | None = None,
@@ -317,19 +318,20 @@ def _ragged_paged_attention_kernel(
317
318
  q_len = q_end - q_start
318
319
  kv_len = kv_lens_ref[seq_idx]
319
320
 
320
- bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
321
- kv_len - sliding_window, 0) // bkv_sz
322
-
323
321
  if sliding_window is None:
324
- next_bkv_idx_start = 0
322
+ bkv_idx_start = next_seq_bkv_idx_start = 0
325
323
  else:
324
+ bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
325
+ 0) // bkv_sz
326
326
 
327
327
  def get_next_bkv_idx_start():
328
328
  next_kv_len = kv_lens_ref[seq_idx + 1]
329
- return jnp.maximum(next_kv_len - sliding_window, 0) // bkv_sz
329
+ next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
330
+ return jnp.maximum(next_kv_len - next_q_len - sliding_window,
331
+ 0) // bkv_sz
330
332
 
331
- next_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
332
- get_next_bkv_idx_start, lambda: 0)
333
+ next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
334
+ get_next_bkv_idx_start, lambda: 0)
333
335
 
334
336
  def debug_print(msg, *args):
335
337
  if debug_mode:
@@ -350,7 +352,7 @@ def _ragged_paged_attention_kernel(
350
352
  debug_print("[RPA debug] q_len={}", q_len)
351
353
  debug_print("[RPA debug] kv_len={}", kv_len)
352
354
 
353
- def flash_attention(
355
+ def flash_attention_step1_qk_softmax(
354
356
  q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
355
357
  kv, # [bkv_sz, actual_head_dim_x2]
356
358
  *,
@@ -364,7 +366,6 @@ def _ragged_paged_attention_kernel(
364
366
  assert kv.shape == (bkv_sz, actual_head_dim_x2)
365
367
  head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
366
368
  head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
367
- head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
368
369
 
369
370
  def load_with_init(ref, init_val):
370
371
  return jnp.where(bkv_idx == bkv_idx_start,
@@ -386,16 +387,19 @@ def _ragged_paged_attention_kernel(
386
387
  s *= k_scale
387
388
  if q_scale is not None:
388
389
  s *= q_scale
390
+ if soft_cap is not None:
391
+ s = soft_cap * jnp.tanh(s / soft_cap)
389
392
 
390
393
  q_span = (kv_len - q_len + bq_idx * bq_sz +
391
394
  lax.broadcasted_iota(jnp.int32, s.shape, 0) //
392
395
  num_q_heads_per_kv_head)
393
396
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
394
- mask = q_span < k_span
397
+ mask = k_span <= q_span
395
398
 
396
- if soft_cap is not None:
397
- s = soft_cap * jnp.tanh(s / soft_cap)
398
- s += jnp.where(mask, mask_value, 0.0)
399
+ if sliding_window is not None and strict_sliding_window:
400
+ mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
401
+
402
+ s = jnp.where(mask, s, mask_value)
399
403
  s_rowmax = jnp.max(s, axis=1, keepdims=True)
400
404
 
401
405
  if attention_sink_ref is not None:
@@ -411,15 +415,33 @@ def _ragged_paged_attention_kernel(
411
415
  head_m_ref[...] = m_curr
412
416
  p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
413
417
 
414
- pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
415
- if v_scale is not None:
416
- pv *= v_scale
417
-
418
418
  p_rowsum = jnp.sum(p, axis=1, keepdims=True)
419
419
  exp_m_diff = jnp.exp(m_prev - m_curr)
420
420
  l_prev = load_with_init(head_l_ref, 1.0)
421
421
  l_curr = exp_m_diff * l_prev + p_rowsum
422
422
  head_l_ref[...] = l_curr
423
+
424
+ return p, exp_m_diff
425
+
426
+ def flash_attention_step2_pv(
427
+ q_shape_0,
428
+ kv, # [bkv_sz, actual_head_dim_x2]
429
+ p, # from step1
430
+ exp_m_diff, # from step1
431
+ *,
432
+ bkv_idx,
433
+ kv_head_idx,
434
+ ):
435
+ head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
436
+
437
+ def load_with_init(ref, init_val):
438
+ return jnp.where(bkv_idx == bkv_idx_start,
439
+ jnp.full_like(ref, init_val), ref[...])
440
+
441
+ pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
442
+ if v_scale is not None:
443
+ pv *= v_scale
444
+
423
445
  o_prev = load_with_init(head_acc_ref, 0.0)
424
446
  o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
425
447
  head_acc_ref[...] = o_curr
@@ -475,42 +497,54 @@ def _ragged_paged_attention_kernel(
475
497
  debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
476
498
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
477
499
 
478
- # Fetch effective kv from kv cache.
479
- def loop_body(i, offset):
480
- sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
481
- _async_copy(
482
- cache_hbm_ref.at[pl.ds(
483
- page_indices_ref[page_indices_offset + i] * page_size,
484
- sz)],
485
- vmem_ref.at[pl.ds(i * page_size, sz)],
486
- sem,
487
- wait,
500
+ if not wait:
501
+ # Fetch effective kv from kv cache.
502
+ def loop_body(i, offset):
503
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
504
+ _async_copy(
505
+ cache_hbm_ref.at[pl.ds(
506
+ page_indices_ref[page_indices_offset + i] * page_size,
507
+ sz)],
508
+ vmem_ref.at[pl.ds(i * page_size, sz)],
509
+ sem,
510
+ wait=False,
511
+ )
512
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
513
+ return offset + sz
514
+
515
+ offset = lax.fori_loop(
516
+ 0,
517
+ bkv_p_frm_cache,
518
+ loop_body,
519
+ 0, # offset
520
+ unroll=False,
488
521
  )
489
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
490
- return offset + sz
491
-
492
- offset = lax.fori_loop(
493
- 0,
494
- bkv_p_frm_cache,
495
- loop_body,
496
- 0, # offset
497
- unroll=False,
498
- )
499
522
 
500
- # Fetch kv directly from new kv.
501
- @pl.when(bkv_sz_frm_new > 0)
502
- def _fetch_bkv_from_new_kv():
503
- new_kv_len_start = q_end - kv_left_frm_new
504
- debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
505
- debug_print("[RPA debug] offset_in_bkv={}", offset)
523
+ # Fetch kv directly from new kv.
524
+ @pl.when(bkv_sz_frm_new > 0)
525
+ def _fetch_bkv_from_new_kv():
526
+ new_kv_len_start = q_end - kv_left_frm_new
527
+ debug_print("[RPA debug] new_kv_len_start={}",
528
+ new_kv_len_start)
529
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
530
+ _async_copy(
531
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
532
+ vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
533
+ sem,
534
+ wait,
535
+ )
536
+
537
+ return kv_len_start + offset, bkv_sz_frm_new
538
+ else:
539
+ offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
540
+ dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
506
541
  _async_copy(
507
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
508
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
509
- sem,
510
- wait,
542
+ src=dst,
543
+ dst=dst,
544
+ sem=sem,
545
+ wait=True,
511
546
  )
512
-
513
- return kv_len_start + offset, bkv_sz_frm_new
547
+ return kv_len_start + offset, bkv_sz_frm_new
514
548
 
515
549
  def _update_kv_cache(seq_idx,
516
550
  bkv_sem_idx,
@@ -546,30 +580,41 @@ def _ragged_paged_attention_kernel(
546
580
  debug_print("[RPA debug] p_ignore={}", p_ignore)
547
581
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
548
582
 
549
- def loop_body(i, states):
550
- update_sz, ignore = states
551
- sz = jnp.minimum(page_size - ignore, update_sz)
552
-
583
+ if not wait:
584
+
585
+ def loop_body(i, states):
586
+ update_sz, ignore = states
587
+ sz = jnp.minimum(page_size - ignore, update_sz)
588
+
589
+ _async_copy(
590
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
591
+ sz)],
592
+ cache_hbm_ref.at[pl.ds(
593
+ page_indices_ref[page_indices_offset + i] * page_size +
594
+ ignore,
595
+ sz,
596
+ )],
597
+ sem,
598
+ wait=False,
599
+ )
600
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
601
+ return update_sz - sz, 0
602
+
603
+ lax.fori_loop(
604
+ 0,
605
+ kv_p_end - kv_p_start,
606
+ loop_body,
607
+ (update_sz, ignore), # total transfer size
608
+ unroll=False,
609
+ )
610
+ else:
611
+ dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
553
612
  _async_copy(
554
- vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
555
- cache_hbm_ref.at[pl.ds(
556
- page_indices_ref[page_indices_offset + i] * page_size +
557
- ignore,
558
- sz,
559
- )],
560
- sem,
561
- wait,
613
+ src=dst,
614
+ dst=dst,
615
+ sem=sem,
616
+ wait=True,
562
617
  )
563
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
564
- return update_sz - sz, 0
565
-
566
- lax.fori_loop(
567
- 0,
568
- kv_p_end - kv_p_start,
569
- loop_body,
570
- (update_sz, ignore), # total transfer size
571
- unroll=False,
572
- )
573
618
 
574
619
  def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
575
620
  sem = sems.at[1, bq_sem_idx]
@@ -737,13 +782,17 @@ def _ragged_paged_attention_kernel(
737
782
  next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
738
783
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
739
784
 
740
- next_bkv_idx = lax.select(
741
- is_last_bkv,
742
- lax.select(
785
+ if sliding_window is None:
786
+ next_bkv_start_idx = 0
787
+ else:
788
+ next_bkv_start_idx = lax.select(
743
789
  is_last_bq,
744
- next_bkv_idx_start,
790
+ next_seq_bkv_idx_start,
745
791
  bkv_idx_start,
746
- ), next_bkv_idx)
792
+ )
793
+ next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
794
+ next_bkv_idx)
795
+
747
796
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
748
797
 
749
798
  def compute_with_bq(bq_idx, _):
@@ -803,6 +852,11 @@ def _ragged_paged_attention_kernel(
803
852
  return
804
853
 
805
854
  # Flash attention with cur bkv and bq
855
+ prev_bq_shape_0 = None
856
+ prev_kv_head_bkv = None
857
+ prev_kv_head_idx = None
858
+ prev_kv_head_p = None
859
+ prev_kv_head_exp_m_diff = None
806
860
  for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
807
861
  bkv_lst = strided_load_bkv(
808
862
  bkv_sem_idx,
@@ -812,20 +866,51 @@ def _ragged_paged_attention_kernel(
812
866
  )
813
867
  assert len(bkv_lst) == kv_packing
814
868
  for i in range(kv_packing):
815
- kv_head_idx = kv_head_start + i
816
- if kv_head_idx >= actual_num_kv_heads:
869
+ cur_kv_head_idx = kv_head_start + i
870
+ if cur_kv_head_idx >= actual_num_kv_heads:
817
871
  break
818
- bq = load_bq(bq_sem_idx,
819
- kv_head_idx,
820
- actual_bq_sz=actual_bq_sz)
821
- bkv = bkv_lst[i]
822
- flash_attention(
823
- bq,
824
- bkv,
825
- bq_idx=bq_idx,
826
- bkv_idx=bkv_idx,
827
- kv_head_idx=kv_head_idx,
828
- )
872
+ cur_kv_head_bq = load_bq(bq_sem_idx,
873
+ cur_kv_head_idx,
874
+ actual_bq_sz=actual_bq_sz)
875
+ cur_kv_head__bkv = bkv_lst[i]
876
+ # FlashAttention is divided into `flash_attention_step1_qk_softmax`
877
+ # and `flash_attention_step2_pv` to pipeline the computation.
878
+ # `step2_pv` for the previous KV head, which depends on the softmax
879
+ # output, is overlapped with `step1_qk_softmax` for the current KV
880
+ # head, reducing overall wait times.
881
+ cur_kv_head_p, cur_kv_head_exp_m_diff = (
882
+ flash_attention_step1_qk_softmax(
883
+ cur_kv_head_bq,
884
+ cur_kv_head__bkv,
885
+ bq_idx=bq_idx,
886
+ bkv_idx=bkv_idx,
887
+ kv_head_idx=cur_kv_head_idx,
888
+ ))
889
+ if prev_bq_shape_0 is not None:
890
+ flash_attention_step2_pv(
891
+ prev_bq_shape_0,
892
+ prev_kv_head_bkv,
893
+ prev_kv_head_p,
894
+ prev_kv_head_exp_m_diff,
895
+ bkv_idx=bkv_idx,
896
+ kv_head_idx=prev_kv_head_idx,
897
+ )
898
+ prev_bq_shape_0 = cur_kv_head_bq.shape[0]
899
+ prev_kv_head_bkv = cur_kv_head__bkv
900
+ prev_kv_head_p = cur_kv_head_p
901
+ prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
902
+ prev_kv_head_idx = cur_kv_head_idx
903
+
904
+ # Execute pv of last attention head.
905
+ assert prev_bq_shape_0 is not None
906
+ flash_attention_step2_pv(
907
+ prev_bq_shape_0,
908
+ prev_kv_head_bkv,
909
+ prev_kv_head_p,
910
+ prev_kv_head_exp_m_diff,
911
+ bkv_idx=bkv_idx,
912
+ kv_head_idx=prev_kv_head_idx,
913
+ )
829
914
 
830
915
  lax.fori_loop(bkv_idx_start,
831
916
  num_bkv,
@@ -1226,6 +1311,7 @@ def static_validate_inputs(
1226
1311
  static_argnames=(
1227
1312
  "sm_scale",
1228
1313
  "sliding_window",
1314
+ "strict_sliding_window",
1229
1315
  "soft_cap",
1230
1316
  "mask_value",
1231
1317
  "q_scale",
@@ -1255,6 +1341,7 @@ def ragged_paged_attention_hd64(
1255
1341
  *,
1256
1342
  sm_scale: float = 1.0,
1257
1343
  sliding_window: int | None = None,
1344
+ strict_sliding_window: bool = True,
1258
1345
  soft_cap: float | None = None,
1259
1346
  mask_value: float | None = DEFAULT_MASK_VALUE,
1260
1347
  q_scale: float | None = None,
@@ -1269,42 +1356,41 @@ def ragged_paged_attention_hd64(
1269
1356
  # Debug params.
1270
1357
  debug_mode: bool = False,
1271
1358
  ):
1272
- """A special Ragged paged attention version for head_dim=64 that supports mixed
1273
-
1274
- prefill and decode.
1275
-
1276
- Args:
1277
- queries: concatenated all sequences' queries.
1278
- keys: concatenated all sequences' keys (quantized).
1279
- values: concatenated all sequences' values (quantized).
1280
- kv_cache: paged KV cache with TPU-friendly shape.
1281
- kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1282
- page_indices: flattened page indices look-up table by (seq_id, page_id).
1283
- cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1284
- kv_lens, only the first num_seqs+1 values are valid.
1285
- distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1286
- sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1287
- k is also the total number of sequences.
1288
- attention_sink: optional attention sink for each q head.
1289
- actual_head_dim: the actual head size of the attention. Here we assume k and
1290
- v have the same actual head size.
1291
- sm_scale: the softmax scale which will be applied to the Q@K^T.
1292
- sliding_window: the sliding window size for the attention.
1293
- soft_cap: the logit soft cap for the attention.
1294
- mask_value: mask value for causal mask.
1295
- k_scale: the scale for the key cache.
1296
- v_scale: the scale for the value cache.
1297
- num_kv_pages_per_block: number of kv pages to be processed in one flash
1298
- attention block in the pallas kernel.
1299
- num_queries_per_block: number of kv pages to be processed in one flash
1300
- attention block in the pallas kernel.
1301
- vmem_limit_bytes: the vmem limit for the pallas kernel.
1302
- debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1303
- print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1304
-
1305
- Returns:
1306
- The output of the attention.
1307
- """
1359
+ """A variant of ragged paged attention for head_dim=64.
1360
+
1361
+ Args:
1362
+ queries: concatenated all sequences' queries.
1363
+ keys: concatenated all sequences' keys (quantized).
1364
+ values: concatenated all sequences' values (quantized).
1365
+ kv_cache: paged KV cache with TPU-friendly shape.
1366
+ kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1367
+ page_indices: flattened page indices look-up table by (seq_id, page_id).
1368
+ cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1369
+ kv_lens, only the first num_seqs+1 values are valid.
1370
+ distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1371
+ sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1372
+ k is also the total number of sequences.
1373
+ attention_sink: optional attention sink for each q head.
1374
+ sm_scale: the softmax scale which will be applied to the Q@K^T.
1375
+ sliding_window: the sliding window size for the attention.
1376
+ strict_sliding_window: compute tokens that are strictly within the window.
1377
+ soft_cap: the logit soft cap for the attention.
1378
+ mask_value: mask value for causal mask.
1379
+ q_scale: the scale for the query.
1380
+ k_scale: the scale for the key cache.
1381
+ v_scale: the scale for the value cache.
1382
+ chunk_prefill_size: the chunk prefill size for the attention.
1383
+ num_kv_pages_per_block: number of kv pages to be processed in one flash
1384
+ attention block in the pallas kernel.
1385
+ num_queries_per_block: number of kv pages to be processed in one flash
1386
+ attention block in the pallas kernel.
1387
+ vmem_limit_bytes: the vmem limit for the pallas kernel.
1388
+ debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1389
+ print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1390
+
1391
+ Returns:
1392
+ The output of the attention.
1393
+ """
1308
1394
  q, k, v = queries, keys, values
1309
1395
  static_validate_inputs(
1310
1396
  q,
@@ -1374,7 +1460,7 @@ def ragged_paged_attention_hd64(
1374
1460
  pl.BlockSpec(memory_space=pltpu.HBM),
1375
1461
  pl.BlockSpec(memory_space=pltpu.HBM),
1376
1462
  None if attention_sink is None else pl.BlockSpec(
1377
- memory_space=pltpu.VMEM)
1463
+ memory_space=pltpu.VMEM),
1378
1464
  ]
1379
1465
 
1380
1466
  out_specs = [
@@ -1438,6 +1524,7 @@ def ragged_paged_attention_hd64(
1438
1524
  _ragged_paged_attention_kernel,
1439
1525
  sm_scale=sm_scale,
1440
1526
  sliding_window=sliding_window,
1527
+ strict_sliding_window=strict_sliding_window,
1441
1528
  soft_cap=soft_cap,
1442
1529
  mask_value=mask_value,
1443
1530
  q_scale=q_scale,
@@ -308,7 +308,13 @@ def sharded_ragged_paged_attention(
308
308
  args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
309
309
 
310
310
  use_hd64 = q.shape[-1] == 64
311
- func = ragged_paged_attention_hd64 if use_hd64 else ragged_paged_attention
311
+
312
+ func = ragged_paged_attention
313
+ if use_hd64:
314
+ func = functools.partial(ragged_paged_attention_hd64,
315
+ strict_sliding_window=False)
316
+ else:
317
+ func = ragged_paged_attention
312
318
 
313
319
  if attention_sink is not None:
314
320
  if not use_hd64: