tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tests/test_envs.py +32 -11
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +61 -8
- tpu_inference/executors/ray_distributed_executor.py +31 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +45 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
- tpu_inference/platforms/tpu_platform.py +28 -22
- tpu_inference/runner/compilation_manager.py +144 -59
- tpu_inference/runner/kv_cache_manager.py +17 -18
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +271 -147
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +36 -13
- tpu_inference/worker/tpu_worker.py +162 -25
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
|
@@ -440,42 +440,54 @@ def _ragged_paged_attention_kernel(
|
|
|
440
440
|
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
441
441
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
442
442
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
443
|
+
if not wait:
|
|
444
|
+
# Fetch effective kv from kv cache.
|
|
445
|
+
def loop_body(i, offset):
|
|
446
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
447
|
+
_async_copy(
|
|
448
|
+
cache_hbm_ref.at[pl.ds(
|
|
449
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
450
|
+
sz)],
|
|
451
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
452
|
+
sem,
|
|
453
|
+
wait=False,
|
|
454
|
+
)
|
|
455
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
456
|
+
return offset + sz
|
|
457
|
+
|
|
458
|
+
offset = lax.fori_loop(
|
|
459
|
+
0,
|
|
460
|
+
bkv_p_frm_cache,
|
|
461
|
+
loop_body,
|
|
462
|
+
0, # offset
|
|
463
|
+
unroll=False,
|
|
453
464
|
)
|
|
454
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
455
|
-
return offset + sz
|
|
456
|
-
|
|
457
|
-
offset = lax.fori_loop(
|
|
458
|
-
0,
|
|
459
|
-
bkv_p_frm_cache,
|
|
460
|
-
loop_body,
|
|
461
|
-
0, # offset
|
|
462
|
-
unroll=False,
|
|
463
|
-
)
|
|
464
465
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
466
|
+
# Fetch kv directly from new kv.
|
|
467
|
+
@pl.when(bkv_sz_frm_new > 0)
|
|
468
|
+
def _fetch_bkv_from_new_kv():
|
|
469
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
470
|
+
debug_print("[RPA debug] new_kv_len_start={}",
|
|
471
|
+
new_kv_len_start)
|
|
472
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
473
|
+
_async_copy(
|
|
474
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
475
|
+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
476
|
+
sem,
|
|
477
|
+
wait,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
481
|
+
else:
|
|
482
|
+
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
483
|
+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
|
|
471
484
|
_async_copy(
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
sem,
|
|
475
|
-
wait,
|
|
485
|
+
src=dst,
|
|
486
|
+
dst=dst,
|
|
487
|
+
sem=sem,
|
|
488
|
+
wait=True,
|
|
476
489
|
)
|
|
477
|
-
|
|
478
|
-
return kv_len_start + offset, bkv_sz_frm_new
|
|
490
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
479
491
|
|
|
480
492
|
def _update_kv_cache(seq_idx,
|
|
481
493
|
bkv_sem_idx,
|
|
@@ -511,30 +523,41 @@ def _ragged_paged_attention_kernel(
|
|
|
511
523
|
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
512
524
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
513
525
|
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
526
|
+
if not wait:
|
|
527
|
+
|
|
528
|
+
def loop_body(i, states):
|
|
529
|
+
update_sz, ignore = states
|
|
530
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
531
|
+
|
|
532
|
+
_async_copy(
|
|
533
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
534
|
+
sz)],
|
|
535
|
+
cache_hbm_ref.at[pl.ds(
|
|
536
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
537
|
+
ignore,
|
|
538
|
+
sz,
|
|
539
|
+
)],
|
|
540
|
+
sem,
|
|
541
|
+
wait=False,
|
|
542
|
+
)
|
|
543
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
544
|
+
return update_sz - sz, 0
|
|
545
|
+
|
|
546
|
+
lax.fori_loop(
|
|
547
|
+
0,
|
|
548
|
+
kv_p_end - kv_p_start,
|
|
549
|
+
loop_body,
|
|
550
|
+
(update_sz, ignore), # total transfer size
|
|
551
|
+
unroll=False,
|
|
552
|
+
)
|
|
553
|
+
else:
|
|
554
|
+
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
518
555
|
_async_copy(
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
sz,
|
|
524
|
-
)],
|
|
525
|
-
sem,
|
|
526
|
-
wait,
|
|
556
|
+
src=dst,
|
|
557
|
+
dst=dst,
|
|
558
|
+
sem=sem,
|
|
559
|
+
wait=True,
|
|
527
560
|
)
|
|
528
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
529
|
-
return update_sz - sz, 0
|
|
530
|
-
|
|
531
|
-
lax.fori_loop(
|
|
532
|
-
0,
|
|
533
|
-
kv_p_end - kv_p_start,
|
|
534
|
-
loop_body,
|
|
535
|
-
(update_sz, ignore), # total transfer size
|
|
536
|
-
unroll=False,
|
|
537
|
-
)
|
|
538
561
|
|
|
539
562
|
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
540
563
|
sem = sems.at[1, bq_sem_idx]
|
|
@@ -267,6 +267,7 @@ def _ragged_paged_attention_kernel(
|
|
|
267
267
|
*,
|
|
268
268
|
sm_scale: float,
|
|
269
269
|
sliding_window: int | None = None,
|
|
270
|
+
strict_sliding_window: bool = True,
|
|
270
271
|
soft_cap: float | None = None,
|
|
271
272
|
mask_value: float = DEFAULT_MASK_VALUE,
|
|
272
273
|
q_scale: float | None = None,
|
|
@@ -317,19 +318,20 @@ def _ragged_paged_attention_kernel(
|
|
|
317
318
|
q_len = q_end - q_start
|
|
318
319
|
kv_len = kv_lens_ref[seq_idx]
|
|
319
320
|
|
|
320
|
-
bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
|
|
321
|
-
kv_len - sliding_window, 0) // bkv_sz
|
|
322
|
-
|
|
323
321
|
if sliding_window is None:
|
|
324
|
-
|
|
322
|
+
bkv_idx_start = next_seq_bkv_idx_start = 0
|
|
325
323
|
else:
|
|
324
|
+
bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
|
|
325
|
+
0) // bkv_sz
|
|
326
326
|
|
|
327
327
|
def get_next_bkv_idx_start():
|
|
328
328
|
next_kv_len = kv_lens_ref[seq_idx + 1]
|
|
329
|
-
|
|
329
|
+
next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
|
|
330
|
+
return jnp.maximum(next_kv_len - next_q_len - sliding_window,
|
|
331
|
+
0) // bkv_sz
|
|
330
332
|
|
|
331
|
-
|
|
332
|
-
|
|
333
|
+
next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
|
|
334
|
+
get_next_bkv_idx_start, lambda: 0)
|
|
333
335
|
|
|
334
336
|
def debug_print(msg, *args):
|
|
335
337
|
if debug_mode:
|
|
@@ -350,7 +352,7 @@ def _ragged_paged_attention_kernel(
|
|
|
350
352
|
debug_print("[RPA debug] q_len={}", q_len)
|
|
351
353
|
debug_print("[RPA debug] kv_len={}", kv_len)
|
|
352
354
|
|
|
353
|
-
def
|
|
355
|
+
def flash_attention_step1_qk_softmax(
|
|
354
356
|
q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
|
|
355
357
|
kv, # [bkv_sz, actual_head_dim_x2]
|
|
356
358
|
*,
|
|
@@ -364,7 +366,6 @@ def _ragged_paged_attention_kernel(
|
|
|
364
366
|
assert kv.shape == (bkv_sz, actual_head_dim_x2)
|
|
365
367
|
head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
|
|
366
368
|
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
|
|
367
|
-
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
|
|
368
369
|
|
|
369
370
|
def load_with_init(ref, init_val):
|
|
370
371
|
return jnp.where(bkv_idx == bkv_idx_start,
|
|
@@ -386,16 +387,19 @@ def _ragged_paged_attention_kernel(
|
|
|
386
387
|
s *= k_scale
|
|
387
388
|
if q_scale is not None:
|
|
388
389
|
s *= q_scale
|
|
390
|
+
if soft_cap is not None:
|
|
391
|
+
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
389
392
|
|
|
390
393
|
q_span = (kv_len - q_len + bq_idx * bq_sz +
|
|
391
394
|
lax.broadcasted_iota(jnp.int32, s.shape, 0) //
|
|
392
395
|
num_q_heads_per_kv_head)
|
|
393
396
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
394
|
-
mask =
|
|
397
|
+
mask = k_span <= q_span
|
|
395
398
|
|
|
396
|
-
if
|
|
397
|
-
|
|
398
|
-
|
|
399
|
+
if sliding_window is not None and strict_sliding_window:
|
|
400
|
+
mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
|
|
401
|
+
|
|
402
|
+
s = jnp.where(mask, s, mask_value)
|
|
399
403
|
s_rowmax = jnp.max(s, axis=1, keepdims=True)
|
|
400
404
|
|
|
401
405
|
if attention_sink_ref is not None:
|
|
@@ -411,15 +415,33 @@ def _ragged_paged_attention_kernel(
|
|
|
411
415
|
head_m_ref[...] = m_curr
|
|
412
416
|
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
|
|
413
417
|
|
|
414
|
-
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
415
|
-
if v_scale is not None:
|
|
416
|
-
pv *= v_scale
|
|
417
|
-
|
|
418
418
|
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
|
|
419
419
|
exp_m_diff = jnp.exp(m_prev - m_curr)
|
|
420
420
|
l_prev = load_with_init(head_l_ref, 1.0)
|
|
421
421
|
l_curr = exp_m_diff * l_prev + p_rowsum
|
|
422
422
|
head_l_ref[...] = l_curr
|
|
423
|
+
|
|
424
|
+
return p, exp_m_diff
|
|
425
|
+
|
|
426
|
+
def flash_attention_step2_pv(
|
|
427
|
+
q_shape_0,
|
|
428
|
+
kv, # [bkv_sz, actual_head_dim_x2]
|
|
429
|
+
p, # from step1
|
|
430
|
+
exp_m_diff, # from step1
|
|
431
|
+
*,
|
|
432
|
+
bkv_idx,
|
|
433
|
+
kv_head_idx,
|
|
434
|
+
):
|
|
435
|
+
head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
|
|
436
|
+
|
|
437
|
+
def load_with_init(ref, init_val):
|
|
438
|
+
return jnp.where(bkv_idx == bkv_idx_start,
|
|
439
|
+
jnp.full_like(ref, init_val), ref[...])
|
|
440
|
+
|
|
441
|
+
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
442
|
+
if v_scale is not None:
|
|
443
|
+
pv *= v_scale
|
|
444
|
+
|
|
423
445
|
o_prev = load_with_init(head_acc_ref, 0.0)
|
|
424
446
|
o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
|
|
425
447
|
head_acc_ref[...] = o_curr
|
|
@@ -475,42 +497,54 @@ def _ragged_paged_attention_kernel(
|
|
|
475
497
|
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
476
498
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
477
499
|
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
500
|
+
if not wait:
|
|
501
|
+
# Fetch effective kv from kv cache.
|
|
502
|
+
def loop_body(i, offset):
|
|
503
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
504
|
+
_async_copy(
|
|
505
|
+
cache_hbm_ref.at[pl.ds(
|
|
506
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
507
|
+
sz)],
|
|
508
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
509
|
+
sem,
|
|
510
|
+
wait=False,
|
|
511
|
+
)
|
|
512
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
513
|
+
return offset + sz
|
|
514
|
+
|
|
515
|
+
offset = lax.fori_loop(
|
|
516
|
+
0,
|
|
517
|
+
bkv_p_frm_cache,
|
|
518
|
+
loop_body,
|
|
519
|
+
0, # offset
|
|
520
|
+
unroll=False,
|
|
488
521
|
)
|
|
489
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
490
|
-
return offset + sz
|
|
491
|
-
|
|
492
|
-
offset = lax.fori_loop(
|
|
493
|
-
0,
|
|
494
|
-
bkv_p_frm_cache,
|
|
495
|
-
loop_body,
|
|
496
|
-
0, # offset
|
|
497
|
-
unroll=False,
|
|
498
|
-
)
|
|
499
522
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
523
|
+
# Fetch kv directly from new kv.
|
|
524
|
+
@pl.when(bkv_sz_frm_new > 0)
|
|
525
|
+
def _fetch_bkv_from_new_kv():
|
|
526
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
527
|
+
debug_print("[RPA debug] new_kv_len_start={}",
|
|
528
|
+
new_kv_len_start)
|
|
529
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
530
|
+
_async_copy(
|
|
531
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
532
|
+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
533
|
+
sem,
|
|
534
|
+
wait,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
538
|
+
else:
|
|
539
|
+
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
540
|
+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
|
|
506
541
|
_async_copy(
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
sem,
|
|
510
|
-
wait,
|
|
542
|
+
src=dst,
|
|
543
|
+
dst=dst,
|
|
544
|
+
sem=sem,
|
|
545
|
+
wait=True,
|
|
511
546
|
)
|
|
512
|
-
|
|
513
|
-
return kv_len_start + offset, bkv_sz_frm_new
|
|
547
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
514
548
|
|
|
515
549
|
def _update_kv_cache(seq_idx,
|
|
516
550
|
bkv_sem_idx,
|
|
@@ -546,30 +580,41 @@ def _ragged_paged_attention_kernel(
|
|
|
546
580
|
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
547
581
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
548
582
|
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
583
|
+
if not wait:
|
|
584
|
+
|
|
585
|
+
def loop_body(i, states):
|
|
586
|
+
update_sz, ignore = states
|
|
587
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
588
|
+
|
|
589
|
+
_async_copy(
|
|
590
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
591
|
+
sz)],
|
|
592
|
+
cache_hbm_ref.at[pl.ds(
|
|
593
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
594
|
+
ignore,
|
|
595
|
+
sz,
|
|
596
|
+
)],
|
|
597
|
+
sem,
|
|
598
|
+
wait=False,
|
|
599
|
+
)
|
|
600
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
601
|
+
return update_sz - sz, 0
|
|
602
|
+
|
|
603
|
+
lax.fori_loop(
|
|
604
|
+
0,
|
|
605
|
+
kv_p_end - kv_p_start,
|
|
606
|
+
loop_body,
|
|
607
|
+
(update_sz, ignore), # total transfer size
|
|
608
|
+
unroll=False,
|
|
609
|
+
)
|
|
610
|
+
else:
|
|
611
|
+
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
553
612
|
_async_copy(
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
sz,
|
|
559
|
-
)],
|
|
560
|
-
sem,
|
|
561
|
-
wait,
|
|
613
|
+
src=dst,
|
|
614
|
+
dst=dst,
|
|
615
|
+
sem=sem,
|
|
616
|
+
wait=True,
|
|
562
617
|
)
|
|
563
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
564
|
-
return update_sz - sz, 0
|
|
565
|
-
|
|
566
|
-
lax.fori_loop(
|
|
567
|
-
0,
|
|
568
|
-
kv_p_end - kv_p_start,
|
|
569
|
-
loop_body,
|
|
570
|
-
(update_sz, ignore), # total transfer size
|
|
571
|
-
unroll=False,
|
|
572
|
-
)
|
|
573
618
|
|
|
574
619
|
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
575
620
|
sem = sems.at[1, bq_sem_idx]
|
|
@@ -737,13 +782,17 @@ def _ragged_paged_attention_kernel(
|
|
|
737
782
|
next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
|
|
738
783
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
739
784
|
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
785
|
+
if sliding_window is None:
|
|
786
|
+
next_bkv_start_idx = 0
|
|
787
|
+
else:
|
|
788
|
+
next_bkv_start_idx = lax.select(
|
|
743
789
|
is_last_bq,
|
|
744
|
-
|
|
790
|
+
next_seq_bkv_idx_start,
|
|
745
791
|
bkv_idx_start,
|
|
746
|
-
)
|
|
792
|
+
)
|
|
793
|
+
next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
|
|
794
|
+
next_bkv_idx)
|
|
795
|
+
|
|
747
796
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
748
797
|
|
|
749
798
|
def compute_with_bq(bq_idx, _):
|
|
@@ -803,6 +852,11 @@ def _ragged_paged_attention_kernel(
|
|
|
803
852
|
return
|
|
804
853
|
|
|
805
854
|
# Flash attention with cur bkv and bq
|
|
855
|
+
prev_bq_shape_0 = None
|
|
856
|
+
prev_kv_head_bkv = None
|
|
857
|
+
prev_kv_head_idx = None
|
|
858
|
+
prev_kv_head_p = None
|
|
859
|
+
prev_kv_head_exp_m_diff = None
|
|
806
860
|
for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
|
|
807
861
|
bkv_lst = strided_load_bkv(
|
|
808
862
|
bkv_sem_idx,
|
|
@@ -812,20 +866,51 @@ def _ragged_paged_attention_kernel(
|
|
|
812
866
|
)
|
|
813
867
|
assert len(bkv_lst) == kv_packing
|
|
814
868
|
for i in range(kv_packing):
|
|
815
|
-
|
|
816
|
-
if
|
|
869
|
+
cur_kv_head_idx = kv_head_start + i
|
|
870
|
+
if cur_kv_head_idx >= actual_num_kv_heads:
|
|
817
871
|
break
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
872
|
+
cur_kv_head_bq = load_bq(bq_sem_idx,
|
|
873
|
+
cur_kv_head_idx,
|
|
874
|
+
actual_bq_sz=actual_bq_sz)
|
|
875
|
+
cur_kv_head__bkv = bkv_lst[i]
|
|
876
|
+
# FlashAttention is divided into `flash_attention_step1_qk_softmax`
|
|
877
|
+
# and `flash_attention_step2_pv` to pipeline the computation.
|
|
878
|
+
# `step2_pv` for the previous KV head, which depends on the softmax
|
|
879
|
+
# output, is overlapped with `step1_qk_softmax` for the current KV
|
|
880
|
+
# head, reducing overall wait times.
|
|
881
|
+
cur_kv_head_p, cur_kv_head_exp_m_diff = (
|
|
882
|
+
flash_attention_step1_qk_softmax(
|
|
883
|
+
cur_kv_head_bq,
|
|
884
|
+
cur_kv_head__bkv,
|
|
885
|
+
bq_idx=bq_idx,
|
|
886
|
+
bkv_idx=bkv_idx,
|
|
887
|
+
kv_head_idx=cur_kv_head_idx,
|
|
888
|
+
))
|
|
889
|
+
if prev_bq_shape_0 is not None:
|
|
890
|
+
flash_attention_step2_pv(
|
|
891
|
+
prev_bq_shape_0,
|
|
892
|
+
prev_kv_head_bkv,
|
|
893
|
+
prev_kv_head_p,
|
|
894
|
+
prev_kv_head_exp_m_diff,
|
|
895
|
+
bkv_idx=bkv_idx,
|
|
896
|
+
kv_head_idx=prev_kv_head_idx,
|
|
897
|
+
)
|
|
898
|
+
prev_bq_shape_0 = cur_kv_head_bq.shape[0]
|
|
899
|
+
prev_kv_head_bkv = cur_kv_head__bkv
|
|
900
|
+
prev_kv_head_p = cur_kv_head_p
|
|
901
|
+
prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
|
|
902
|
+
prev_kv_head_idx = cur_kv_head_idx
|
|
903
|
+
|
|
904
|
+
# Execute pv of last attention head.
|
|
905
|
+
assert prev_bq_shape_0 is not None
|
|
906
|
+
flash_attention_step2_pv(
|
|
907
|
+
prev_bq_shape_0,
|
|
908
|
+
prev_kv_head_bkv,
|
|
909
|
+
prev_kv_head_p,
|
|
910
|
+
prev_kv_head_exp_m_diff,
|
|
911
|
+
bkv_idx=bkv_idx,
|
|
912
|
+
kv_head_idx=prev_kv_head_idx,
|
|
913
|
+
)
|
|
829
914
|
|
|
830
915
|
lax.fori_loop(bkv_idx_start,
|
|
831
916
|
num_bkv,
|
|
@@ -1226,6 +1311,7 @@ def static_validate_inputs(
|
|
|
1226
1311
|
static_argnames=(
|
|
1227
1312
|
"sm_scale",
|
|
1228
1313
|
"sliding_window",
|
|
1314
|
+
"strict_sliding_window",
|
|
1229
1315
|
"soft_cap",
|
|
1230
1316
|
"mask_value",
|
|
1231
1317
|
"q_scale",
|
|
@@ -1255,6 +1341,7 @@ def ragged_paged_attention_hd64(
|
|
|
1255
1341
|
*,
|
|
1256
1342
|
sm_scale: float = 1.0,
|
|
1257
1343
|
sliding_window: int | None = None,
|
|
1344
|
+
strict_sliding_window: bool = True,
|
|
1258
1345
|
soft_cap: float | None = None,
|
|
1259
1346
|
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
1260
1347
|
q_scale: float | None = None,
|
|
@@ -1269,42 +1356,41 @@ def ragged_paged_attention_hd64(
|
|
|
1269
1356
|
# Debug params.
|
|
1270
1357
|
debug_mode: bool = False,
|
|
1271
1358
|
):
|
|
1272
|
-
"""A
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
"""
|
|
1359
|
+
"""A variant of ragged paged attention for head_dim=64.
|
|
1360
|
+
|
|
1361
|
+
Args:
|
|
1362
|
+
queries: concatenated all sequences' queries.
|
|
1363
|
+
keys: concatenated all sequences' keys (quantized).
|
|
1364
|
+
values: concatenated all sequences' values (quantized).
|
|
1365
|
+
kv_cache: paged KV cache with TPU-friendly shape.
|
|
1366
|
+
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
|
|
1367
|
+
page_indices: flattened page indices look-up table by (seq_id, page_id).
|
|
1368
|
+
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
|
|
1369
|
+
kv_lens, only the first num_seqs+1 values are valid.
|
|
1370
|
+
distribution: (i, j, k) represents that sequences[0:i] are decode-only,
|
|
1371
|
+
sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
|
|
1372
|
+
k is also the total number of sequences.
|
|
1373
|
+
attention_sink: optional attention sink for each q head.
|
|
1374
|
+
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
1375
|
+
sliding_window: the sliding window size for the attention.
|
|
1376
|
+
strict_sliding_window: compute tokens that are strictly within the window.
|
|
1377
|
+
soft_cap: the logit soft cap for the attention.
|
|
1378
|
+
mask_value: mask value for causal mask.
|
|
1379
|
+
q_scale: the scale for the query.
|
|
1380
|
+
k_scale: the scale for the key cache.
|
|
1381
|
+
v_scale: the scale for the value cache.
|
|
1382
|
+
chunk_prefill_size: the chunk prefill size for the attention.
|
|
1383
|
+
num_kv_pages_per_block: number of kv pages to be processed in one flash
|
|
1384
|
+
attention block in the pallas kernel.
|
|
1385
|
+
num_queries_per_block: number of kv pages to be processed in one flash
|
|
1386
|
+
attention block in the pallas kernel.
|
|
1387
|
+
vmem_limit_bytes: the vmem limit for the pallas kernel.
|
|
1388
|
+
debug_mode: if true, RPA does not issue any DMAs or run flash attention but
|
|
1389
|
+
print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
|
|
1390
|
+
|
|
1391
|
+
Returns:
|
|
1392
|
+
The output of the attention.
|
|
1393
|
+
"""
|
|
1308
1394
|
q, k, v = queries, keys, values
|
|
1309
1395
|
static_validate_inputs(
|
|
1310
1396
|
q,
|
|
@@ -1374,7 +1460,7 @@ def ragged_paged_attention_hd64(
|
|
|
1374
1460
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1375
1461
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1376
1462
|
None if attention_sink is None else pl.BlockSpec(
|
|
1377
|
-
memory_space=pltpu.VMEM)
|
|
1463
|
+
memory_space=pltpu.VMEM),
|
|
1378
1464
|
]
|
|
1379
1465
|
|
|
1380
1466
|
out_specs = [
|
|
@@ -1438,6 +1524,7 @@ def ragged_paged_attention_hd64(
|
|
|
1438
1524
|
_ragged_paged_attention_kernel,
|
|
1439
1525
|
sm_scale=sm_scale,
|
|
1440
1526
|
sliding_window=sliding_window,
|
|
1527
|
+
strict_sliding_window=strict_sliding_window,
|
|
1441
1528
|
soft_cap=soft_cap,
|
|
1442
1529
|
mask_value=mask_value,
|
|
1443
1530
|
q_scale=q_scale,
|
|
@@ -308,7 +308,13 @@ def sharded_ragged_paged_attention(
|
|
|
308
308
|
args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
|
|
309
309
|
|
|
310
310
|
use_hd64 = q.shape[-1] == 64
|
|
311
|
-
|
|
311
|
+
|
|
312
|
+
func = ragged_paged_attention
|
|
313
|
+
if use_hd64:
|
|
314
|
+
func = functools.partial(ragged_paged_attention_hd64,
|
|
315
|
+
strict_sliding_window=False)
|
|
316
|
+
else:
|
|
317
|
+
func = ragged_paged_attention
|
|
312
318
|
|
|
313
319
|
if attention_sink is not None:
|
|
314
320
|
if not use_hd64:
|