tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (56) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_envs.py +11 -32
  6. tests/test_utils.py +2 -1
  7. tpu_inference/__init__.py +3 -22
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +2 -9
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
  16. tpu_inference/layers/common/attention_interface.py +1 -7
  17. tpu_inference/layers/common/sharding.py +5 -5
  18. tpu_inference/layers/vllm/fused_moe.py +208 -170
  19. tpu_inference/layers/vllm/quantization/common.py +1 -6
  20. tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
  21. tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
  22. tpu_inference/layers/vllm/sharding.py +2 -2
  23. tpu_inference/lora/torch_punica_tpu.py +2 -1
  24. tpu_inference/mock/__init__.py +0 -0
  25. tpu_inference/mock/vllm_config_utils.py +28 -0
  26. tpu_inference/mock/vllm_envs.py +1219 -0
  27. tpu_inference/mock/vllm_logger.py +212 -0
  28. tpu_inference/mock/vllm_logging_utils.py +15 -0
  29. tpu_inference/models/common/model_loader.py +10 -43
  30. tpu_inference/models/jax/llama3.py +1 -2
  31. tpu_inference/models/jax/llama_eagle3.py +5 -8
  32. tpu_inference/models/jax/phi3.py +376 -0
  33. tpu_inference/models/jax/qwen2.py +1 -2
  34. tpu_inference/models/jax/qwen2_5_vl.py +48 -163
  35. tpu_inference/models/jax/qwen3.py +1 -2
  36. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  37. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  38. tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
  39. tpu_inference/platforms/tpu_platform.py +31 -37
  40. tpu_inference/runner/compilation_manager.py +58 -141
  41. tpu_inference/runner/kv_cache.py +1 -1
  42. tpu_inference/runner/kv_cache_manager.py +18 -17
  43. tpu_inference/runner/persistent_batch_manager.py +2 -40
  44. tpu_inference/runner/structured_decoding_manager.py +3 -2
  45. tpu_inference/runner/tpu_runner.py +147 -271
  46. tpu_inference/runner/utils.py +2 -2
  47. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  48. tpu_inference/tpu_info.py +3 -4
  49. tpu_inference/utils.py +13 -36
  50. tpu_inference/worker/tpu_worker.py +25 -162
  51. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
  52. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
  53. tpu_inference/models/jax/llama_guard_4.py +0 -361
  54. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
  55. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
  56. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
@@ -267,7 +267,6 @@ def _ragged_paged_attention_kernel(
267
267
  *,
268
268
  sm_scale: float,
269
269
  sliding_window: int | None = None,
270
- strict_sliding_window: bool = True,
271
270
  soft_cap: float | None = None,
272
271
  mask_value: float = DEFAULT_MASK_VALUE,
273
272
  q_scale: float | None = None,
@@ -318,20 +317,19 @@ def _ragged_paged_attention_kernel(
318
317
  q_len = q_end - q_start
319
318
  kv_len = kv_lens_ref[seq_idx]
320
319
 
320
+ bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
321
+ kv_len - sliding_window, 0) // bkv_sz
322
+
321
323
  if sliding_window is None:
322
- bkv_idx_start = next_seq_bkv_idx_start = 0
324
+ next_bkv_idx_start = 0
323
325
  else:
324
- bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
325
- 0) // bkv_sz
326
326
 
327
327
  def get_next_bkv_idx_start():
328
328
  next_kv_len = kv_lens_ref[seq_idx + 1]
329
- next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
330
- return jnp.maximum(next_kv_len - next_q_len - sliding_window,
331
- 0) // bkv_sz
329
+ return jnp.maximum(next_kv_len - sliding_window, 0) // bkv_sz
332
330
 
333
- next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
334
- get_next_bkv_idx_start, lambda: 0)
331
+ next_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
332
+ get_next_bkv_idx_start, lambda: 0)
335
333
 
336
334
  def debug_print(msg, *args):
337
335
  if debug_mode:
@@ -352,7 +350,7 @@ def _ragged_paged_attention_kernel(
352
350
  debug_print("[RPA debug] q_len={}", q_len)
353
351
  debug_print("[RPA debug] kv_len={}", kv_len)
354
352
 
355
- def flash_attention_step1_qk_softmax(
353
+ def flash_attention(
356
354
  q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
357
355
  kv, # [bkv_sz, actual_head_dim_x2]
358
356
  *,
@@ -366,6 +364,7 @@ def _ragged_paged_attention_kernel(
366
364
  assert kv.shape == (bkv_sz, actual_head_dim_x2)
367
365
  head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
368
366
  head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
367
+ head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
369
368
 
370
369
  def load_with_init(ref, init_val):
371
370
  return jnp.where(bkv_idx == bkv_idx_start,
@@ -387,19 +386,16 @@ def _ragged_paged_attention_kernel(
387
386
  s *= k_scale
388
387
  if q_scale is not None:
389
388
  s *= q_scale
390
- if soft_cap is not None:
391
- s = soft_cap * jnp.tanh(s / soft_cap)
392
389
 
393
390
  q_span = (kv_len - q_len + bq_idx * bq_sz +
394
391
  lax.broadcasted_iota(jnp.int32, s.shape, 0) //
395
392
  num_q_heads_per_kv_head)
396
393
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
397
- mask = k_span <= q_span
394
+ mask = q_span < k_span
398
395
 
399
- if sliding_window is not None and strict_sliding_window:
400
- mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
401
-
402
- s = jnp.where(mask, s, mask_value)
396
+ if soft_cap is not None:
397
+ s = soft_cap * jnp.tanh(s / soft_cap)
398
+ s += jnp.where(mask, mask_value, 0.0)
403
399
  s_rowmax = jnp.max(s, axis=1, keepdims=True)
404
400
 
405
401
  if attention_sink_ref is not None:
@@ -415,33 +411,15 @@ def _ragged_paged_attention_kernel(
415
411
  head_m_ref[...] = m_curr
416
412
  p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
417
413
 
414
+ pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
415
+ if v_scale is not None:
416
+ pv *= v_scale
417
+
418
418
  p_rowsum = jnp.sum(p, axis=1, keepdims=True)
419
419
  exp_m_diff = jnp.exp(m_prev - m_curr)
420
420
  l_prev = load_with_init(head_l_ref, 1.0)
421
421
  l_curr = exp_m_diff * l_prev + p_rowsum
422
422
  head_l_ref[...] = l_curr
423
-
424
- return p, exp_m_diff
425
-
426
- def flash_attention_step2_pv(
427
- q_shape_0,
428
- kv, # [bkv_sz, actual_head_dim_x2]
429
- p, # from step1
430
- exp_m_diff, # from step1
431
- *,
432
- bkv_idx,
433
- kv_head_idx,
434
- ):
435
- head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
436
-
437
- def load_with_init(ref, init_val):
438
- return jnp.where(bkv_idx == bkv_idx_start,
439
- jnp.full_like(ref, init_val), ref[...])
440
-
441
- pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
442
- if v_scale is not None:
443
- pv *= v_scale
444
-
445
423
  o_prev = load_with_init(head_acc_ref, 0.0)
446
424
  o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
447
425
  head_acc_ref[...] = o_curr
@@ -456,12 +434,7 @@ def _ragged_paged_attention_kernel(
456
434
  else:
457
435
  cp.start()
458
436
 
459
- def _fetch_bkv(seq_idx,
460
- bkv_idx,
461
- bkv_sem_idx,
462
- *,
463
- is_full_fetch=False,
464
- wait=False):
437
+ def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
465
438
  sem = sems.at[0, bkv_sem_idx]
466
439
  vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
467
440
 
@@ -502,73 +475,42 @@ def _ragged_paged_attention_kernel(
502
475
  debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
503
476
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
504
477
 
505
- if not wait:
506
- # Fetch effective kv from kv cache.
507
- def loop_body(i, offset):
508
- sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
509
- _async_copy(
510
- cache_hbm_ref.at[pl.ds(
511
- page_indices_ref[page_indices_offset + i] * page_size,
512
- sz)],
513
- vmem_ref.at[pl.ds(i * page_size, sz)],
514
- sem,
515
- wait=False,
516
- )
517
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
518
- return offset + sz
519
-
520
- offset = lax.fori_loop(
521
- 0,
522
- bkv_p_frm_cache,
523
- loop_body,
524
- 0, # offset
525
- unroll=False,
478
+ # Fetch effective kv from kv cache.
479
+ def loop_body(i, offset):
480
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
481
+ _async_copy(
482
+ cache_hbm_ref.at[pl.ds(
483
+ page_indices_ref[page_indices_offset + i] * page_size,
484
+ sz)],
485
+ vmem_ref.at[pl.ds(i * page_size, sz)],
486
+ sem,
487
+ wait,
526
488
  )
489
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
490
+ return offset + sz
491
+
492
+ offset = lax.fori_loop(
493
+ 0,
494
+ bkv_p_frm_cache,
495
+ loop_body,
496
+ 0, # offset
497
+ unroll=False,
498
+ )
527
499
 
528
- # Fetch kv directly from new kv.
529
- @pl.when(bkv_sz_frm_new > 0)
530
- def _fetch_bkv_from_new_kv():
531
- new_kv_len_start = q_end - kv_left_frm_new
532
- debug_print("[RPA debug] new_kv_len_start={}",
533
- new_kv_len_start)
534
- debug_print("[RPA debug] offset_in_bkv={}", offset)
535
- _async_copy(
536
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
537
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
538
- sem,
539
- wait,
540
- )
541
-
542
- # NOTE(chengjiyao): This condition is true for the first two bkv fetches.
543
- # We need to ensure the bkv_x2_ref VMEM buffer is fully initialized to
544
- # avoid potential NaN values in regions not overwritten by actual data.
545
- # This is done by padding the remaining parts of the buffer with data
546
- # from the KV cache. This special handling is only strictly necessary
547
- # until both buffers in the double buffer (bkv_x2_ref) have been written
548
- # to at least once.
549
- @pl.when(is_full_fetch)
550
- def _make_sure_bkv_vmem_is_not_nan():
551
- effective_sz = offset + bkv_sz_frm_new
552
- remaining_sz = bkv_sz - effective_sz
553
- _async_copy(
554
- cache_hbm_ref.at[pl.ds(0, remaining_sz)],
555
- vmem_ref.at[pl.ds(effective_sz, remaining_sz)],
556
- sem,
557
- wait,
558
- )
559
-
560
- return kv_len_start + offset, bkv_sz_frm_new
561
- else:
562
- offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
563
- sz = lax.select(is_full_fetch, bkv_sz, offset + bkv_sz_frm_new)
564
- dst = vmem_ref.at[pl.ds(0, sz)]
500
+ # Fetch kv directly from new kv.
501
+ @pl.when(bkv_sz_frm_new > 0)
502
+ def _fetch_bkv_from_new_kv():
503
+ new_kv_len_start = q_end - kv_left_frm_new
504
+ debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
505
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
565
506
  _async_copy(
566
- src=dst,
567
- dst=dst,
568
- sem=sem,
569
- wait=True,
507
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
508
+ vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
509
+ sem,
510
+ wait,
570
511
  )
571
- return kv_len_start + offset, bkv_sz_frm_new
512
+
513
+ return kv_len_start + offset, bkv_sz_frm_new
572
514
 
573
515
  def _update_kv_cache(seq_idx,
574
516
  bkv_sem_idx,
@@ -604,41 +546,30 @@ def _ragged_paged_attention_kernel(
604
546
  debug_print("[RPA debug] p_ignore={}", p_ignore)
605
547
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
606
548
 
607
- if not wait:
608
-
609
- def loop_body(i, states):
610
- update_sz, ignore = states
611
- sz = jnp.minimum(page_size - ignore, update_sz)
612
-
613
- _async_copy(
614
- vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
615
- sz)],
616
- cache_hbm_ref.at[pl.ds(
617
- page_indices_ref[page_indices_offset + i] * page_size +
618
- ignore,
619
- sz,
620
- )],
621
- sem,
622
- wait=False,
623
- )
624
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
625
- return update_sz - sz, 0
626
-
627
- lax.fori_loop(
628
- 0,
629
- kv_p_end - kv_p_start,
630
- loop_body,
631
- (update_sz, ignore), # total transfer size
632
- unroll=False,
633
- )
634
- else:
635
- dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
549
+ def loop_body(i, states):
550
+ update_sz, ignore = states
551
+ sz = jnp.minimum(page_size - ignore, update_sz)
552
+
636
553
  _async_copy(
637
- src=dst,
638
- dst=dst,
639
- sem=sem,
640
- wait=True,
554
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
555
+ cache_hbm_ref.at[pl.ds(
556
+ page_indices_ref[page_indices_offset + i] * page_size +
557
+ ignore,
558
+ sz,
559
+ )],
560
+ sem,
561
+ wait,
641
562
  )
563
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
564
+ return update_sz - sz, 0
565
+
566
+ lax.fori_loop(
567
+ 0,
568
+ kv_p_end - kv_p_start,
569
+ loop_body,
570
+ (update_sz, ignore), # total transfer size
571
+ unroll=False,
572
+ )
642
573
 
643
574
  def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
644
575
  sem = sems.at[1, bq_sem_idx]
@@ -688,18 +619,11 @@ def _ragged_paged_attention_kernel(
688
619
  wait,
689
620
  )
690
621
 
691
- def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, is_full_fetch=False):
692
- return _fetch_bkv(seq_idx,
693
- bkv_idx,
694
- bkv_sem_idx,
695
- is_full_fetch=is_full_fetch)
622
+ def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
623
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
696
624
 
697
- def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, is_full_fetch=False):
698
- return _fetch_bkv(seq_idx,
699
- bkv_idx,
700
- bkv_sem_idx,
701
- is_full_fetch=is_full_fetch,
702
- wait=True)
625
+ def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
626
+ return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, wait=True)
703
627
 
704
628
  def start_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
705
629
  return _fetch_bq(seq_idx, bq_idx, bq_sem_idx)
@@ -757,7 +681,7 @@ def _ragged_paged_attention_kernel(
757
681
  vec = ref[start::step]
758
682
  return vec
759
683
 
760
- def strided_load_bkv(bkv_sem_idx, start, step):
684
+ def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
761
685
  assert start % kv_packing == 0
762
686
  assert step % kv_packing == 0
763
687
  start //= kv_packing
@@ -766,6 +690,7 @@ def _ragged_paged_attention_kernel(
766
690
  bkv_sz * step, actual_head_dim_x2))
767
691
 
768
692
  kv = strided_load(kv_ref, start, step)
693
+ kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
769
694
  bitwidth = 32 // kv_packing
770
695
  repack_ty = jnp.dtype(f"uint{bitwidth}")
771
696
  lst = []
@@ -812,17 +737,13 @@ def _ragged_paged_attention_kernel(
812
737
  next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
813
738
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
814
739
 
815
- if sliding_window is None:
816
- next_bkv_start_idx = 0
817
- else:
818
- next_bkv_start_idx = lax.select(
740
+ next_bkv_idx = lax.select(
741
+ is_last_bkv,
742
+ lax.select(
819
743
  is_last_bq,
820
- next_seq_bkv_idx_start,
744
+ next_bkv_idx_start,
821
745
  bkv_idx_start,
822
- )
823
- next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
824
- next_bkv_idx)
825
-
746
+ ), next_bkv_idx)
826
747
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
827
748
 
828
749
  def compute_with_bq(bq_idx, _):
@@ -839,23 +760,22 @@ def _ragged_paged_attention_kernel(
839
760
  def compute_with_bkv(bkv_idx, _):
840
761
  # Create bitmask for KV.
841
762
  assert bkv_sz % kv_packing == 0
763
+ actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
764
+ bkv_shape = (bkv_sz, actual_head_dim_x2)
765
+ bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
766
+ 0) < actual_bkv_sz
842
767
 
843
768
  # Get next bkv ids.
844
769
  bkv_sem_idx = sem_ids_ref[1]
845
- next_seq_idx, next_bq_idx_for_kv, next_bkv_idx, next_bkv_sem_idx = (
846
- get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx))
770
+ next_seq_idx, _, next_bkv_idx, next_bkv_sem_idx = get_next_bkv_ids(
771
+ seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
847
772
 
848
773
  # Prefetch next bkv
849
774
  @pl.when(next_seq_idx < num_seqs)
850
775
  def prefetch_next_bkv():
851
776
  sem_ids_ref[1] = next_bkv_sem_idx
852
- start_fetch_bkv(
853
- next_seq_idx,
854
- next_bkv_idx,
855
- next_bkv_sem_idx,
856
- is_full_fetch=next_seq_idx + next_bq_idx_for_kv +
857
- next_bkv_idx < 2,
858
- )
777
+ start_fetch_bkv(next_seq_idx, next_bkv_idx,
778
+ next_bkv_sem_idx)
859
779
 
860
780
  # Wait for cur bq if not ready yet
861
781
  @pl.when(bkv_idx == bkv_idx_start)
@@ -863,12 +783,8 @@ def _ragged_paged_attention_kernel(
863
783
  wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
864
784
 
865
785
  # Wait for cur bkv
866
- offset, update_sz = wait_fetch_bkv(
867
- seq_idx,
868
- bkv_idx,
869
- bkv_sem_idx,
870
- is_full_fetch=seq_idx + bq_idx + bkv_idx < 2,
871
- )
786
+ offset, update_sz = wait_fetch_bkv(seq_idx, bkv_idx,
787
+ bkv_sem_idx)
872
788
 
873
789
  # Start updating bkv to kv cache if applicable.
874
790
  # Only needed in first bq loop.
@@ -887,64 +803,29 @@ def _ragged_paged_attention_kernel(
887
803
  return
888
804
 
889
805
  # Flash attention with cur bkv and bq
890
- prev_bq_shape_0 = None
891
- prev_kv_head_bkv = None
892
- prev_kv_head_idx = None
893
- prev_kv_head_p = None
894
- prev_kv_head_exp_m_diff = None
895
806
  for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
896
807
  bkv_lst = strided_load_bkv(
897
808
  bkv_sem_idx,
898
809
  kv_head_start,
899
810
  num_kv_heads,
811
+ bkv_mask=bkv_mask,
900
812
  )
901
813
  assert len(bkv_lst) == kv_packing
902
814
  for i in range(kv_packing):
903
- cur_kv_head_idx = kv_head_start + i
904
- if cur_kv_head_idx >= actual_num_kv_heads:
815
+ kv_head_idx = kv_head_start + i
816
+ if kv_head_idx >= actual_num_kv_heads:
905
817
  break
906
- cur_kv_head_bq = load_bq(bq_sem_idx,
907
- cur_kv_head_idx,
908
- actual_bq_sz=actual_bq_sz)
909
- cur_kv_head__bkv = bkv_lst[i]
910
- # FlashAttention is divided into `flash_attention_step1_qk_softmax`
911
- # and `flash_attention_step2_pv` to pipeline the computation.
912
- # `step2_pv` for the previous KV head, which depends on the softmax
913
- # output, is overlapped with `step1_qk_softmax` for the current KV
914
- # head, reducing overall wait times.
915
- cur_kv_head_p, cur_kv_head_exp_m_diff = (
916
- flash_attention_step1_qk_softmax(
917
- cur_kv_head_bq,
918
- cur_kv_head__bkv,
919
- bq_idx=bq_idx,
920
- bkv_idx=bkv_idx,
921
- kv_head_idx=cur_kv_head_idx,
922
- ))
923
- if prev_bq_shape_0 is not None:
924
- flash_attention_step2_pv(
925
- prev_bq_shape_0,
926
- prev_kv_head_bkv,
927
- prev_kv_head_p,
928
- prev_kv_head_exp_m_diff,
929
- bkv_idx=bkv_idx,
930
- kv_head_idx=prev_kv_head_idx,
931
- )
932
- prev_bq_shape_0 = cur_kv_head_bq.shape[0]
933
- prev_kv_head_bkv = cur_kv_head__bkv
934
- prev_kv_head_p = cur_kv_head_p
935
- prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
936
- prev_kv_head_idx = cur_kv_head_idx
937
-
938
- # Execute pv of last attention head.
939
- assert prev_bq_shape_0 is not None
940
- flash_attention_step2_pv(
941
- prev_bq_shape_0,
942
- prev_kv_head_bkv,
943
- prev_kv_head_p,
944
- prev_kv_head_exp_m_diff,
945
- bkv_idx=bkv_idx,
946
- kv_head_idx=prev_kv_head_idx,
947
- )
818
+ bq = load_bq(bq_sem_idx,
819
+ kv_head_idx,
820
+ actual_bq_sz=actual_bq_sz)
821
+ bkv = bkv_lst[i]
822
+ flash_attention(
823
+ bq,
824
+ bkv,
825
+ bq_idx=bq_idx,
826
+ bkv_idx=bkv_idx,
827
+ kv_head_idx=kv_head_idx,
828
+ )
948
829
 
949
830
  lax.fori_loop(bkv_idx_start,
950
831
  num_bkv,
@@ -980,7 +861,7 @@ def _ragged_paged_attention_kernel(
980
861
  @pl.when(seq_idx == 0)
981
862
  def prologue():
982
863
  start_fetch_bq(0, 0, 0)
983
- start_fetch_bkv(0, bkv_idx_start, 0, is_full_fetch=True)
864
+ start_fetch_bkv(0, bkv_idx_start, 0)
984
865
 
985
866
  @pl.when(seq_idx < decode_end)
986
867
  def process_decode():
@@ -1345,7 +1226,6 @@ def static_validate_inputs(
1345
1226
  static_argnames=(
1346
1227
  "sm_scale",
1347
1228
  "sliding_window",
1348
- "strict_sliding_window",
1349
1229
  "soft_cap",
1350
1230
  "mask_value",
1351
1231
  "q_scale",
@@ -1375,7 +1255,6 @@ def ragged_paged_attention_hd64(
1375
1255
  *,
1376
1256
  sm_scale: float = 1.0,
1377
1257
  sliding_window: int | None = None,
1378
- strict_sliding_window: bool = True,
1379
1258
  soft_cap: float | None = None,
1380
1259
  mask_value: float | None = DEFAULT_MASK_VALUE,
1381
1260
  q_scale: float | None = None,
@@ -1390,41 +1269,42 @@ def ragged_paged_attention_hd64(
1390
1269
  # Debug params.
1391
1270
  debug_mode: bool = False,
1392
1271
  ):
1393
- """A variant of ragged paged attention for head_dim=64.
1394
-
1395
- Args:
1396
- queries: concatenated all sequences' queries.
1397
- keys: concatenated all sequences' keys (quantized).
1398
- values: concatenated all sequences' values (quantized).
1399
- kv_cache: paged KV cache with TPU-friendly shape.
1400
- kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1401
- page_indices: flattened page indices look-up table by (seq_id, page_id).
1402
- cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1403
- kv_lens, only the first num_seqs+1 values are valid.
1404
- distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1405
- sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1406
- k is also the total number of sequences.
1407
- attention_sink: optional attention sink for each q head.
1408
- sm_scale: the softmax scale which will be applied to the Q@K^T.
1409
- sliding_window: the sliding window size for the attention.
1410
- strict_sliding_window: compute tokens that are strictly within the window.
1411
- soft_cap: the logit soft cap for the attention.
1412
- mask_value: mask value for causal mask.
1413
- q_scale: the scale for the query.
1414
- k_scale: the scale for the key cache.
1415
- v_scale: the scale for the value cache.
1416
- chunk_prefill_size: the chunk prefill size for the attention.
1417
- num_kv_pages_per_block: number of kv pages to be processed in one flash
1418
- attention block in the pallas kernel.
1419
- num_queries_per_block: number of kv pages to be processed in one flash
1420
- attention block in the pallas kernel.
1421
- vmem_limit_bytes: the vmem limit for the pallas kernel.
1422
- debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1423
- print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1424
-
1425
- Returns:
1426
- The output of the attention.
1427
- """
1272
+ """A special Ragged paged attention version for head_dim=64 that supports mixed
1273
+
1274
+ prefill and decode.
1275
+
1276
+ Args:
1277
+ queries: concatenated all sequences' queries.
1278
+ keys: concatenated all sequences' keys (quantized).
1279
+ values: concatenated all sequences' values (quantized).
1280
+ kv_cache: paged KV cache with TPU-friendly shape.
1281
+ kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1282
+ page_indices: flattened page indices look-up table by (seq_id, page_id).
1283
+ cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1284
+ kv_lens, only the first num_seqs+1 values are valid.
1285
+ distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1286
+ sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1287
+ k is also the total number of sequences.
1288
+ attention_sink: optional attention sink for each q head.
1289
+ actual_head_dim: the actual head size of the attention. Here we assume k and
1290
+ v have the same actual head size.
1291
+ sm_scale: the softmax scale which will be applied to the Q@K^T.
1292
+ sliding_window: the sliding window size for the attention.
1293
+ soft_cap: the logit soft cap for the attention.
1294
+ mask_value: mask value for causal mask.
1295
+ k_scale: the scale for the key cache.
1296
+ v_scale: the scale for the value cache.
1297
+ num_kv_pages_per_block: number of kv pages to be processed in one flash
1298
+ attention block in the pallas kernel.
1299
+ num_queries_per_block: number of kv pages to be processed in one flash
1300
+ attention block in the pallas kernel.
1301
+ vmem_limit_bytes: the vmem limit for the pallas kernel.
1302
+ debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1303
+ print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1304
+
1305
+ Returns:
1306
+ The output of the attention.
1307
+ """
1428
1308
  q, k, v = queries, keys, values
1429
1309
  static_validate_inputs(
1430
1310
  q,
@@ -1494,7 +1374,7 @@ def ragged_paged_attention_hd64(
1494
1374
  pl.BlockSpec(memory_space=pltpu.HBM),
1495
1375
  pl.BlockSpec(memory_space=pltpu.HBM),
1496
1376
  None if attention_sink is None else pl.BlockSpec(
1497
- memory_space=pltpu.VMEM),
1377
+ memory_space=pltpu.VMEM)
1498
1378
  ]
1499
1379
 
1500
1380
  out_specs = [
@@ -1558,7 +1438,6 @@ def ragged_paged_attention_hd64(
1558
1438
  _ragged_paged_attention_kernel,
1559
1439
  sm_scale=sm_scale,
1560
1440
  sliding_window=sliding_window,
1561
- strict_sliding_window=strict_sliding_window,
1562
1441
  soft_cap=soft_cap,
1563
1442
  mask_value=mask_value,
1564
1443
  q_scale=q_scale,
@@ -308,13 +308,7 @@ def sharded_ragged_paged_attention(
308
308
  args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
309
309
 
310
310
  use_hd64 = q.shape[-1] == 64
311
-
312
- func = ragged_paged_attention
313
- if use_hd64:
314
- func = functools.partial(ragged_paged_attention_hd64,
315
- strict_sliding_window=False)
316
- else:
317
- func = ragged_paged_attention
311
+ func = ragged_paged_attention_hd64 if use_hd64 else ragged_paged_attention
318
312
 
319
313
  if attention_sink is not None:
320
314
  if not use_hd64:
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import math
3
+ import os
3
4
  from dataclasses import asdict, dataclass
4
5
  from typing import TYPE_CHECKING, List, Optional
5
6
 
@@ -7,7 +8,7 @@ import jax.numpy as jnp
7
8
  import numpy as np
8
9
  from jax.sharding import Mesh
9
10
 
10
- from tpu_inference import envs, utils
11
+ from tpu_inference import utils
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from vllm.v1.configs.vllm_config import VllmConfig
@@ -47,7 +48,7 @@ class ShardingAxisName2D:
47
48
 
48
49
 
49
50
  try:
50
- _use_base_sharding = envs.NEW_MODEL_DESIGN
51
+ _use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False)
51
52
  if _use_base_sharding:
52
53
  ShardingAxisName = ShardingAxisNameBase
53
54
  else:
@@ -165,10 +166,9 @@ class ShardingConfigManager:
165
166
  f"LoRA is not supported with data parallelism "
166
167
  f"(DP size: {total_dp_size}). Please disable LoRA or "
167
168
  f"set data parallelism to 1.")
168
- if sharding_strategy.attention_data_parallelism > 1:
169
- if not envs.NEW_MODEL_DESIGN:
169
+ if not os.environ.get("NEW_MODEL_DESIGN", False):
170
170
  raise ValueError(
171
- "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
171
+ "Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
172
172
  "NEW_MODEL_DESIGN=True.")
173
173
 
174
174
  @property