tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202511270815__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +27 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +141 -107
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +2 -1
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +135 -61
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +43 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +13 -5
- tpu_inference/platforms/tpu_platform.py +15 -2
- tpu_inference/runner/compilation_manager.py +58 -33
- tpu_inference/runner/kv_cache_manager.py +9 -3
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +203 -102
- tpu_inference/spec_decode/jax/eagle3.py +19 -2
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +5 -4
- tpu_inference/worker/tpu_worker.py +160 -23
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/RECORD +43 -48
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/top_level.txt +0 -0
|
@@ -440,42 +440,54 @@ def _ragged_paged_attention_kernel(
|
|
|
440
440
|
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
441
441
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
442
442
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
443
|
+
if not wait:
|
|
444
|
+
# Fetch effective kv from kv cache.
|
|
445
|
+
def loop_body(i, offset):
|
|
446
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
447
|
+
_async_copy(
|
|
448
|
+
cache_hbm_ref.at[pl.ds(
|
|
449
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
450
|
+
sz)],
|
|
451
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
452
|
+
sem,
|
|
453
|
+
wait=False,
|
|
454
|
+
)
|
|
455
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
456
|
+
return offset + sz
|
|
457
|
+
|
|
458
|
+
offset = lax.fori_loop(
|
|
459
|
+
0,
|
|
460
|
+
bkv_p_frm_cache,
|
|
461
|
+
loop_body,
|
|
462
|
+
0, # offset
|
|
463
|
+
unroll=False,
|
|
453
464
|
)
|
|
454
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
455
|
-
return offset + sz
|
|
456
|
-
|
|
457
|
-
offset = lax.fori_loop(
|
|
458
|
-
0,
|
|
459
|
-
bkv_p_frm_cache,
|
|
460
|
-
loop_body,
|
|
461
|
-
0, # offset
|
|
462
|
-
unroll=False,
|
|
463
|
-
)
|
|
464
465
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
466
|
+
# Fetch kv directly from new kv.
|
|
467
|
+
@pl.when(bkv_sz_frm_new > 0)
|
|
468
|
+
def _fetch_bkv_from_new_kv():
|
|
469
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
470
|
+
debug_print("[RPA debug] new_kv_len_start={}",
|
|
471
|
+
new_kv_len_start)
|
|
472
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
473
|
+
_async_copy(
|
|
474
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
475
|
+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
476
|
+
sem,
|
|
477
|
+
wait,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
481
|
+
else:
|
|
482
|
+
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
483
|
+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
|
|
471
484
|
_async_copy(
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
sem,
|
|
475
|
-
wait,
|
|
485
|
+
src=dst,
|
|
486
|
+
dst=dst,
|
|
487
|
+
sem=sem,
|
|
488
|
+
wait=True,
|
|
476
489
|
)
|
|
477
|
-
|
|
478
|
-
return kv_len_start + offset, bkv_sz_frm_new
|
|
490
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
479
491
|
|
|
480
492
|
def _update_kv_cache(seq_idx,
|
|
481
493
|
bkv_sem_idx,
|
|
@@ -511,30 +523,41 @@ def _ragged_paged_attention_kernel(
|
|
|
511
523
|
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
512
524
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
513
525
|
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
526
|
+
if not wait:
|
|
527
|
+
|
|
528
|
+
def loop_body(i, states):
|
|
529
|
+
update_sz, ignore = states
|
|
530
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
531
|
+
|
|
532
|
+
_async_copy(
|
|
533
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
534
|
+
sz)],
|
|
535
|
+
cache_hbm_ref.at[pl.ds(
|
|
536
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
537
|
+
ignore,
|
|
538
|
+
sz,
|
|
539
|
+
)],
|
|
540
|
+
sem,
|
|
541
|
+
wait=False,
|
|
542
|
+
)
|
|
543
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
544
|
+
return update_sz - sz, 0
|
|
545
|
+
|
|
546
|
+
lax.fori_loop(
|
|
547
|
+
0,
|
|
548
|
+
kv_p_end - kv_p_start,
|
|
549
|
+
loop_body,
|
|
550
|
+
(update_sz, ignore), # total transfer size
|
|
551
|
+
unroll=False,
|
|
552
|
+
)
|
|
553
|
+
else:
|
|
554
|
+
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
518
555
|
_async_copy(
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
sz,
|
|
524
|
-
)],
|
|
525
|
-
sem,
|
|
526
|
-
wait,
|
|
556
|
+
src=dst,
|
|
557
|
+
dst=dst,
|
|
558
|
+
sem=sem,
|
|
559
|
+
wait=True,
|
|
527
560
|
)
|
|
528
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
529
|
-
return update_sz - sz, 0
|
|
530
|
-
|
|
531
|
-
lax.fori_loop(
|
|
532
|
-
0,
|
|
533
|
-
kv_p_end - kv_p_start,
|
|
534
|
-
loop_body,
|
|
535
|
-
(update_sz, ignore), # total transfer size
|
|
536
|
-
unroll=False,
|
|
537
|
-
)
|
|
538
561
|
|
|
539
562
|
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
540
563
|
sem = sems.at[1, bq_sem_idx]
|
|
@@ -267,6 +267,7 @@ def _ragged_paged_attention_kernel(
|
|
|
267
267
|
*,
|
|
268
268
|
sm_scale: float,
|
|
269
269
|
sliding_window: int | None = None,
|
|
270
|
+
strict_sliding_window: bool = True,
|
|
270
271
|
soft_cap: float | None = None,
|
|
271
272
|
mask_value: float = DEFAULT_MASK_VALUE,
|
|
272
273
|
q_scale: float | None = None,
|
|
@@ -317,19 +318,20 @@ def _ragged_paged_attention_kernel(
|
|
|
317
318
|
q_len = q_end - q_start
|
|
318
319
|
kv_len = kv_lens_ref[seq_idx]
|
|
319
320
|
|
|
320
|
-
bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
|
|
321
|
-
kv_len - sliding_window, 0) // bkv_sz
|
|
322
|
-
|
|
323
321
|
if sliding_window is None:
|
|
324
|
-
|
|
322
|
+
bkv_idx_start = next_seq_bkv_idx_start = 0
|
|
325
323
|
else:
|
|
324
|
+
bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
|
|
325
|
+
0) // bkv_sz
|
|
326
326
|
|
|
327
327
|
def get_next_bkv_idx_start():
|
|
328
328
|
next_kv_len = kv_lens_ref[seq_idx + 1]
|
|
329
|
-
|
|
329
|
+
next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
|
|
330
|
+
return jnp.maximum(next_kv_len - next_q_len - sliding_window,
|
|
331
|
+
0) // bkv_sz
|
|
330
332
|
|
|
331
|
-
|
|
332
|
-
|
|
333
|
+
next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
|
|
334
|
+
get_next_bkv_idx_start, lambda: 0)
|
|
333
335
|
|
|
334
336
|
def debug_print(msg, *args):
|
|
335
337
|
if debug_mode:
|
|
@@ -386,16 +388,19 @@ def _ragged_paged_attention_kernel(
|
|
|
386
388
|
s *= k_scale
|
|
387
389
|
if q_scale is not None:
|
|
388
390
|
s *= q_scale
|
|
391
|
+
if soft_cap is not None:
|
|
392
|
+
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
389
393
|
|
|
390
394
|
q_span = (kv_len - q_len + bq_idx * bq_sz +
|
|
391
395
|
lax.broadcasted_iota(jnp.int32, s.shape, 0) //
|
|
392
396
|
num_q_heads_per_kv_head)
|
|
393
397
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
394
|
-
mask =
|
|
398
|
+
mask = k_span <= q_span
|
|
395
399
|
|
|
396
|
-
if
|
|
397
|
-
|
|
398
|
-
|
|
400
|
+
if sliding_window is not None and strict_sliding_window:
|
|
401
|
+
mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
|
|
402
|
+
|
|
403
|
+
s = jnp.where(mask, s, mask_value)
|
|
399
404
|
s_rowmax = jnp.max(s, axis=1, keepdims=True)
|
|
400
405
|
|
|
401
406
|
if attention_sink_ref is not None:
|
|
@@ -475,42 +480,54 @@ def _ragged_paged_attention_kernel(
|
|
|
475
480
|
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
476
481
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
477
482
|
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
483
|
+
if not wait:
|
|
484
|
+
# Fetch effective kv from kv cache.
|
|
485
|
+
def loop_body(i, offset):
|
|
486
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
487
|
+
_async_copy(
|
|
488
|
+
cache_hbm_ref.at[pl.ds(
|
|
489
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
490
|
+
sz)],
|
|
491
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
492
|
+
sem,
|
|
493
|
+
wait=False,
|
|
494
|
+
)
|
|
495
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
496
|
+
return offset + sz
|
|
497
|
+
|
|
498
|
+
offset = lax.fori_loop(
|
|
499
|
+
0,
|
|
500
|
+
bkv_p_frm_cache,
|
|
501
|
+
loop_body,
|
|
502
|
+
0, # offset
|
|
503
|
+
unroll=False,
|
|
488
504
|
)
|
|
489
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
490
|
-
return offset + sz
|
|
491
|
-
|
|
492
|
-
offset = lax.fori_loop(
|
|
493
|
-
0,
|
|
494
|
-
bkv_p_frm_cache,
|
|
495
|
-
loop_body,
|
|
496
|
-
0, # offset
|
|
497
|
-
unroll=False,
|
|
498
|
-
)
|
|
499
505
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
+
# Fetch kv directly from new kv.
|
|
507
|
+
@pl.when(bkv_sz_frm_new > 0)
|
|
508
|
+
def _fetch_bkv_from_new_kv():
|
|
509
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
510
|
+
debug_print("[RPA debug] new_kv_len_start={}",
|
|
511
|
+
new_kv_len_start)
|
|
512
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
513
|
+
_async_copy(
|
|
514
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
515
|
+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
516
|
+
sem,
|
|
517
|
+
wait,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
521
|
+
else:
|
|
522
|
+
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
523
|
+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
|
|
506
524
|
_async_copy(
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
sem,
|
|
510
|
-
wait,
|
|
525
|
+
src=dst,
|
|
526
|
+
dst=dst,
|
|
527
|
+
sem=sem,
|
|
528
|
+
wait=True,
|
|
511
529
|
)
|
|
512
|
-
|
|
513
|
-
return kv_len_start + offset, bkv_sz_frm_new
|
|
530
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
514
531
|
|
|
515
532
|
def _update_kv_cache(seq_idx,
|
|
516
533
|
bkv_sem_idx,
|
|
@@ -546,30 +563,41 @@ def _ragged_paged_attention_kernel(
|
|
|
546
563
|
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
547
564
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
548
565
|
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
566
|
+
if not wait:
|
|
567
|
+
|
|
568
|
+
def loop_body(i, states):
|
|
569
|
+
update_sz, ignore = states
|
|
570
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
571
|
+
|
|
572
|
+
_async_copy(
|
|
573
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
574
|
+
sz)],
|
|
575
|
+
cache_hbm_ref.at[pl.ds(
|
|
576
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
577
|
+
ignore,
|
|
578
|
+
sz,
|
|
579
|
+
)],
|
|
580
|
+
sem,
|
|
581
|
+
wait=False,
|
|
582
|
+
)
|
|
583
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
584
|
+
return update_sz - sz, 0
|
|
585
|
+
|
|
586
|
+
lax.fori_loop(
|
|
587
|
+
0,
|
|
588
|
+
kv_p_end - kv_p_start,
|
|
589
|
+
loop_body,
|
|
590
|
+
(update_sz, ignore), # total transfer size
|
|
591
|
+
unroll=False,
|
|
592
|
+
)
|
|
593
|
+
else:
|
|
594
|
+
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
553
595
|
_async_copy(
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
sz,
|
|
559
|
-
)],
|
|
560
|
-
sem,
|
|
561
|
-
wait,
|
|
596
|
+
src=dst,
|
|
597
|
+
dst=dst,
|
|
598
|
+
sem=sem,
|
|
599
|
+
wait=True,
|
|
562
600
|
)
|
|
563
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
564
|
-
return update_sz - sz, 0
|
|
565
|
-
|
|
566
|
-
lax.fori_loop(
|
|
567
|
-
0,
|
|
568
|
-
kv_p_end - kv_p_start,
|
|
569
|
-
loop_body,
|
|
570
|
-
(update_sz, ignore), # total transfer size
|
|
571
|
-
unroll=False,
|
|
572
|
-
)
|
|
573
601
|
|
|
574
602
|
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
575
603
|
sem = sems.at[1, bq_sem_idx]
|
|
@@ -737,13 +765,17 @@ def _ragged_paged_attention_kernel(
|
|
|
737
765
|
next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
|
|
738
766
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
739
767
|
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
768
|
+
if sliding_window is None:
|
|
769
|
+
next_bkv_start_idx = 0
|
|
770
|
+
else:
|
|
771
|
+
next_bkv_start_idx = lax.select(
|
|
743
772
|
is_last_bq,
|
|
744
|
-
|
|
773
|
+
next_seq_bkv_idx_start,
|
|
745
774
|
bkv_idx_start,
|
|
746
|
-
)
|
|
775
|
+
)
|
|
776
|
+
next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
|
|
777
|
+
next_bkv_idx)
|
|
778
|
+
|
|
747
779
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
748
780
|
|
|
749
781
|
def compute_with_bq(bq_idx, _):
|
|
@@ -1226,6 +1258,7 @@ def static_validate_inputs(
|
|
|
1226
1258
|
static_argnames=(
|
|
1227
1259
|
"sm_scale",
|
|
1228
1260
|
"sliding_window",
|
|
1261
|
+
"strict_sliding_window",
|
|
1229
1262
|
"soft_cap",
|
|
1230
1263
|
"mask_value",
|
|
1231
1264
|
"q_scale",
|
|
@@ -1255,6 +1288,7 @@ def ragged_paged_attention_hd64(
|
|
|
1255
1288
|
*,
|
|
1256
1289
|
sm_scale: float = 1.0,
|
|
1257
1290
|
sliding_window: int | None = None,
|
|
1291
|
+
strict_sliding_window: bool = True,
|
|
1258
1292
|
soft_cap: float | None = None,
|
|
1259
1293
|
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
1260
1294
|
q_scale: float | None = None,
|
|
@@ -1269,42 +1303,41 @@ def ragged_paged_attention_hd64(
|
|
|
1269
1303
|
# Debug params.
|
|
1270
1304
|
debug_mode: bool = False,
|
|
1271
1305
|
):
|
|
1272
|
-
"""A
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
"""
|
|
1306
|
+
"""A variant of ragged paged attention for head_dim=64.
|
|
1307
|
+
|
|
1308
|
+
Args:
|
|
1309
|
+
queries: concatenated all sequences' queries.
|
|
1310
|
+
keys: concatenated all sequences' keys (quantized).
|
|
1311
|
+
values: concatenated all sequences' values (quantized).
|
|
1312
|
+
kv_cache: paged KV cache with TPU-friendly shape.
|
|
1313
|
+
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
|
|
1314
|
+
page_indices: flattened page indices look-up table by (seq_id, page_id).
|
|
1315
|
+
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
|
|
1316
|
+
kv_lens, only the first num_seqs+1 values are valid.
|
|
1317
|
+
distribution: (i, j, k) represents that sequences[0:i] are decode-only,
|
|
1318
|
+
sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
|
|
1319
|
+
k is also the total number of sequences.
|
|
1320
|
+
attention_sink: optional attention sink for each q head.
|
|
1321
|
+
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
1322
|
+
sliding_window: the sliding window size for the attention.
|
|
1323
|
+
strict_sliding_window: compute tokens that are strictly within the window.
|
|
1324
|
+
soft_cap: the logit soft cap for the attention.
|
|
1325
|
+
mask_value: mask value for causal mask.
|
|
1326
|
+
q_scale: the scale for the query.
|
|
1327
|
+
k_scale: the scale for the key cache.
|
|
1328
|
+
v_scale: the scale for the value cache.
|
|
1329
|
+
chunk_prefill_size: the chunk prefill size for the attention.
|
|
1330
|
+
num_kv_pages_per_block: number of kv pages to be processed in one flash
|
|
1331
|
+
attention block in the pallas kernel.
|
|
1332
|
+
num_queries_per_block: number of kv pages to be processed in one flash
|
|
1333
|
+
attention block in the pallas kernel.
|
|
1334
|
+
vmem_limit_bytes: the vmem limit for the pallas kernel.
|
|
1335
|
+
debug_mode: if true, RPA does not issue any DMAs or run flash attention but
|
|
1336
|
+
print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
|
|
1337
|
+
|
|
1338
|
+
Returns:
|
|
1339
|
+
The output of the attention.
|
|
1340
|
+
"""
|
|
1308
1341
|
q, k, v = queries, keys, values
|
|
1309
1342
|
static_validate_inputs(
|
|
1310
1343
|
q,
|
|
@@ -1374,7 +1407,7 @@ def ragged_paged_attention_hd64(
|
|
|
1374
1407
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1375
1408
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1376
1409
|
None if attention_sink is None else pl.BlockSpec(
|
|
1377
|
-
memory_space=pltpu.VMEM)
|
|
1410
|
+
memory_space=pltpu.VMEM),
|
|
1378
1411
|
]
|
|
1379
1412
|
|
|
1380
1413
|
out_specs = [
|
|
@@ -1438,6 +1471,7 @@ def ragged_paged_attention_hd64(
|
|
|
1438
1471
|
_ragged_paged_attention_kernel,
|
|
1439
1472
|
sm_scale=sm_scale,
|
|
1440
1473
|
sliding_window=sliding_window,
|
|
1474
|
+
strict_sliding_window=strict_sliding_window,
|
|
1441
1475
|
soft_cap=soft_cap,
|
|
1442
1476
|
mask_value=mask_value,
|
|
1443
1477
|
q_scale=q_scale,
|
|
@@ -308,7 +308,13 @@ def sharded_ragged_paged_attention(
|
|
|
308
308
|
args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
|
|
309
309
|
|
|
310
310
|
use_hd64 = q.shape[-1] == 64
|
|
311
|
-
|
|
311
|
+
|
|
312
|
+
func = ragged_paged_attention
|
|
313
|
+
if use_hd64:
|
|
314
|
+
func = functools.partial(ragged_paged_attention_hd64,
|
|
315
|
+
strict_sliding_window=False)
|
|
316
|
+
else:
|
|
317
|
+
func = ragged_paged_attention
|
|
312
318
|
|
|
313
319
|
if attention_sink is not None:
|
|
314
320
|
if not use_hd64:
|
|
@@ -166,9 +166,10 @@ class ShardingConfigManager:
|
|
|
166
166
|
f"LoRA is not supported with data parallelism "
|
|
167
167
|
f"(DP size: {total_dp_size}). Please disable LoRA or "
|
|
168
168
|
f"set data parallelism to 1.")
|
|
169
|
+
if sharding_strategy.attention_data_parallelism > 1:
|
|
169
170
|
if not os.environ.get("NEW_MODEL_DESIGN", False):
|
|
170
171
|
raise ValueError(
|
|
171
|
-
"Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
|
|
172
|
+
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
|
|
172
173
|
"NEW_MODEL_DESIGN=True.")
|
|
173
174
|
|
|
174
175
|
@property
|