tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (67) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_utils.py +16 -24
  6. tpu_inference/__init__.py +3 -22
  7. tpu_inference/core/core_tpu.py +9 -17
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +11 -31
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
  16. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
  19. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  20. tpu_inference/layers/jax/sample/sampling.py +2 -2
  21. tpu_inference/layers/{common → jax}/sharding.py +5 -5
  22. tpu_inference/layers/vllm/attention.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +208 -170
  24. tpu_inference/layers/vllm/quantization/__init__.py +3 -7
  25. tpu_inference/layers/vllm/quantization/awq.py +3 -4
  26. tpu_inference/layers/vllm/quantization/common.py +1 -6
  27. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
  28. tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
  29. tpu_inference/layers/vllm/sharding.py +2 -2
  30. tpu_inference/lora/torch_punica_tpu.py +2 -1
  31. tpu_inference/mock/__init__.py +0 -0
  32. tpu_inference/mock/vllm_config_utils.py +28 -0
  33. tpu_inference/mock/vllm_envs.py +1219 -0
  34. tpu_inference/mock/vllm_logger.py +212 -0
  35. tpu_inference/mock/vllm_logging_utils.py +15 -0
  36. tpu_inference/models/common/model_loader.py +12 -46
  37. tpu_inference/models/jax/llama3.py +3 -4
  38. tpu_inference/models/jax/llama_eagle3.py +5 -8
  39. tpu_inference/models/jax/phi3.py +376 -0
  40. tpu_inference/models/jax/qwen2.py +2 -3
  41. tpu_inference/models/jax/qwen2_5_vl.py +50 -165
  42. tpu_inference/models/jax/qwen3.py +2 -3
  43. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  44. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  45. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
  46. tpu_inference/platforms/tpu_platform.py +34 -47
  47. tpu_inference/runner/compilation_manager.py +60 -145
  48. tpu_inference/runner/kv_cache.py +2 -2
  49. tpu_inference/runner/kv_cache_manager.py +18 -17
  50. tpu_inference/runner/persistent_batch_manager.py +2 -40
  51. tpu_inference/runner/structured_decoding_manager.py +3 -2
  52. tpu_inference/runner/tpu_runner.py +135 -283
  53. tpu_inference/runner/utils.py +2 -2
  54. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  55. tpu_inference/tpu_info.py +3 -4
  56. tpu_inference/utils.py +15 -38
  57. tpu_inference/worker/tpu_worker.py +26 -163
  58. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
  59. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
  60. tests/test_envs.py +0 -203
  61. tpu_inference/layers/common/quant_methods.py +0 -8
  62. tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
  63. tpu_inference/models/jax/llama_guard_4.py +0 -361
  64. /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
  65. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
  66. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
  67. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
@@ -267,7 +267,6 @@ def _ragged_paged_attention_kernel(
267
267
  *,
268
268
  sm_scale: float,
269
269
  sliding_window: int | None = None,
270
- strict_sliding_window: bool = True,
271
270
  soft_cap: float | None = None,
272
271
  mask_value: float = DEFAULT_MASK_VALUE,
273
272
  q_scale: float | None = None,
@@ -318,21 +317,6 @@ def _ragged_paged_attention_kernel(
318
317
  q_len = q_end - q_start
319
318
  kv_len = kv_lens_ref[seq_idx]
320
319
 
321
- if sliding_window is None:
322
- bkv_idx_start = next_seq_bkv_idx_start = 0
323
- else:
324
- bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
325
- 0) // bkv_sz
326
-
327
- def get_next_bkv_idx_start():
328
- next_kv_len = kv_lens_ref[seq_idx + 1]
329
- next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
330
- return jnp.maximum(next_kv_len - next_q_len - sliding_window,
331
- 0) // bkv_sz
332
-
333
- next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
334
- get_next_bkv_idx_start, lambda: 0)
335
-
336
320
  def debug_print(msg, *args):
337
321
  if debug_mode:
338
322
  pl.debug_print(msg, *args)
@@ -352,7 +336,7 @@ def _ragged_paged_attention_kernel(
352
336
  debug_print("[RPA debug] q_len={}", q_len)
353
337
  debug_print("[RPA debug] kv_len={}", kv_len)
354
338
 
355
- def flash_attention_step1_qk_softmax(
339
+ def flash_attention(
356
340
  q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
357
341
  kv, # [bkv_sz, actual_head_dim_x2]
358
342
  *,
@@ -366,10 +350,11 @@ def _ragged_paged_attention_kernel(
366
350
  assert kv.shape == (bkv_sz, actual_head_dim_x2)
367
351
  head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
368
352
  head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
353
+ head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
369
354
 
370
355
  def load_with_init(ref, init_val):
371
- return jnp.where(bkv_idx == bkv_idx_start,
372
- jnp.full_like(ref, init_val), ref[...])
356
+ return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
357
+ ref[...])
373
358
 
374
359
  # Follow FlashAttention-2 forward pass.
375
360
  if q_scale is not None:
@@ -387,27 +372,26 @@ def _ragged_paged_attention_kernel(
387
372
  s *= k_scale
388
373
  if q_scale is not None:
389
374
  s *= q_scale
390
- if soft_cap is not None:
391
- s = soft_cap * jnp.tanh(s / soft_cap)
392
375
 
393
376
  q_span = (kv_len - q_len + bq_idx * bq_sz +
394
377
  lax.broadcasted_iota(jnp.int32, s.shape, 0) //
395
378
  num_q_heads_per_kv_head)
396
379
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
397
- mask = k_span <= q_span
398
-
399
- if sliding_window is not None and strict_sliding_window:
400
- mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
380
+ mask = q_span < k_span
381
+ # TODO(jevinjiang, xiowei): reduce pages_per_seq based on sliding_window.
382
+ if sliding_window is not None:
383
+ mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)
401
384
 
402
- s = jnp.where(mask, s, mask_value)
385
+ if soft_cap is not None:
386
+ s = soft_cap * jnp.tanh(s / soft_cap)
387
+ s += jnp.where(mask, mask_value, 0.0)
403
388
  s_rowmax = jnp.max(s, axis=1, keepdims=True)
404
389
 
405
390
  if attention_sink_ref is not None:
406
391
  sinks = attention_sink_ref[kv_head_idx]
407
392
  actual_bq_sz = q.shape[0] // num_q_heads_per_kv_head
408
393
  m_prev_init = jnp.concat([sinks] * actual_bq_sz, axis=0)
409
- m_prev = jnp.where(bkv_idx == bkv_idx_start, m_prev_init,
410
- head_m_ref[...])
394
+ m_prev = jnp.where(bkv_idx == 0, m_prev_init, head_m_ref[...])
411
395
  else:
412
396
  m_prev = load_with_init(head_m_ref, -jnp.inf)
413
397
 
@@ -415,33 +399,15 @@ def _ragged_paged_attention_kernel(
415
399
  head_m_ref[...] = m_curr
416
400
  p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
417
401
 
402
+ pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
403
+ if v_scale is not None:
404
+ pv *= v_scale
405
+
418
406
  p_rowsum = jnp.sum(p, axis=1, keepdims=True)
419
407
  exp_m_diff = jnp.exp(m_prev - m_curr)
420
408
  l_prev = load_with_init(head_l_ref, 1.0)
421
409
  l_curr = exp_m_diff * l_prev + p_rowsum
422
410
  head_l_ref[...] = l_curr
423
-
424
- return p, exp_m_diff
425
-
426
- def flash_attention_step2_pv(
427
- q_shape_0,
428
- kv, # [bkv_sz, actual_head_dim_x2]
429
- p, # from step1
430
- exp_m_diff, # from step1
431
- *,
432
- bkv_idx,
433
- kv_head_idx,
434
- ):
435
- head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
436
-
437
- def load_with_init(ref, init_val):
438
- return jnp.where(bkv_idx == bkv_idx_start,
439
- jnp.full_like(ref, init_val), ref[...])
440
-
441
- pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
442
- if v_scale is not None:
443
- pv *= v_scale
444
-
445
411
  o_prev = load_with_init(head_acc_ref, 0.0)
446
412
  o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
447
413
  head_acc_ref[...] = o_curr
@@ -456,12 +422,7 @@ def _ragged_paged_attention_kernel(
456
422
  else:
457
423
  cp.start()
458
424
 
459
- def _fetch_bkv(seq_idx,
460
- bkv_idx,
461
- bkv_sem_idx,
462
- *,
463
- is_full_fetch=False,
464
- wait=False):
425
+ def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
465
426
  sem = sems.at[0, bkv_sem_idx]
466
427
  vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
467
428
 
@@ -502,73 +463,42 @@ def _ragged_paged_attention_kernel(
502
463
  debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
503
464
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
504
465
 
505
- if not wait:
506
- # Fetch effective kv from kv cache.
507
- def loop_body(i, offset):
508
- sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
509
- _async_copy(
510
- cache_hbm_ref.at[pl.ds(
511
- page_indices_ref[page_indices_offset + i] * page_size,
512
- sz)],
513
- vmem_ref.at[pl.ds(i * page_size, sz)],
514
- sem,
515
- wait=False,
516
- )
517
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
518
- return offset + sz
519
-
520
- offset = lax.fori_loop(
521
- 0,
522
- bkv_p_frm_cache,
523
- loop_body,
524
- 0, # offset
525
- unroll=False,
466
+ # Fetch effective kv from kv cache.
467
+ def loop_body(i, offset):
468
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
469
+ _async_copy(
470
+ cache_hbm_ref.at[pl.ds(
471
+ page_indices_ref[page_indices_offset + i] * page_size,
472
+ sz)],
473
+ vmem_ref.at[pl.ds(i * page_size, sz)],
474
+ sem,
475
+ wait,
526
476
  )
477
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
478
+ return offset + sz
479
+
480
+ offset = lax.fori_loop(
481
+ 0,
482
+ bkv_p_frm_cache,
483
+ loop_body,
484
+ 0, # offset
485
+ unroll=False,
486
+ )
527
487
 
528
- # Fetch kv directly from new kv.
529
- @pl.when(bkv_sz_frm_new > 0)
530
- def _fetch_bkv_from_new_kv():
531
- new_kv_len_start = q_end - kv_left_frm_new
532
- debug_print("[RPA debug] new_kv_len_start={}",
533
- new_kv_len_start)
534
- debug_print("[RPA debug] offset_in_bkv={}", offset)
535
- _async_copy(
536
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
537
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
538
- sem,
539
- wait,
540
- )
541
-
542
- # NOTE(chengjiyao): This condition is true for the first two bkv fetches.
543
- # We need to ensure the bkv_x2_ref VMEM buffer is fully initialized to
544
- # avoid potential NaN values in regions not overwritten by actual data.
545
- # This is done by padding the remaining parts of the buffer with data
546
- # from the KV cache. This special handling is only strictly necessary
547
- # until both buffers in the double buffer (bkv_x2_ref) have been written
548
- # to at least once.
549
- @pl.when(is_full_fetch)
550
- def _make_sure_bkv_vmem_is_not_nan():
551
- effective_sz = offset + bkv_sz_frm_new
552
- remaining_sz = bkv_sz - effective_sz
553
- _async_copy(
554
- cache_hbm_ref.at[pl.ds(0, remaining_sz)],
555
- vmem_ref.at[pl.ds(effective_sz, remaining_sz)],
556
- sem,
557
- wait,
558
- )
559
-
560
- return kv_len_start + offset, bkv_sz_frm_new
561
- else:
562
- offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
563
- sz = lax.select(is_full_fetch, bkv_sz, offset + bkv_sz_frm_new)
564
- dst = vmem_ref.at[pl.ds(0, sz)]
488
+ # Fetch kv directly from new kv.
489
+ @pl.when(bkv_sz_frm_new > 0)
490
+ def _fetch_bkv_from_new_kv():
491
+ new_kv_len_start = q_end - kv_left_frm_new
492
+ debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
493
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
565
494
  _async_copy(
566
- src=dst,
567
- dst=dst,
568
- sem=sem,
569
- wait=True,
495
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
496
+ vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
497
+ sem,
498
+ wait,
570
499
  )
571
- return kv_len_start + offset, bkv_sz_frm_new
500
+
501
+ return kv_len_start + offset, bkv_sz_frm_new
572
502
 
573
503
  def _update_kv_cache(seq_idx,
574
504
  bkv_sem_idx,
@@ -604,41 +534,30 @@ def _ragged_paged_attention_kernel(
604
534
  debug_print("[RPA debug] p_ignore={}", p_ignore)
605
535
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
606
536
 
607
- if not wait:
608
-
609
- def loop_body(i, states):
610
- update_sz, ignore = states
611
- sz = jnp.minimum(page_size - ignore, update_sz)
612
-
613
- _async_copy(
614
- vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
615
- sz)],
616
- cache_hbm_ref.at[pl.ds(
617
- page_indices_ref[page_indices_offset + i] * page_size +
618
- ignore,
619
- sz,
620
- )],
621
- sem,
622
- wait=False,
623
- )
624
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
625
- return update_sz - sz, 0
626
-
627
- lax.fori_loop(
628
- 0,
629
- kv_p_end - kv_p_start,
630
- loop_body,
631
- (update_sz, ignore), # total transfer size
632
- unroll=False,
633
- )
634
- else:
635
- dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
537
+ def loop_body(i, states):
538
+ update_sz, ignore = states
539
+ sz = jnp.minimum(page_size - ignore, update_sz)
540
+
636
541
  _async_copy(
637
- src=dst,
638
- dst=dst,
639
- sem=sem,
640
- wait=True,
542
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
543
+ cache_hbm_ref.at[pl.ds(
544
+ page_indices_ref[page_indices_offset + i] * page_size +
545
+ ignore,
546
+ sz,
547
+ )],
548
+ sem,
549
+ wait,
641
550
  )
551
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
552
+ return update_sz - sz, 0
553
+
554
+ lax.fori_loop(
555
+ 0,
556
+ kv_p_end - kv_p_start,
557
+ loop_body,
558
+ (update_sz, ignore), # total transfer size
559
+ unroll=False,
560
+ )
642
561
 
643
562
  def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
644
563
  sem = sems.at[1, bq_sem_idx]
@@ -688,18 +607,11 @@ def _ragged_paged_attention_kernel(
688
607
  wait,
689
608
  )
690
609
 
691
- def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, is_full_fetch=False):
692
- return _fetch_bkv(seq_idx,
693
- bkv_idx,
694
- bkv_sem_idx,
695
- is_full_fetch=is_full_fetch)
610
+ def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
611
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
696
612
 
697
- def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, is_full_fetch=False):
698
- return _fetch_bkv(seq_idx,
699
- bkv_idx,
700
- bkv_sem_idx,
701
- is_full_fetch=is_full_fetch,
702
- wait=True)
613
+ def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
614
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, wait=True)
703
615
 
704
616
  def start_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
705
617
  return _fetch_bq(seq_idx, bq_idx, bq_sem_idx)
@@ -757,7 +669,7 @@ def _ragged_paged_attention_kernel(
757
669
  vec = ref[start::step]
758
670
  return vec
759
671
 
760
- def strided_load_bkv(bkv_sem_idx, start, step):
672
+ def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
761
673
  assert start % kv_packing == 0
762
674
  assert step % kv_packing == 0
763
675
  start //= kv_packing
@@ -766,6 +678,7 @@ def _ragged_paged_attention_kernel(
766
678
  bkv_sz * step, actual_head_dim_x2))
767
679
 
768
680
  kv = strided_load(kv_ref, start, step)
681
+ kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
769
682
  bitwidth = 32 // kv_packing
770
683
  repack_ty = jnp.dtype(f"uint{bitwidth}")
771
684
  lst = []
@@ -806,23 +719,12 @@ def _ragged_paged_attention_kernel(
806
719
  def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
807
720
  next_bkv_idx = bkv_idx + 1
808
721
  is_last_bkv = next_bkv_idx == num_bkv
722
+ next_bkv_idx = lax.select(is_last_bkv, 0, next_bkv_idx)
809
723
  next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
810
724
  is_last_bq = next_bq_idx == num_bq
811
725
  next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
812
726
  next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
813
727
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
814
-
815
- if sliding_window is None:
816
- next_bkv_start_idx = 0
817
- else:
818
- next_bkv_start_idx = lax.select(
819
- is_last_bq,
820
- next_seq_bkv_idx_start,
821
- bkv_idx_start,
822
- )
823
- next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
824
- next_bkv_idx)
825
-
826
728
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
827
729
 
828
730
  def compute_with_bq(bq_idx, _):
@@ -839,36 +741,31 @@ def _ragged_paged_attention_kernel(
839
741
  def compute_with_bkv(bkv_idx, _):
840
742
  # Create bitmask for KV.
841
743
  assert bkv_sz % kv_packing == 0
744
+ actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
745
+ bkv_shape = (bkv_sz, actual_head_dim_x2)
746
+ bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
747
+ 0) < actual_bkv_sz
842
748
 
843
749
  # Get next bkv ids.
844
750
  bkv_sem_idx = sem_ids_ref[1]
845
- next_seq_idx, next_bq_idx_for_kv, next_bkv_idx, next_bkv_sem_idx = (
846
- get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx))
751
+ next_seq_idx, _, next_bkv_idx, next_bkv_sem_idx = get_next_bkv_ids(
752
+ seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
847
753
 
848
754
  # Prefetch next bkv
849
755
  @pl.when(next_seq_idx < num_seqs)
850
756
  def prefetch_next_bkv():
851
757
  sem_ids_ref[1] = next_bkv_sem_idx
852
- start_fetch_bkv(
853
- next_seq_idx,
854
- next_bkv_idx,
855
- next_bkv_sem_idx,
856
- is_full_fetch=next_seq_idx + next_bq_idx_for_kv +
857
- next_bkv_idx < 2,
858
- )
758
+ start_fetch_bkv(next_seq_idx, next_bkv_idx,
759
+ next_bkv_sem_idx)
859
760
 
860
761
  # Wait for cur bq if not ready yet
861
- @pl.when(bkv_idx == bkv_idx_start)
762
+ @pl.when(bkv_idx == 0)
862
763
  def wait_cur_bq():
863
764
  wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
864
765
 
865
766
  # Wait for cur bkv
866
- offset, update_sz = wait_fetch_bkv(
867
- seq_idx,
868
- bkv_idx,
869
- bkv_sem_idx,
870
- is_full_fetch=seq_idx + bq_idx + bkv_idx < 2,
871
- )
767
+ offset, update_sz = wait_fetch_bkv(seq_idx, bkv_idx,
768
+ bkv_sem_idx)
872
769
 
873
770
  # Start updating bkv to kv cache if applicable.
874
771
  # Only needed in first bq loop.
@@ -887,70 +784,31 @@ def _ragged_paged_attention_kernel(
887
784
  return
888
785
 
889
786
  # Flash attention with cur bkv and bq
890
- prev_bq_shape_0 = None
891
- prev_kv_head_bkv = None
892
- prev_kv_head_idx = None
893
- prev_kv_head_p = None
894
- prev_kv_head_exp_m_diff = None
895
787
  for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
896
788
  bkv_lst = strided_load_bkv(
897
789
  bkv_sem_idx,
898
790
  kv_head_start,
899
791
  num_kv_heads,
792
+ bkv_mask=bkv_mask,
900
793
  )
901
794
  assert len(bkv_lst) == kv_packing
902
795
  for i in range(kv_packing):
903
- cur_kv_head_idx = kv_head_start + i
904
- if cur_kv_head_idx >= actual_num_kv_heads:
796
+ kv_head_idx = kv_head_start + i
797
+ if kv_head_idx >= actual_num_kv_heads:
905
798
  break
906
- cur_kv_head_bq = load_bq(bq_sem_idx,
907
- cur_kv_head_idx,
908
- actual_bq_sz=actual_bq_sz)
909
- cur_kv_head__bkv = bkv_lst[i]
910
- # FlashAttention is divided into `flash_attention_step1_qk_softmax`
911
- # and `flash_attention_step2_pv` to pipeline the computation.
912
- # `step2_pv` for the previous KV head, which depends on the softmax
913
- # output, is overlapped with `step1_qk_softmax` for the current KV
914
- # head, reducing overall wait times.
915
- cur_kv_head_p, cur_kv_head_exp_m_diff = (
916
- flash_attention_step1_qk_softmax(
917
- cur_kv_head_bq,
918
- cur_kv_head__bkv,
919
- bq_idx=bq_idx,
920
- bkv_idx=bkv_idx,
921
- kv_head_idx=cur_kv_head_idx,
922
- ))
923
- if prev_bq_shape_0 is not None:
924
- flash_attention_step2_pv(
925
- prev_bq_shape_0,
926
- prev_kv_head_bkv,
927
- prev_kv_head_p,
928
- prev_kv_head_exp_m_diff,
929
- bkv_idx=bkv_idx,
930
- kv_head_idx=prev_kv_head_idx,
931
- )
932
- prev_bq_shape_0 = cur_kv_head_bq.shape[0]
933
- prev_kv_head_bkv = cur_kv_head__bkv
934
- prev_kv_head_p = cur_kv_head_p
935
- prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
936
- prev_kv_head_idx = cur_kv_head_idx
937
-
938
- # Execute pv of last attention head.
939
- assert prev_bq_shape_0 is not None
940
- flash_attention_step2_pv(
941
- prev_bq_shape_0,
942
- prev_kv_head_bkv,
943
- prev_kv_head_p,
944
- prev_kv_head_exp_m_diff,
945
- bkv_idx=bkv_idx,
946
- kv_head_idx=prev_kv_head_idx,
947
- )
948
-
949
- lax.fori_loop(bkv_idx_start,
950
- num_bkv,
951
- compute_with_bkv,
952
- None,
953
- unroll=False)
799
+ bq = load_bq(bq_sem_idx,
800
+ kv_head_idx,
801
+ actual_bq_sz=actual_bq_sz)
802
+ bkv = bkv_lst[i]
803
+ flash_attention(
804
+ bq,
805
+ bkv,
806
+ bq_idx=bq_idx,
807
+ bkv_idx=bkv_idx,
808
+ kv_head_idx=kv_head_idx,
809
+ )
810
+
811
+ lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
954
812
 
955
813
  # Load acc and calculate final output.
956
814
  acc = acc_ref[...]
@@ -980,7 +838,7 @@ def _ragged_paged_attention_kernel(
980
838
  @pl.when(seq_idx == 0)
981
839
  def prologue():
982
840
  start_fetch_bq(0, 0, 0)
983
- start_fetch_bkv(0, bkv_idx_start, 0, is_full_fetch=True)
841
+ start_fetch_bkv(0, 0, 0)
984
842
 
985
843
  @pl.when(seq_idx < decode_end)
986
844
  def process_decode():
@@ -1345,7 +1203,6 @@ def static_validate_inputs(
1345
1203
  static_argnames=(
1346
1204
  "sm_scale",
1347
1205
  "sliding_window",
1348
- "strict_sliding_window",
1349
1206
  "soft_cap",
1350
1207
  "mask_value",
1351
1208
  "q_scale",
@@ -1375,7 +1232,6 @@ def ragged_paged_attention_hd64(
1375
1232
  *,
1376
1233
  sm_scale: float = 1.0,
1377
1234
  sliding_window: int | None = None,
1378
- strict_sliding_window: bool = True,
1379
1235
  soft_cap: float | None = None,
1380
1236
  mask_value: float | None = DEFAULT_MASK_VALUE,
1381
1237
  q_scale: float | None = None,
@@ -1390,41 +1246,42 @@ def ragged_paged_attention_hd64(
1390
1246
  # Debug params.
1391
1247
  debug_mode: bool = False,
1392
1248
  ):
1393
- """A variant of ragged paged attention for head_dim=64.
1394
-
1395
- Args:
1396
- queries: concatenated all sequences' queries.
1397
- keys: concatenated all sequences' keys (quantized).
1398
- values: concatenated all sequences' values (quantized).
1399
- kv_cache: paged KV cache with TPU-friendly shape.
1400
- kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1401
- page_indices: flattened page indices look-up table by (seq_id, page_id).
1402
- cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1403
- kv_lens, only the first num_seqs+1 values are valid.
1404
- distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1405
- sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1406
- k is also the total number of sequences.
1407
- attention_sink: optional attention sink for each q head.
1408
- sm_scale: the softmax scale which will be applied to the Q@K^T.
1409
- sliding_window: the sliding window size for the attention.
1410
- strict_sliding_window: compute tokens that are strictly within the window.
1411
- soft_cap: the logit soft cap for the attention.
1412
- mask_value: mask value for causal mask.
1413
- q_scale: the scale for the query.
1414
- k_scale: the scale for the key cache.
1415
- v_scale: the scale for the value cache.
1416
- chunk_prefill_size: the chunk prefill size for the attention.
1417
- num_kv_pages_per_block: number of kv pages to be processed in one flash
1418
- attention block in the pallas kernel.
1419
- num_queries_per_block: number of kv pages to be processed in one flash
1420
- attention block in the pallas kernel.
1421
- vmem_limit_bytes: the vmem limit for the pallas kernel.
1422
- debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1423
- print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1424
-
1425
- Returns:
1426
- The output of the attention.
1427
- """
1249
+ """A special Ragged paged attention version for head_dim=64 that supports mixed
1250
+
1251
+ prefill and decode.
1252
+
1253
+ Args:
1254
+ queries: concatenated all sequences' queries.
1255
+ keys: concatenated all sequences' keys (quantized).
1256
+ values: concatenated all sequences' values (quantized).
1257
+ kv_cache: paged KV cache with TPU-friendly shape.
1258
+ kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1259
+ page_indices: flattened page indices look-up table by (seq_id, page_id).
1260
+ cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1261
+ kv_lens, only the first num_seqs+1 values are valid.
1262
+ distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1263
+ sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1264
+ k is also the total number of sequences.
1265
+ attention_sink: optional attention sink for each q head.
1266
+ actual_head_dim: the actual head size of the attention. Here we assume k and
1267
+ v have the same actual head size.
1268
+ sm_scale: the softmax scale which will be applied to the Q@K^T.
1269
+ sliding_window: the sliding window size for the attention.
1270
+ soft_cap: the logit soft cap for the attention.
1271
+ mask_value: mask value for causal mask.
1272
+ k_scale: the scale for the key cache.
1273
+ v_scale: the scale for the value cache.
1274
+ num_kv_pages_per_block: number of kv pages to be processed in one flash
1275
+ attention block in the pallas kernel.
1276
+ num_queries_per_block: number of kv pages to be processed in one flash
1277
+ attention block in the pallas kernel.
1278
+ vmem_limit_bytes: the vmem limit for the pallas kernel.
1279
+ debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1280
+ print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1281
+
1282
+ Returns:
1283
+ The output of the attention.
1284
+ """
1428
1285
  q, k, v = queries, keys, values
1429
1286
  static_validate_inputs(
1430
1287
  q,
@@ -1494,7 +1351,7 @@ def ragged_paged_attention_hd64(
1494
1351
  pl.BlockSpec(memory_space=pltpu.HBM),
1495
1352
  pl.BlockSpec(memory_space=pltpu.HBM),
1496
1353
  None if attention_sink is None else pl.BlockSpec(
1497
- memory_space=pltpu.VMEM),
1354
+ memory_space=pltpu.VMEM)
1498
1355
  ]
1499
1356
 
1500
1357
  out_specs = [
@@ -1558,7 +1415,6 @@ def ragged_paged_attention_hd64(
1558
1415
  _ragged_paged_attention_kernel,
1559
1416
  sm_scale=sm_scale,
1560
1417
  sliding_window=sliding_window,
1561
- strict_sliding_window=strict_sliding_window,
1562
1418
  soft_cap=soft_cap,
1563
1419
  mask_value=mask_value,
1564
1420
  q_scale=q_scale,
@@ -1231,13 +1231,6 @@ TUNED_BLOCK_SIZES = {
1231
1231
  },
1232
1232
  }
1233
1233
  },
1234
- 16: {
1235
- 'q_bfloat16_kv_bfloat16': {
1236
- 'q_head-8_kv_head-1_head-128': {
1237
- 262144: (128, 256),
1238
- }
1239
- }
1240
- },
1241
1234
  },
1242
1235
  'TPU v5e': {
1243
1236
  128: {
@@ -13,9 +13,9 @@ from tpu_inference import utils
13
13
  from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
14
14
  ragged_paged_attention
15
15
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
16
- from tpu_inference.layers.common.sharding import ShardingAxisName
17
16
  from tpu_inference.layers.jax.base import create_param
18
17
  from tpu_inference.layers.jax.rope_interface import apply_rope
18
+ from tpu_inference.layers.jax.sharding import ShardingAxisName
19
19
 
20
20
  KVCache = Tuple[jax.Array, jax.Array]
21
21