tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202511270815__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (49) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/lora/test_layers.py +0 -6
  3. tests/lora/utils.py +0 -8
  4. tpu_inference/__init__.py +22 -3
  5. tpu_inference/core/disagg_utils.py +6 -8
  6. tpu_inference/distributed/tpu_connector.py +2 -3
  7. tpu_inference/distributed/utils.py +3 -2
  8. tpu_inference/envs.py +1 -1
  9. tpu_inference/executors/ray_distributed_executor.py +27 -11
  10. tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  11. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  12. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +141 -107
  13. tpu_inference/layers/common/attention_interface.py +7 -1
  14. tpu_inference/layers/common/sharding.py +2 -1
  15. tpu_inference/layers/vllm/fused_moe.py +74 -25
  16. tpu_inference/layers/vllm/quantization/common.py +6 -1
  17. tpu_inference/layers/vllm/quantization/mxfp4.py +135 -61
  18. tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
  19. tpu_inference/layers/vllm/sharding.py +2 -2
  20. tpu_inference/lora/torch_punica_tpu.py +1 -2
  21. tpu_inference/models/common/model_loader.py +43 -11
  22. tpu_inference/models/jax/llama3.py +2 -1
  23. tpu_inference/models/jax/llama_eagle3.py +8 -5
  24. tpu_inference/models/jax/llama_guard_4.py +361 -0
  25. tpu_inference/models/jax/qwen2.py +2 -1
  26. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  27. tpu_inference/models/jax/qwen3.py +2 -1
  28. tpu_inference/models/jax/utils/weight_utils.py +198 -143
  29. tpu_inference/models/vllm/vllm_model_wrapper.py +13 -5
  30. tpu_inference/platforms/tpu_platform.py +15 -2
  31. tpu_inference/runner/compilation_manager.py +58 -33
  32. tpu_inference/runner/kv_cache_manager.py +9 -3
  33. tpu_inference/runner/structured_decoding_manager.py +2 -3
  34. tpu_inference/runner/tpu_runner.py +203 -102
  35. tpu_inference/spec_decode/jax/eagle3.py +19 -2
  36. tpu_inference/tpu_info.py +4 -3
  37. tpu_inference/utils.py +5 -4
  38. tpu_inference/worker/tpu_worker.py +160 -23
  39. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/METADATA +3 -2
  40. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/RECORD +43 -48
  41. tpu_inference/mock/__init__.py +0 -0
  42. tpu_inference/mock/vllm_config_utils.py +0 -28
  43. tpu_inference/mock/vllm_envs.py +0 -1219
  44. tpu_inference/mock/vllm_logger.py +0 -212
  45. tpu_inference/mock/vllm_logging_utils.py +0 -15
  46. tpu_inference/models/jax/phi3.py +0 -376
  47. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/WHEEL +0 -0
  48. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/licenses/LICENSE +0 -0
  49. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/top_level.txt +0 -0
@@ -440,42 +440,54 @@ def _ragged_paged_attention_kernel(
440
440
  debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
441
441
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
442
442
 
443
- # Fetch effective kv from kv cache.
444
- def loop_body(i, offset):
445
- sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
446
- _async_copy(
447
- cache_hbm_ref.at[pl.ds(
448
- page_indices_ref[page_indices_offset + i] * page_size,
449
- sz)],
450
- vmem_ref.at[pl.ds(i * page_size, sz)],
451
- sem,
452
- wait,
443
+ if not wait:
444
+ # Fetch effective kv from kv cache.
445
+ def loop_body(i, offset):
446
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
447
+ _async_copy(
448
+ cache_hbm_ref.at[pl.ds(
449
+ page_indices_ref[page_indices_offset + i] * page_size,
450
+ sz)],
451
+ vmem_ref.at[pl.ds(i * page_size, sz)],
452
+ sem,
453
+ wait=False,
454
+ )
455
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
456
+ return offset + sz
457
+
458
+ offset = lax.fori_loop(
459
+ 0,
460
+ bkv_p_frm_cache,
461
+ loop_body,
462
+ 0, # offset
463
+ unroll=False,
453
464
  )
454
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
455
- return offset + sz
456
-
457
- offset = lax.fori_loop(
458
- 0,
459
- bkv_p_frm_cache,
460
- loop_body,
461
- 0, # offset
462
- unroll=False,
463
- )
464
465
 
465
- # Fetch kv directly from new kv.
466
- @pl.when(bkv_sz_frm_new > 0)
467
- def _fetch_bkv_from_new_kv():
468
- new_kv_len_start = q_end - kv_left_frm_new
469
- debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
470
- debug_print("[RPA debug] offset_in_bkv={}", offset)
466
+ # Fetch kv directly from new kv.
467
+ @pl.when(bkv_sz_frm_new > 0)
468
+ def _fetch_bkv_from_new_kv():
469
+ new_kv_len_start = q_end - kv_left_frm_new
470
+ debug_print("[RPA debug] new_kv_len_start={}",
471
+ new_kv_len_start)
472
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
473
+ _async_copy(
474
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
475
+ vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
476
+ sem,
477
+ wait,
478
+ )
479
+
480
+ return kv_len_start + offset, bkv_sz_frm_new
481
+ else:
482
+ offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
483
+ dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
471
484
  _async_copy(
472
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
473
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
474
- sem,
475
- wait,
485
+ src=dst,
486
+ dst=dst,
487
+ sem=sem,
488
+ wait=True,
476
489
  )
477
-
478
- return kv_len_start + offset, bkv_sz_frm_new
490
+ return kv_len_start + offset, bkv_sz_frm_new
479
491
 
480
492
  def _update_kv_cache(seq_idx,
481
493
  bkv_sem_idx,
@@ -511,30 +523,41 @@ def _ragged_paged_attention_kernel(
511
523
  debug_print("[RPA debug] p_ignore={}", p_ignore)
512
524
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
513
525
 
514
- def loop_body(i, states):
515
- update_sz, ignore = states
516
- sz = jnp.minimum(page_size - ignore, update_sz)
517
-
526
+ if not wait:
527
+
528
+ def loop_body(i, states):
529
+ update_sz, ignore = states
530
+ sz = jnp.minimum(page_size - ignore, update_sz)
531
+
532
+ _async_copy(
533
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
534
+ sz)],
535
+ cache_hbm_ref.at[pl.ds(
536
+ page_indices_ref[page_indices_offset + i] * page_size +
537
+ ignore,
538
+ sz,
539
+ )],
540
+ sem,
541
+ wait=False,
542
+ )
543
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
544
+ return update_sz - sz, 0
545
+
546
+ lax.fori_loop(
547
+ 0,
548
+ kv_p_end - kv_p_start,
549
+ loop_body,
550
+ (update_sz, ignore), # total transfer size
551
+ unroll=False,
552
+ )
553
+ else:
554
+ dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
518
555
  _async_copy(
519
- vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
520
- cache_hbm_ref.at[pl.ds(
521
- page_indices_ref[page_indices_offset + i] * page_size +
522
- ignore,
523
- sz,
524
- )],
525
- sem,
526
- wait,
556
+ src=dst,
557
+ dst=dst,
558
+ sem=sem,
559
+ wait=True,
527
560
  )
528
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
529
- return update_sz - sz, 0
530
-
531
- lax.fori_loop(
532
- 0,
533
- kv_p_end - kv_p_start,
534
- loop_body,
535
- (update_sz, ignore), # total transfer size
536
- unroll=False,
537
- )
538
561
 
539
562
  def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
540
563
  sem = sems.at[1, bq_sem_idx]
@@ -267,6 +267,7 @@ def _ragged_paged_attention_kernel(
267
267
  *,
268
268
  sm_scale: float,
269
269
  sliding_window: int | None = None,
270
+ strict_sliding_window: bool = True,
270
271
  soft_cap: float | None = None,
271
272
  mask_value: float = DEFAULT_MASK_VALUE,
272
273
  q_scale: float | None = None,
@@ -317,19 +318,20 @@ def _ragged_paged_attention_kernel(
317
318
  q_len = q_end - q_start
318
319
  kv_len = kv_lens_ref[seq_idx]
319
320
 
320
- bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
321
- kv_len - sliding_window, 0) // bkv_sz
322
-
323
321
  if sliding_window is None:
324
- next_bkv_idx_start = 0
322
+ bkv_idx_start = next_seq_bkv_idx_start = 0
325
323
  else:
324
+ bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
325
+ 0) // bkv_sz
326
326
 
327
327
  def get_next_bkv_idx_start():
328
328
  next_kv_len = kv_lens_ref[seq_idx + 1]
329
- return jnp.maximum(next_kv_len - sliding_window, 0) // bkv_sz
329
+ next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
330
+ return jnp.maximum(next_kv_len - next_q_len - sliding_window,
331
+ 0) // bkv_sz
330
332
 
331
- next_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
332
- get_next_bkv_idx_start, lambda: 0)
333
+ next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
334
+ get_next_bkv_idx_start, lambda: 0)
333
335
 
334
336
  def debug_print(msg, *args):
335
337
  if debug_mode:
@@ -386,16 +388,19 @@ def _ragged_paged_attention_kernel(
386
388
  s *= k_scale
387
389
  if q_scale is not None:
388
390
  s *= q_scale
391
+ if soft_cap is not None:
392
+ s = soft_cap * jnp.tanh(s / soft_cap)
389
393
 
390
394
  q_span = (kv_len - q_len + bq_idx * bq_sz +
391
395
  lax.broadcasted_iota(jnp.int32, s.shape, 0) //
392
396
  num_q_heads_per_kv_head)
393
397
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
394
- mask = q_span < k_span
398
+ mask = k_span <= q_span
395
399
 
396
- if soft_cap is not None:
397
- s = soft_cap * jnp.tanh(s / soft_cap)
398
- s += jnp.where(mask, mask_value, 0.0)
400
+ if sliding_window is not None and strict_sliding_window:
401
+ mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
402
+
403
+ s = jnp.where(mask, s, mask_value)
399
404
  s_rowmax = jnp.max(s, axis=1, keepdims=True)
400
405
 
401
406
  if attention_sink_ref is not None:
@@ -475,42 +480,54 @@ def _ragged_paged_attention_kernel(
475
480
  debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
476
481
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
477
482
 
478
- # Fetch effective kv from kv cache.
479
- def loop_body(i, offset):
480
- sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
481
- _async_copy(
482
- cache_hbm_ref.at[pl.ds(
483
- page_indices_ref[page_indices_offset + i] * page_size,
484
- sz)],
485
- vmem_ref.at[pl.ds(i * page_size, sz)],
486
- sem,
487
- wait,
483
+ if not wait:
484
+ # Fetch effective kv from kv cache.
485
+ def loop_body(i, offset):
486
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
487
+ _async_copy(
488
+ cache_hbm_ref.at[pl.ds(
489
+ page_indices_ref[page_indices_offset + i] * page_size,
490
+ sz)],
491
+ vmem_ref.at[pl.ds(i * page_size, sz)],
492
+ sem,
493
+ wait=False,
494
+ )
495
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
496
+ return offset + sz
497
+
498
+ offset = lax.fori_loop(
499
+ 0,
500
+ bkv_p_frm_cache,
501
+ loop_body,
502
+ 0, # offset
503
+ unroll=False,
488
504
  )
489
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
490
- return offset + sz
491
-
492
- offset = lax.fori_loop(
493
- 0,
494
- bkv_p_frm_cache,
495
- loop_body,
496
- 0, # offset
497
- unroll=False,
498
- )
499
505
 
500
- # Fetch kv directly from new kv.
501
- @pl.when(bkv_sz_frm_new > 0)
502
- def _fetch_bkv_from_new_kv():
503
- new_kv_len_start = q_end - kv_left_frm_new
504
- debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
505
- debug_print("[RPA debug] offset_in_bkv={}", offset)
506
+ # Fetch kv directly from new kv.
507
+ @pl.when(bkv_sz_frm_new > 0)
508
+ def _fetch_bkv_from_new_kv():
509
+ new_kv_len_start = q_end - kv_left_frm_new
510
+ debug_print("[RPA debug] new_kv_len_start={}",
511
+ new_kv_len_start)
512
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
513
+ _async_copy(
514
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
515
+ vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
516
+ sem,
517
+ wait,
518
+ )
519
+
520
+ return kv_len_start + offset, bkv_sz_frm_new
521
+ else:
522
+ offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
523
+ dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
506
524
  _async_copy(
507
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
508
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
509
- sem,
510
- wait,
525
+ src=dst,
526
+ dst=dst,
527
+ sem=sem,
528
+ wait=True,
511
529
  )
512
-
513
- return kv_len_start + offset, bkv_sz_frm_new
530
+ return kv_len_start + offset, bkv_sz_frm_new
514
531
 
515
532
  def _update_kv_cache(seq_idx,
516
533
  bkv_sem_idx,
@@ -546,30 +563,41 @@ def _ragged_paged_attention_kernel(
546
563
  debug_print("[RPA debug] p_ignore={}", p_ignore)
547
564
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
548
565
 
549
- def loop_body(i, states):
550
- update_sz, ignore = states
551
- sz = jnp.minimum(page_size - ignore, update_sz)
552
-
566
+ if not wait:
567
+
568
+ def loop_body(i, states):
569
+ update_sz, ignore = states
570
+ sz = jnp.minimum(page_size - ignore, update_sz)
571
+
572
+ _async_copy(
573
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
574
+ sz)],
575
+ cache_hbm_ref.at[pl.ds(
576
+ page_indices_ref[page_indices_offset + i] * page_size +
577
+ ignore,
578
+ sz,
579
+ )],
580
+ sem,
581
+ wait=False,
582
+ )
583
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
584
+ return update_sz - sz, 0
585
+
586
+ lax.fori_loop(
587
+ 0,
588
+ kv_p_end - kv_p_start,
589
+ loop_body,
590
+ (update_sz, ignore), # total transfer size
591
+ unroll=False,
592
+ )
593
+ else:
594
+ dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
553
595
  _async_copy(
554
- vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
555
- cache_hbm_ref.at[pl.ds(
556
- page_indices_ref[page_indices_offset + i] * page_size +
557
- ignore,
558
- sz,
559
- )],
560
- sem,
561
- wait,
596
+ src=dst,
597
+ dst=dst,
598
+ sem=sem,
599
+ wait=True,
562
600
  )
563
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
564
- return update_sz - sz, 0
565
-
566
- lax.fori_loop(
567
- 0,
568
- kv_p_end - kv_p_start,
569
- loop_body,
570
- (update_sz, ignore), # total transfer size
571
- unroll=False,
572
- )
573
601
 
574
602
  def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
575
603
  sem = sems.at[1, bq_sem_idx]
@@ -737,13 +765,17 @@ def _ragged_paged_attention_kernel(
737
765
  next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
738
766
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
739
767
 
740
- next_bkv_idx = lax.select(
741
- is_last_bkv,
742
- lax.select(
768
+ if sliding_window is None:
769
+ next_bkv_start_idx = 0
770
+ else:
771
+ next_bkv_start_idx = lax.select(
743
772
  is_last_bq,
744
- next_bkv_idx_start,
773
+ next_seq_bkv_idx_start,
745
774
  bkv_idx_start,
746
- ), next_bkv_idx)
775
+ )
776
+ next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
777
+ next_bkv_idx)
778
+
747
779
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
748
780
 
749
781
  def compute_with_bq(bq_idx, _):
@@ -1226,6 +1258,7 @@ def static_validate_inputs(
1226
1258
  static_argnames=(
1227
1259
  "sm_scale",
1228
1260
  "sliding_window",
1261
+ "strict_sliding_window",
1229
1262
  "soft_cap",
1230
1263
  "mask_value",
1231
1264
  "q_scale",
@@ -1255,6 +1288,7 @@ def ragged_paged_attention_hd64(
1255
1288
  *,
1256
1289
  sm_scale: float = 1.0,
1257
1290
  sliding_window: int | None = None,
1291
+ strict_sliding_window: bool = True,
1258
1292
  soft_cap: float | None = None,
1259
1293
  mask_value: float | None = DEFAULT_MASK_VALUE,
1260
1294
  q_scale: float | None = None,
@@ -1269,42 +1303,41 @@ def ragged_paged_attention_hd64(
1269
1303
  # Debug params.
1270
1304
  debug_mode: bool = False,
1271
1305
  ):
1272
- """A special Ragged paged attention version for head_dim=64 that supports mixed
1273
-
1274
- prefill and decode.
1275
-
1276
- Args:
1277
- queries: concatenated all sequences' queries.
1278
- keys: concatenated all sequences' keys (quantized).
1279
- values: concatenated all sequences' values (quantized).
1280
- kv_cache: paged KV cache with TPU-friendly shape.
1281
- kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1282
- page_indices: flattened page indices look-up table by (seq_id, page_id).
1283
- cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1284
- kv_lens, only the first num_seqs+1 values are valid.
1285
- distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1286
- sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1287
- k is also the total number of sequences.
1288
- attention_sink: optional attention sink for each q head.
1289
- actual_head_dim: the actual head size of the attention. Here we assume k and
1290
- v have the same actual head size.
1291
- sm_scale: the softmax scale which will be applied to the Q@K^T.
1292
- sliding_window: the sliding window size for the attention.
1293
- soft_cap: the logit soft cap for the attention.
1294
- mask_value: mask value for causal mask.
1295
- k_scale: the scale for the key cache.
1296
- v_scale: the scale for the value cache.
1297
- num_kv_pages_per_block: number of kv pages to be processed in one flash
1298
- attention block in the pallas kernel.
1299
- num_queries_per_block: number of kv pages to be processed in one flash
1300
- attention block in the pallas kernel.
1301
- vmem_limit_bytes: the vmem limit for the pallas kernel.
1302
- debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1303
- print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1304
-
1305
- Returns:
1306
- The output of the attention.
1307
- """
1306
+ """A variant of ragged paged attention for head_dim=64.
1307
+
1308
+ Args:
1309
+ queries: concatenated all sequences' queries.
1310
+ keys: concatenated all sequences' keys (quantized).
1311
+ values: concatenated all sequences' values (quantized).
1312
+ kv_cache: paged KV cache with TPU-friendly shape.
1313
+ kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1314
+ page_indices: flattened page indices look-up table by (seq_id, page_id).
1315
+ cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1316
+ kv_lens, only the first num_seqs+1 values are valid.
1317
+ distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1318
+ sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1319
+ k is also the total number of sequences.
1320
+ attention_sink: optional attention sink for each q head.
1321
+ sm_scale: the softmax scale which will be applied to the Q@K^T.
1322
+ sliding_window: the sliding window size for the attention.
1323
+ strict_sliding_window: compute tokens that are strictly within the window.
1324
+ soft_cap: the logit soft cap for the attention.
1325
+ mask_value: mask value for causal mask.
1326
+ q_scale: the scale for the query.
1327
+ k_scale: the scale for the key cache.
1328
+ v_scale: the scale for the value cache.
1329
+ chunk_prefill_size: the chunk prefill size for the attention.
1330
+ num_kv_pages_per_block: number of kv pages to be processed in one flash
1331
+ attention block in the pallas kernel.
1332
+ num_queries_per_block: number of kv pages to be processed in one flash
1333
+ attention block in the pallas kernel.
1334
+ vmem_limit_bytes: the vmem limit for the pallas kernel.
1335
+ debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1336
+ print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1337
+
1338
+ Returns:
1339
+ The output of the attention.
1340
+ """
1308
1341
  q, k, v = queries, keys, values
1309
1342
  static_validate_inputs(
1310
1343
  q,
@@ -1374,7 +1407,7 @@ def ragged_paged_attention_hd64(
1374
1407
  pl.BlockSpec(memory_space=pltpu.HBM),
1375
1408
  pl.BlockSpec(memory_space=pltpu.HBM),
1376
1409
  None if attention_sink is None else pl.BlockSpec(
1377
- memory_space=pltpu.VMEM)
1410
+ memory_space=pltpu.VMEM),
1378
1411
  ]
1379
1412
 
1380
1413
  out_specs = [
@@ -1438,6 +1471,7 @@ def ragged_paged_attention_hd64(
1438
1471
  _ragged_paged_attention_kernel,
1439
1472
  sm_scale=sm_scale,
1440
1473
  sliding_window=sliding_window,
1474
+ strict_sliding_window=strict_sliding_window,
1441
1475
  soft_cap=soft_cap,
1442
1476
  mask_value=mask_value,
1443
1477
  q_scale=q_scale,
@@ -308,7 +308,13 @@ def sharded_ragged_paged_attention(
308
308
  args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
309
309
 
310
310
  use_hd64 = q.shape[-1] == 64
311
- func = ragged_paged_attention_hd64 if use_hd64 else ragged_paged_attention
311
+
312
+ func = ragged_paged_attention
313
+ if use_hd64:
314
+ func = functools.partial(ragged_paged_attention_hd64,
315
+ strict_sliding_window=False)
316
+ else:
317
+ func = ragged_paged_attention
312
318
 
313
319
  if attention_sink is not None:
314
320
  if not use_hd64:
@@ -166,9 +166,10 @@ class ShardingConfigManager:
166
166
  f"LoRA is not supported with data parallelism "
167
167
  f"(DP size: {total_dp_size}). Please disable LoRA or "
168
168
  f"set data parallelism to 1.")
169
+ if sharding_strategy.attention_data_parallelism > 1:
169
170
  if not os.environ.get("NEW_MODEL_DESIGN", False):
170
171
  raise ValueError(
171
- "Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
172
+ "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
172
173
  "NEW_MODEL_DESIGN=True.")
173
174
 
174
175
  @property