tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_utils.py +16 -24
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/core_tpu.py +9 -17
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +11 -31
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/{common → jax}/sharding.py +5 -5
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/__init__.py +3 -7
- tpu_inference/layers/vllm/quantization/awq.py +3 -4
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
- tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +12 -46
- tpu_inference/models/jax/llama3.py +3 -4
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +2 -3
- tpu_inference/models/jax/qwen2_5_vl.py +50 -165
- tpu_inference/models/jax/qwen3.py +2 -3
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
- tpu_inference/platforms/tpu_platform.py +34 -47
- tpu_inference/runner/compilation_manager.py +60 -145
- tpu_inference/runner/kv_cache.py +2 -2
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +135 -283
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +15 -38
- tpu_inference/worker/tpu_worker.py +26 -163
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
- tests/test_envs.py +0 -203
- tpu_inference/layers/common/quant_methods.py +0 -8
- tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
|
@@ -267,7 +267,6 @@ def _ragged_paged_attention_kernel(
|
|
|
267
267
|
*,
|
|
268
268
|
sm_scale: float,
|
|
269
269
|
sliding_window: int | None = None,
|
|
270
|
-
strict_sliding_window: bool = True,
|
|
271
270
|
soft_cap: float | None = None,
|
|
272
271
|
mask_value: float = DEFAULT_MASK_VALUE,
|
|
273
272
|
q_scale: float | None = None,
|
|
@@ -318,21 +317,6 @@ def _ragged_paged_attention_kernel(
|
|
|
318
317
|
q_len = q_end - q_start
|
|
319
318
|
kv_len = kv_lens_ref[seq_idx]
|
|
320
319
|
|
|
321
|
-
if sliding_window is None:
|
|
322
|
-
bkv_idx_start = next_seq_bkv_idx_start = 0
|
|
323
|
-
else:
|
|
324
|
-
bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
|
|
325
|
-
0) // bkv_sz
|
|
326
|
-
|
|
327
|
-
def get_next_bkv_idx_start():
|
|
328
|
-
next_kv_len = kv_lens_ref[seq_idx + 1]
|
|
329
|
-
next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
|
|
330
|
-
return jnp.maximum(next_kv_len - next_q_len - sliding_window,
|
|
331
|
-
0) // bkv_sz
|
|
332
|
-
|
|
333
|
-
next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
|
|
334
|
-
get_next_bkv_idx_start, lambda: 0)
|
|
335
|
-
|
|
336
320
|
def debug_print(msg, *args):
|
|
337
321
|
if debug_mode:
|
|
338
322
|
pl.debug_print(msg, *args)
|
|
@@ -352,7 +336,7 @@ def _ragged_paged_attention_kernel(
|
|
|
352
336
|
debug_print("[RPA debug] q_len={}", q_len)
|
|
353
337
|
debug_print("[RPA debug] kv_len={}", kv_len)
|
|
354
338
|
|
|
355
|
-
def
|
|
339
|
+
def flash_attention(
|
|
356
340
|
q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
|
|
357
341
|
kv, # [bkv_sz, actual_head_dim_x2]
|
|
358
342
|
*,
|
|
@@ -366,10 +350,11 @@ def _ragged_paged_attention_kernel(
|
|
|
366
350
|
assert kv.shape == (bkv_sz, actual_head_dim_x2)
|
|
367
351
|
head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
|
|
368
352
|
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
|
|
353
|
+
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
|
|
369
354
|
|
|
370
355
|
def load_with_init(ref, init_val):
|
|
371
|
-
return jnp.where(bkv_idx ==
|
|
372
|
-
|
|
356
|
+
return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
|
|
357
|
+
ref[...])
|
|
373
358
|
|
|
374
359
|
# Follow FlashAttention-2 forward pass.
|
|
375
360
|
if q_scale is not None:
|
|
@@ -387,27 +372,26 @@ def _ragged_paged_attention_kernel(
|
|
|
387
372
|
s *= k_scale
|
|
388
373
|
if q_scale is not None:
|
|
389
374
|
s *= q_scale
|
|
390
|
-
if soft_cap is not None:
|
|
391
|
-
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
392
375
|
|
|
393
376
|
q_span = (kv_len - q_len + bq_idx * bq_sz +
|
|
394
377
|
lax.broadcasted_iota(jnp.int32, s.shape, 0) //
|
|
395
378
|
num_q_heads_per_kv_head)
|
|
396
379
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
397
|
-
mask =
|
|
398
|
-
|
|
399
|
-
if sliding_window is not None
|
|
400
|
-
mask = jnp.
|
|
380
|
+
mask = q_span < k_span
|
|
381
|
+
# TODO(jevinjiang, xiowei): reduce pages_per_seq based on sliding_window.
|
|
382
|
+
if sliding_window is not None:
|
|
383
|
+
mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)
|
|
401
384
|
|
|
402
|
-
|
|
385
|
+
if soft_cap is not None:
|
|
386
|
+
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
387
|
+
s += jnp.where(mask, mask_value, 0.0)
|
|
403
388
|
s_rowmax = jnp.max(s, axis=1, keepdims=True)
|
|
404
389
|
|
|
405
390
|
if attention_sink_ref is not None:
|
|
406
391
|
sinks = attention_sink_ref[kv_head_idx]
|
|
407
392
|
actual_bq_sz = q.shape[0] // num_q_heads_per_kv_head
|
|
408
393
|
m_prev_init = jnp.concat([sinks] * actual_bq_sz, axis=0)
|
|
409
|
-
m_prev = jnp.where(bkv_idx ==
|
|
410
|
-
head_m_ref[...])
|
|
394
|
+
m_prev = jnp.where(bkv_idx == 0, m_prev_init, head_m_ref[...])
|
|
411
395
|
else:
|
|
412
396
|
m_prev = load_with_init(head_m_ref, -jnp.inf)
|
|
413
397
|
|
|
@@ -415,33 +399,15 @@ def _ragged_paged_attention_kernel(
|
|
|
415
399
|
head_m_ref[...] = m_curr
|
|
416
400
|
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
|
|
417
401
|
|
|
402
|
+
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
403
|
+
if v_scale is not None:
|
|
404
|
+
pv *= v_scale
|
|
405
|
+
|
|
418
406
|
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
|
|
419
407
|
exp_m_diff = jnp.exp(m_prev - m_curr)
|
|
420
408
|
l_prev = load_with_init(head_l_ref, 1.0)
|
|
421
409
|
l_curr = exp_m_diff * l_prev + p_rowsum
|
|
422
410
|
head_l_ref[...] = l_curr
|
|
423
|
-
|
|
424
|
-
return p, exp_m_diff
|
|
425
|
-
|
|
426
|
-
def flash_attention_step2_pv(
|
|
427
|
-
q_shape_0,
|
|
428
|
-
kv, # [bkv_sz, actual_head_dim_x2]
|
|
429
|
-
p, # from step1
|
|
430
|
-
exp_m_diff, # from step1
|
|
431
|
-
*,
|
|
432
|
-
bkv_idx,
|
|
433
|
-
kv_head_idx,
|
|
434
|
-
):
|
|
435
|
-
head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
|
|
436
|
-
|
|
437
|
-
def load_with_init(ref, init_val):
|
|
438
|
-
return jnp.where(bkv_idx == bkv_idx_start,
|
|
439
|
-
jnp.full_like(ref, init_val), ref[...])
|
|
440
|
-
|
|
441
|
-
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
442
|
-
if v_scale is not None:
|
|
443
|
-
pv *= v_scale
|
|
444
|
-
|
|
445
411
|
o_prev = load_with_init(head_acc_ref, 0.0)
|
|
446
412
|
o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
|
|
447
413
|
head_acc_ref[...] = o_curr
|
|
@@ -456,12 +422,7 @@ def _ragged_paged_attention_kernel(
|
|
|
456
422
|
else:
|
|
457
423
|
cp.start()
|
|
458
424
|
|
|
459
|
-
def _fetch_bkv(seq_idx,
|
|
460
|
-
bkv_idx,
|
|
461
|
-
bkv_sem_idx,
|
|
462
|
-
*,
|
|
463
|
-
is_full_fetch=False,
|
|
464
|
-
wait=False):
|
|
425
|
+
def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
|
|
465
426
|
sem = sems.at[0, bkv_sem_idx]
|
|
466
427
|
vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
|
|
467
428
|
|
|
@@ -502,73 +463,42 @@ def _ragged_paged_attention_kernel(
|
|
|
502
463
|
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
503
464
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
504
465
|
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
wait=False,
|
|
516
|
-
)
|
|
517
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
518
|
-
return offset + sz
|
|
519
|
-
|
|
520
|
-
offset = lax.fori_loop(
|
|
521
|
-
0,
|
|
522
|
-
bkv_p_frm_cache,
|
|
523
|
-
loop_body,
|
|
524
|
-
0, # offset
|
|
525
|
-
unroll=False,
|
|
466
|
+
# Fetch effective kv from kv cache.
|
|
467
|
+
def loop_body(i, offset):
|
|
468
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
469
|
+
_async_copy(
|
|
470
|
+
cache_hbm_ref.at[pl.ds(
|
|
471
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
472
|
+
sz)],
|
|
473
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
474
|
+
sem,
|
|
475
|
+
wait,
|
|
526
476
|
)
|
|
477
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
478
|
+
return offset + sz
|
|
479
|
+
|
|
480
|
+
offset = lax.fori_loop(
|
|
481
|
+
0,
|
|
482
|
+
bkv_p_frm_cache,
|
|
483
|
+
loop_body,
|
|
484
|
+
0, # offset
|
|
485
|
+
unroll=False,
|
|
486
|
+
)
|
|
527
487
|
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
535
|
-
_async_copy(
|
|
536
|
-
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
537
|
-
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
538
|
-
sem,
|
|
539
|
-
wait,
|
|
540
|
-
)
|
|
541
|
-
|
|
542
|
-
# NOTE(chengjiyao): This condition is true for the first two bkv fetches.
|
|
543
|
-
# We need to ensure the bkv_x2_ref VMEM buffer is fully initialized to
|
|
544
|
-
# avoid potential NaN values in regions not overwritten by actual data.
|
|
545
|
-
# This is done by padding the remaining parts of the buffer with data
|
|
546
|
-
# from the KV cache. This special handling is only strictly necessary
|
|
547
|
-
# until both buffers in the double buffer (bkv_x2_ref) have been written
|
|
548
|
-
# to at least once.
|
|
549
|
-
@pl.when(is_full_fetch)
|
|
550
|
-
def _make_sure_bkv_vmem_is_not_nan():
|
|
551
|
-
effective_sz = offset + bkv_sz_frm_new
|
|
552
|
-
remaining_sz = bkv_sz - effective_sz
|
|
553
|
-
_async_copy(
|
|
554
|
-
cache_hbm_ref.at[pl.ds(0, remaining_sz)],
|
|
555
|
-
vmem_ref.at[pl.ds(effective_sz, remaining_sz)],
|
|
556
|
-
sem,
|
|
557
|
-
wait,
|
|
558
|
-
)
|
|
559
|
-
|
|
560
|
-
return kv_len_start + offset, bkv_sz_frm_new
|
|
561
|
-
else:
|
|
562
|
-
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
563
|
-
sz = lax.select(is_full_fetch, bkv_sz, offset + bkv_sz_frm_new)
|
|
564
|
-
dst = vmem_ref.at[pl.ds(0, sz)]
|
|
488
|
+
# Fetch kv directly from new kv.
|
|
489
|
+
@pl.when(bkv_sz_frm_new > 0)
|
|
490
|
+
def _fetch_bkv_from_new_kv():
|
|
491
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
492
|
+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
|
|
493
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
565
494
|
_async_copy(
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
sem
|
|
569
|
-
wait
|
|
495
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
496
|
+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
497
|
+
sem,
|
|
498
|
+
wait,
|
|
570
499
|
)
|
|
571
|
-
|
|
500
|
+
|
|
501
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
572
502
|
|
|
573
503
|
def _update_kv_cache(seq_idx,
|
|
574
504
|
bkv_sem_idx,
|
|
@@ -604,41 +534,30 @@ def _ragged_paged_attention_kernel(
|
|
|
604
534
|
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
605
535
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
606
536
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
612
|
-
|
|
613
|
-
_async_copy(
|
|
614
|
-
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
615
|
-
sz)],
|
|
616
|
-
cache_hbm_ref.at[pl.ds(
|
|
617
|
-
page_indices_ref[page_indices_offset + i] * page_size +
|
|
618
|
-
ignore,
|
|
619
|
-
sz,
|
|
620
|
-
)],
|
|
621
|
-
sem,
|
|
622
|
-
wait=False,
|
|
623
|
-
)
|
|
624
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
625
|
-
return update_sz - sz, 0
|
|
626
|
-
|
|
627
|
-
lax.fori_loop(
|
|
628
|
-
0,
|
|
629
|
-
kv_p_end - kv_p_start,
|
|
630
|
-
loop_body,
|
|
631
|
-
(update_sz, ignore), # total transfer size
|
|
632
|
-
unroll=False,
|
|
633
|
-
)
|
|
634
|
-
else:
|
|
635
|
-
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
537
|
+
def loop_body(i, states):
|
|
538
|
+
update_sz, ignore = states
|
|
539
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
540
|
+
|
|
636
541
|
_async_copy(
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
542
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
|
|
543
|
+
cache_hbm_ref.at[pl.ds(
|
|
544
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
545
|
+
ignore,
|
|
546
|
+
sz,
|
|
547
|
+
)],
|
|
548
|
+
sem,
|
|
549
|
+
wait,
|
|
641
550
|
)
|
|
551
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
552
|
+
return update_sz - sz, 0
|
|
553
|
+
|
|
554
|
+
lax.fori_loop(
|
|
555
|
+
0,
|
|
556
|
+
kv_p_end - kv_p_start,
|
|
557
|
+
loop_body,
|
|
558
|
+
(update_sz, ignore), # total transfer size
|
|
559
|
+
unroll=False,
|
|
560
|
+
)
|
|
642
561
|
|
|
643
562
|
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
644
563
|
sem = sems.at[1, bq_sem_idx]
|
|
@@ -688,18 +607,11 @@ def _ragged_paged_attention_kernel(
|
|
|
688
607
|
wait,
|
|
689
608
|
)
|
|
690
609
|
|
|
691
|
-
def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx
|
|
692
|
-
return _fetch_bkv(seq_idx,
|
|
693
|
-
bkv_idx,
|
|
694
|
-
bkv_sem_idx,
|
|
695
|
-
is_full_fetch=is_full_fetch)
|
|
610
|
+
def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
|
|
611
|
+
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
|
|
696
612
|
|
|
697
|
-
def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx
|
|
698
|
-
return _fetch_bkv(seq_idx,
|
|
699
|
-
bkv_idx,
|
|
700
|
-
bkv_sem_idx,
|
|
701
|
-
is_full_fetch=is_full_fetch,
|
|
702
|
-
wait=True)
|
|
613
|
+
def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
|
|
614
|
+
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, wait=True)
|
|
703
615
|
|
|
704
616
|
def start_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
|
|
705
617
|
return _fetch_bq(seq_idx, bq_idx, bq_sem_idx)
|
|
@@ -757,7 +669,7 @@ def _ragged_paged_attention_kernel(
|
|
|
757
669
|
vec = ref[start::step]
|
|
758
670
|
return vec
|
|
759
671
|
|
|
760
|
-
def strided_load_bkv(bkv_sem_idx, start, step):
|
|
672
|
+
def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
|
|
761
673
|
assert start % kv_packing == 0
|
|
762
674
|
assert step % kv_packing == 0
|
|
763
675
|
start //= kv_packing
|
|
@@ -766,6 +678,7 @@ def _ragged_paged_attention_kernel(
|
|
|
766
678
|
bkv_sz * step, actual_head_dim_x2))
|
|
767
679
|
|
|
768
680
|
kv = strided_load(kv_ref, start, step)
|
|
681
|
+
kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
|
|
769
682
|
bitwidth = 32 // kv_packing
|
|
770
683
|
repack_ty = jnp.dtype(f"uint{bitwidth}")
|
|
771
684
|
lst = []
|
|
@@ -806,23 +719,12 @@ def _ragged_paged_attention_kernel(
|
|
|
806
719
|
def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
|
|
807
720
|
next_bkv_idx = bkv_idx + 1
|
|
808
721
|
is_last_bkv = next_bkv_idx == num_bkv
|
|
722
|
+
next_bkv_idx = lax.select(is_last_bkv, 0, next_bkv_idx)
|
|
809
723
|
next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
|
|
810
724
|
is_last_bq = next_bq_idx == num_bq
|
|
811
725
|
next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
|
|
812
726
|
next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
|
|
813
727
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
814
|
-
|
|
815
|
-
if sliding_window is None:
|
|
816
|
-
next_bkv_start_idx = 0
|
|
817
|
-
else:
|
|
818
|
-
next_bkv_start_idx = lax.select(
|
|
819
|
-
is_last_bq,
|
|
820
|
-
next_seq_bkv_idx_start,
|
|
821
|
-
bkv_idx_start,
|
|
822
|
-
)
|
|
823
|
-
next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
|
|
824
|
-
next_bkv_idx)
|
|
825
|
-
|
|
826
728
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
827
729
|
|
|
828
730
|
def compute_with_bq(bq_idx, _):
|
|
@@ -839,36 +741,31 @@ def _ragged_paged_attention_kernel(
|
|
|
839
741
|
def compute_with_bkv(bkv_idx, _):
|
|
840
742
|
# Create bitmask for KV.
|
|
841
743
|
assert bkv_sz % kv_packing == 0
|
|
744
|
+
actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
|
|
745
|
+
bkv_shape = (bkv_sz, actual_head_dim_x2)
|
|
746
|
+
bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
|
|
747
|
+
0) < actual_bkv_sz
|
|
842
748
|
|
|
843
749
|
# Get next bkv ids.
|
|
844
750
|
bkv_sem_idx = sem_ids_ref[1]
|
|
845
|
-
next_seq_idx,
|
|
846
|
-
|
|
751
|
+
next_seq_idx, _, next_bkv_idx, next_bkv_sem_idx = get_next_bkv_ids(
|
|
752
|
+
seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
|
|
847
753
|
|
|
848
754
|
# Prefetch next bkv
|
|
849
755
|
@pl.when(next_seq_idx < num_seqs)
|
|
850
756
|
def prefetch_next_bkv():
|
|
851
757
|
sem_ids_ref[1] = next_bkv_sem_idx
|
|
852
|
-
start_fetch_bkv(
|
|
853
|
-
|
|
854
|
-
next_bkv_idx,
|
|
855
|
-
next_bkv_sem_idx,
|
|
856
|
-
is_full_fetch=next_seq_idx + next_bq_idx_for_kv +
|
|
857
|
-
next_bkv_idx < 2,
|
|
858
|
-
)
|
|
758
|
+
start_fetch_bkv(next_seq_idx, next_bkv_idx,
|
|
759
|
+
next_bkv_sem_idx)
|
|
859
760
|
|
|
860
761
|
# Wait for cur bq if not ready yet
|
|
861
|
-
@pl.when(bkv_idx ==
|
|
762
|
+
@pl.when(bkv_idx == 0)
|
|
862
763
|
def wait_cur_bq():
|
|
863
764
|
wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
|
|
864
765
|
|
|
865
766
|
# Wait for cur bkv
|
|
866
|
-
offset, update_sz = wait_fetch_bkv(
|
|
867
|
-
|
|
868
|
-
bkv_idx,
|
|
869
|
-
bkv_sem_idx,
|
|
870
|
-
is_full_fetch=seq_idx + bq_idx + bkv_idx < 2,
|
|
871
|
-
)
|
|
767
|
+
offset, update_sz = wait_fetch_bkv(seq_idx, bkv_idx,
|
|
768
|
+
bkv_sem_idx)
|
|
872
769
|
|
|
873
770
|
# Start updating bkv to kv cache if applicable.
|
|
874
771
|
# Only needed in first bq loop.
|
|
@@ -887,70 +784,31 @@ def _ragged_paged_attention_kernel(
|
|
|
887
784
|
return
|
|
888
785
|
|
|
889
786
|
# Flash attention with cur bkv and bq
|
|
890
|
-
prev_bq_shape_0 = None
|
|
891
|
-
prev_kv_head_bkv = None
|
|
892
|
-
prev_kv_head_idx = None
|
|
893
|
-
prev_kv_head_p = None
|
|
894
|
-
prev_kv_head_exp_m_diff = None
|
|
895
787
|
for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
|
|
896
788
|
bkv_lst = strided_load_bkv(
|
|
897
789
|
bkv_sem_idx,
|
|
898
790
|
kv_head_start,
|
|
899
791
|
num_kv_heads,
|
|
792
|
+
bkv_mask=bkv_mask,
|
|
900
793
|
)
|
|
901
794
|
assert len(bkv_lst) == kv_packing
|
|
902
795
|
for i in range(kv_packing):
|
|
903
|
-
|
|
904
|
-
if
|
|
796
|
+
kv_head_idx = kv_head_start + i
|
|
797
|
+
if kv_head_idx >= actual_num_kv_heads:
|
|
905
798
|
break
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
bq_idx=bq_idx,
|
|
920
|
-
bkv_idx=bkv_idx,
|
|
921
|
-
kv_head_idx=cur_kv_head_idx,
|
|
922
|
-
))
|
|
923
|
-
if prev_bq_shape_0 is not None:
|
|
924
|
-
flash_attention_step2_pv(
|
|
925
|
-
prev_bq_shape_0,
|
|
926
|
-
prev_kv_head_bkv,
|
|
927
|
-
prev_kv_head_p,
|
|
928
|
-
prev_kv_head_exp_m_diff,
|
|
929
|
-
bkv_idx=bkv_idx,
|
|
930
|
-
kv_head_idx=prev_kv_head_idx,
|
|
931
|
-
)
|
|
932
|
-
prev_bq_shape_0 = cur_kv_head_bq.shape[0]
|
|
933
|
-
prev_kv_head_bkv = cur_kv_head__bkv
|
|
934
|
-
prev_kv_head_p = cur_kv_head_p
|
|
935
|
-
prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
|
|
936
|
-
prev_kv_head_idx = cur_kv_head_idx
|
|
937
|
-
|
|
938
|
-
# Execute pv of last attention head.
|
|
939
|
-
assert prev_bq_shape_0 is not None
|
|
940
|
-
flash_attention_step2_pv(
|
|
941
|
-
prev_bq_shape_0,
|
|
942
|
-
prev_kv_head_bkv,
|
|
943
|
-
prev_kv_head_p,
|
|
944
|
-
prev_kv_head_exp_m_diff,
|
|
945
|
-
bkv_idx=bkv_idx,
|
|
946
|
-
kv_head_idx=prev_kv_head_idx,
|
|
947
|
-
)
|
|
948
|
-
|
|
949
|
-
lax.fori_loop(bkv_idx_start,
|
|
950
|
-
num_bkv,
|
|
951
|
-
compute_with_bkv,
|
|
952
|
-
None,
|
|
953
|
-
unroll=False)
|
|
799
|
+
bq = load_bq(bq_sem_idx,
|
|
800
|
+
kv_head_idx,
|
|
801
|
+
actual_bq_sz=actual_bq_sz)
|
|
802
|
+
bkv = bkv_lst[i]
|
|
803
|
+
flash_attention(
|
|
804
|
+
bq,
|
|
805
|
+
bkv,
|
|
806
|
+
bq_idx=bq_idx,
|
|
807
|
+
bkv_idx=bkv_idx,
|
|
808
|
+
kv_head_idx=kv_head_idx,
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
|
|
954
812
|
|
|
955
813
|
# Load acc and calculate final output.
|
|
956
814
|
acc = acc_ref[...]
|
|
@@ -980,7 +838,7 @@ def _ragged_paged_attention_kernel(
|
|
|
980
838
|
@pl.when(seq_idx == 0)
|
|
981
839
|
def prologue():
|
|
982
840
|
start_fetch_bq(0, 0, 0)
|
|
983
|
-
start_fetch_bkv(0,
|
|
841
|
+
start_fetch_bkv(0, 0, 0)
|
|
984
842
|
|
|
985
843
|
@pl.when(seq_idx < decode_end)
|
|
986
844
|
def process_decode():
|
|
@@ -1345,7 +1203,6 @@ def static_validate_inputs(
|
|
|
1345
1203
|
static_argnames=(
|
|
1346
1204
|
"sm_scale",
|
|
1347
1205
|
"sliding_window",
|
|
1348
|
-
"strict_sliding_window",
|
|
1349
1206
|
"soft_cap",
|
|
1350
1207
|
"mask_value",
|
|
1351
1208
|
"q_scale",
|
|
@@ -1375,7 +1232,6 @@ def ragged_paged_attention_hd64(
|
|
|
1375
1232
|
*,
|
|
1376
1233
|
sm_scale: float = 1.0,
|
|
1377
1234
|
sliding_window: int | None = None,
|
|
1378
|
-
strict_sliding_window: bool = True,
|
|
1379
1235
|
soft_cap: float | None = None,
|
|
1380
1236
|
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
1381
1237
|
q_scale: float | None = None,
|
|
@@ -1390,41 +1246,42 @@ def ragged_paged_attention_hd64(
|
|
|
1390
1246
|
# Debug params.
|
|
1391
1247
|
debug_mode: bool = False,
|
|
1392
1248
|
):
|
|
1393
|
-
"""A
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1249
|
+
"""A special Ragged paged attention version for head_dim=64 that supports mixed
|
|
1250
|
+
|
|
1251
|
+
prefill and decode.
|
|
1252
|
+
|
|
1253
|
+
Args:
|
|
1254
|
+
queries: concatenated all sequences' queries.
|
|
1255
|
+
keys: concatenated all sequences' keys (quantized).
|
|
1256
|
+
values: concatenated all sequences' values (quantized).
|
|
1257
|
+
kv_cache: paged KV cache with TPU-friendly shape.
|
|
1258
|
+
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
|
|
1259
|
+
page_indices: flattened page indices look-up table by (seq_id, page_id).
|
|
1260
|
+
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
|
|
1261
|
+
kv_lens, only the first num_seqs+1 values are valid.
|
|
1262
|
+
distribution: (i, j, k) represents that sequences[0:i] are decode-only,
|
|
1263
|
+
sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
|
|
1264
|
+
k is also the total number of sequences.
|
|
1265
|
+
attention_sink: optional attention sink for each q head.
|
|
1266
|
+
actual_head_dim: the actual head size of the attention. Here we assume k and
|
|
1267
|
+
v have the same actual head size.
|
|
1268
|
+
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
1269
|
+
sliding_window: the sliding window size for the attention.
|
|
1270
|
+
soft_cap: the logit soft cap for the attention.
|
|
1271
|
+
mask_value: mask value for causal mask.
|
|
1272
|
+
k_scale: the scale for the key cache.
|
|
1273
|
+
v_scale: the scale for the value cache.
|
|
1274
|
+
num_kv_pages_per_block: number of kv pages to be processed in one flash
|
|
1275
|
+
attention block in the pallas kernel.
|
|
1276
|
+
num_queries_per_block: number of kv pages to be processed in one flash
|
|
1277
|
+
attention block in the pallas kernel.
|
|
1278
|
+
vmem_limit_bytes: the vmem limit for the pallas kernel.
|
|
1279
|
+
debug_mode: if true, RPA does not issue any DMAs or run flash attention but
|
|
1280
|
+
print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
|
|
1281
|
+
|
|
1282
|
+
Returns:
|
|
1283
|
+
The output of the attention.
|
|
1284
|
+
"""
|
|
1428
1285
|
q, k, v = queries, keys, values
|
|
1429
1286
|
static_validate_inputs(
|
|
1430
1287
|
q,
|
|
@@ -1494,7 +1351,7 @@ def ragged_paged_attention_hd64(
|
|
|
1494
1351
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1495
1352
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1496
1353
|
None if attention_sink is None else pl.BlockSpec(
|
|
1497
|
-
memory_space=pltpu.VMEM)
|
|
1354
|
+
memory_space=pltpu.VMEM)
|
|
1498
1355
|
]
|
|
1499
1356
|
|
|
1500
1357
|
out_specs = [
|
|
@@ -1558,7 +1415,6 @@ def ragged_paged_attention_hd64(
|
|
|
1558
1415
|
_ragged_paged_attention_kernel,
|
|
1559
1416
|
sm_scale=sm_scale,
|
|
1560
1417
|
sliding_window=sliding_window,
|
|
1561
|
-
strict_sliding_window=strict_sliding_window,
|
|
1562
1418
|
soft_cap=soft_cap,
|
|
1563
1419
|
mask_value=mask_value,
|
|
1564
1420
|
q_scale=q_scale,
|
|
@@ -13,9 +13,9 @@ from tpu_inference import utils
|
|
|
13
13
|
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
|
|
14
14
|
ragged_paged_attention
|
|
15
15
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
16
|
-
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
17
16
|
from tpu_inference.layers.jax.base import create_param
|
|
18
17
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
18
|
+
from tpu_inference.layers.jax.sharding import ShardingAxisName
|
|
19
19
|
|
|
20
20
|
KVCache = Tuple[jax.Array, jax.Array]
|
|
21
21
|
|