tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_envs.py +11 -32
- tests/test_utils.py +2 -1
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +2 -9
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
- tpu_inference/layers/common/attention_interface.py +1 -7
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
- tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +10 -43
- tpu_inference/models/jax/llama3.py +1 -2
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +1 -2
- tpu_inference/models/jax/qwen2_5_vl.py +48 -163
- tpu_inference/models/jax/qwen3.py +1 -2
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
- tpu_inference/platforms/tpu_platform.py +31 -37
- tpu_inference/runner/compilation_manager.py +58 -141
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +147 -271
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +13 -36
- tpu_inference/worker/tpu_worker.py +25 -162
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
|
@@ -267,7 +267,6 @@ def _ragged_paged_attention_kernel(
|
|
|
267
267
|
*,
|
|
268
268
|
sm_scale: float,
|
|
269
269
|
sliding_window: int | None = None,
|
|
270
|
-
strict_sliding_window: bool = True,
|
|
271
270
|
soft_cap: float | None = None,
|
|
272
271
|
mask_value: float = DEFAULT_MASK_VALUE,
|
|
273
272
|
q_scale: float | None = None,
|
|
@@ -318,20 +317,19 @@ def _ragged_paged_attention_kernel(
|
|
|
318
317
|
q_len = q_end - q_start
|
|
319
318
|
kv_len = kv_lens_ref[seq_idx]
|
|
320
319
|
|
|
320
|
+
bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
|
|
321
|
+
kv_len - sliding_window, 0) // bkv_sz
|
|
322
|
+
|
|
321
323
|
if sliding_window is None:
|
|
322
|
-
|
|
324
|
+
next_bkv_idx_start = 0
|
|
323
325
|
else:
|
|
324
|
-
bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
|
|
325
|
-
0) // bkv_sz
|
|
326
326
|
|
|
327
327
|
def get_next_bkv_idx_start():
|
|
328
328
|
next_kv_len = kv_lens_ref[seq_idx + 1]
|
|
329
|
-
|
|
330
|
-
return jnp.maximum(next_kv_len - next_q_len - sliding_window,
|
|
331
|
-
0) // bkv_sz
|
|
329
|
+
return jnp.maximum(next_kv_len - sliding_window, 0) // bkv_sz
|
|
332
330
|
|
|
333
|
-
|
|
334
|
-
|
|
331
|
+
next_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
|
|
332
|
+
get_next_bkv_idx_start, lambda: 0)
|
|
335
333
|
|
|
336
334
|
def debug_print(msg, *args):
|
|
337
335
|
if debug_mode:
|
|
@@ -352,7 +350,7 @@ def _ragged_paged_attention_kernel(
|
|
|
352
350
|
debug_print("[RPA debug] q_len={}", q_len)
|
|
353
351
|
debug_print("[RPA debug] kv_len={}", kv_len)
|
|
354
352
|
|
|
355
|
-
def
|
|
353
|
+
def flash_attention(
|
|
356
354
|
q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
|
|
357
355
|
kv, # [bkv_sz, actual_head_dim_x2]
|
|
358
356
|
*,
|
|
@@ -366,6 +364,7 @@ def _ragged_paged_attention_kernel(
|
|
|
366
364
|
assert kv.shape == (bkv_sz, actual_head_dim_x2)
|
|
367
365
|
head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
|
|
368
366
|
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
|
|
367
|
+
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
|
|
369
368
|
|
|
370
369
|
def load_with_init(ref, init_val):
|
|
371
370
|
return jnp.where(bkv_idx == bkv_idx_start,
|
|
@@ -387,19 +386,16 @@ def _ragged_paged_attention_kernel(
|
|
|
387
386
|
s *= k_scale
|
|
388
387
|
if q_scale is not None:
|
|
389
388
|
s *= q_scale
|
|
390
|
-
if soft_cap is not None:
|
|
391
|
-
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
392
389
|
|
|
393
390
|
q_span = (kv_len - q_len + bq_idx * bq_sz +
|
|
394
391
|
lax.broadcasted_iota(jnp.int32, s.shape, 0) //
|
|
395
392
|
num_q_heads_per_kv_head)
|
|
396
393
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
397
|
-
mask =
|
|
394
|
+
mask = q_span < k_span
|
|
398
395
|
|
|
399
|
-
if
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
s = jnp.where(mask, s, mask_value)
|
|
396
|
+
if soft_cap is not None:
|
|
397
|
+
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
398
|
+
s += jnp.where(mask, mask_value, 0.0)
|
|
403
399
|
s_rowmax = jnp.max(s, axis=1, keepdims=True)
|
|
404
400
|
|
|
405
401
|
if attention_sink_ref is not None:
|
|
@@ -415,33 +411,15 @@ def _ragged_paged_attention_kernel(
|
|
|
415
411
|
head_m_ref[...] = m_curr
|
|
416
412
|
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
|
|
417
413
|
|
|
414
|
+
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
415
|
+
if v_scale is not None:
|
|
416
|
+
pv *= v_scale
|
|
417
|
+
|
|
418
418
|
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
|
|
419
419
|
exp_m_diff = jnp.exp(m_prev - m_curr)
|
|
420
420
|
l_prev = load_with_init(head_l_ref, 1.0)
|
|
421
421
|
l_curr = exp_m_diff * l_prev + p_rowsum
|
|
422
422
|
head_l_ref[...] = l_curr
|
|
423
|
-
|
|
424
|
-
return p, exp_m_diff
|
|
425
|
-
|
|
426
|
-
def flash_attention_step2_pv(
|
|
427
|
-
q_shape_0,
|
|
428
|
-
kv, # [bkv_sz, actual_head_dim_x2]
|
|
429
|
-
p, # from step1
|
|
430
|
-
exp_m_diff, # from step1
|
|
431
|
-
*,
|
|
432
|
-
bkv_idx,
|
|
433
|
-
kv_head_idx,
|
|
434
|
-
):
|
|
435
|
-
head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
|
|
436
|
-
|
|
437
|
-
def load_with_init(ref, init_val):
|
|
438
|
-
return jnp.where(bkv_idx == bkv_idx_start,
|
|
439
|
-
jnp.full_like(ref, init_val), ref[...])
|
|
440
|
-
|
|
441
|
-
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
442
|
-
if v_scale is not None:
|
|
443
|
-
pv *= v_scale
|
|
444
|
-
|
|
445
423
|
o_prev = load_with_init(head_acc_ref, 0.0)
|
|
446
424
|
o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
|
|
447
425
|
head_acc_ref[...] = o_curr
|
|
@@ -456,12 +434,7 @@ def _ragged_paged_attention_kernel(
|
|
|
456
434
|
else:
|
|
457
435
|
cp.start()
|
|
458
436
|
|
|
459
|
-
def _fetch_bkv(seq_idx,
|
|
460
|
-
bkv_idx,
|
|
461
|
-
bkv_sem_idx,
|
|
462
|
-
*,
|
|
463
|
-
is_full_fetch=False,
|
|
464
|
-
wait=False):
|
|
437
|
+
def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
|
|
465
438
|
sem = sems.at[0, bkv_sem_idx]
|
|
466
439
|
vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
|
|
467
440
|
|
|
@@ -502,73 +475,42 @@ def _ragged_paged_attention_kernel(
|
|
|
502
475
|
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
503
476
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
504
477
|
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
wait=False,
|
|
516
|
-
)
|
|
517
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
518
|
-
return offset + sz
|
|
519
|
-
|
|
520
|
-
offset = lax.fori_loop(
|
|
521
|
-
0,
|
|
522
|
-
bkv_p_frm_cache,
|
|
523
|
-
loop_body,
|
|
524
|
-
0, # offset
|
|
525
|
-
unroll=False,
|
|
478
|
+
# Fetch effective kv from kv cache.
|
|
479
|
+
def loop_body(i, offset):
|
|
480
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
481
|
+
_async_copy(
|
|
482
|
+
cache_hbm_ref.at[pl.ds(
|
|
483
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
484
|
+
sz)],
|
|
485
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
486
|
+
sem,
|
|
487
|
+
wait,
|
|
526
488
|
)
|
|
489
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
490
|
+
return offset + sz
|
|
491
|
+
|
|
492
|
+
offset = lax.fori_loop(
|
|
493
|
+
0,
|
|
494
|
+
bkv_p_frm_cache,
|
|
495
|
+
loop_body,
|
|
496
|
+
0, # offset
|
|
497
|
+
unroll=False,
|
|
498
|
+
)
|
|
527
499
|
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
535
|
-
_async_copy(
|
|
536
|
-
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
537
|
-
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
538
|
-
sem,
|
|
539
|
-
wait,
|
|
540
|
-
)
|
|
541
|
-
|
|
542
|
-
# NOTE(chengjiyao): This condition is true for the first two bkv fetches.
|
|
543
|
-
# We need to ensure the bkv_x2_ref VMEM buffer is fully initialized to
|
|
544
|
-
# avoid potential NaN values in regions not overwritten by actual data.
|
|
545
|
-
# This is done by padding the remaining parts of the buffer with data
|
|
546
|
-
# from the KV cache. This special handling is only strictly necessary
|
|
547
|
-
# until both buffers in the double buffer (bkv_x2_ref) have been written
|
|
548
|
-
# to at least once.
|
|
549
|
-
@pl.when(is_full_fetch)
|
|
550
|
-
def _make_sure_bkv_vmem_is_not_nan():
|
|
551
|
-
effective_sz = offset + bkv_sz_frm_new
|
|
552
|
-
remaining_sz = bkv_sz - effective_sz
|
|
553
|
-
_async_copy(
|
|
554
|
-
cache_hbm_ref.at[pl.ds(0, remaining_sz)],
|
|
555
|
-
vmem_ref.at[pl.ds(effective_sz, remaining_sz)],
|
|
556
|
-
sem,
|
|
557
|
-
wait,
|
|
558
|
-
)
|
|
559
|
-
|
|
560
|
-
return kv_len_start + offset, bkv_sz_frm_new
|
|
561
|
-
else:
|
|
562
|
-
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
563
|
-
sz = lax.select(is_full_fetch, bkv_sz, offset + bkv_sz_frm_new)
|
|
564
|
-
dst = vmem_ref.at[pl.ds(0, sz)]
|
|
500
|
+
# Fetch kv directly from new kv.
|
|
501
|
+
@pl.when(bkv_sz_frm_new > 0)
|
|
502
|
+
def _fetch_bkv_from_new_kv():
|
|
503
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
504
|
+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
|
|
505
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
565
506
|
_async_copy(
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
sem
|
|
569
|
-
wait
|
|
507
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
508
|
+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
509
|
+
sem,
|
|
510
|
+
wait,
|
|
570
511
|
)
|
|
571
|
-
|
|
512
|
+
|
|
513
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
572
514
|
|
|
573
515
|
def _update_kv_cache(seq_idx,
|
|
574
516
|
bkv_sem_idx,
|
|
@@ -604,41 +546,30 @@ def _ragged_paged_attention_kernel(
|
|
|
604
546
|
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
605
547
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
606
548
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
612
|
-
|
|
613
|
-
_async_copy(
|
|
614
|
-
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
615
|
-
sz)],
|
|
616
|
-
cache_hbm_ref.at[pl.ds(
|
|
617
|
-
page_indices_ref[page_indices_offset + i] * page_size +
|
|
618
|
-
ignore,
|
|
619
|
-
sz,
|
|
620
|
-
)],
|
|
621
|
-
sem,
|
|
622
|
-
wait=False,
|
|
623
|
-
)
|
|
624
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
625
|
-
return update_sz - sz, 0
|
|
626
|
-
|
|
627
|
-
lax.fori_loop(
|
|
628
|
-
0,
|
|
629
|
-
kv_p_end - kv_p_start,
|
|
630
|
-
loop_body,
|
|
631
|
-
(update_sz, ignore), # total transfer size
|
|
632
|
-
unroll=False,
|
|
633
|
-
)
|
|
634
|
-
else:
|
|
635
|
-
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
549
|
+
def loop_body(i, states):
|
|
550
|
+
update_sz, ignore = states
|
|
551
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
552
|
+
|
|
636
553
|
_async_copy(
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
554
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
|
|
555
|
+
cache_hbm_ref.at[pl.ds(
|
|
556
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
557
|
+
ignore,
|
|
558
|
+
sz,
|
|
559
|
+
)],
|
|
560
|
+
sem,
|
|
561
|
+
wait,
|
|
641
562
|
)
|
|
563
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
564
|
+
return update_sz - sz, 0
|
|
565
|
+
|
|
566
|
+
lax.fori_loop(
|
|
567
|
+
0,
|
|
568
|
+
kv_p_end - kv_p_start,
|
|
569
|
+
loop_body,
|
|
570
|
+
(update_sz, ignore), # total transfer size
|
|
571
|
+
unroll=False,
|
|
572
|
+
)
|
|
642
573
|
|
|
643
574
|
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
644
575
|
sem = sems.at[1, bq_sem_idx]
|
|
@@ -688,18 +619,11 @@ def _ragged_paged_attention_kernel(
|
|
|
688
619
|
wait,
|
|
689
620
|
)
|
|
690
621
|
|
|
691
|
-
def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx
|
|
692
|
-
return _fetch_bkv(seq_idx,
|
|
693
|
-
bkv_idx,
|
|
694
|
-
bkv_sem_idx,
|
|
695
|
-
is_full_fetch=is_full_fetch)
|
|
622
|
+
def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
|
|
623
|
+
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
|
|
696
624
|
|
|
697
|
-
def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx
|
|
698
|
-
return _fetch_bkv(seq_idx,
|
|
699
|
-
bkv_idx,
|
|
700
|
-
bkv_sem_idx,
|
|
701
|
-
is_full_fetch=is_full_fetch,
|
|
702
|
-
wait=True)
|
|
625
|
+
def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
|
|
626
|
+
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, wait=True)
|
|
703
627
|
|
|
704
628
|
def start_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
|
|
705
629
|
return _fetch_bq(seq_idx, bq_idx, bq_sem_idx)
|
|
@@ -757,7 +681,7 @@ def _ragged_paged_attention_kernel(
|
|
|
757
681
|
vec = ref[start::step]
|
|
758
682
|
return vec
|
|
759
683
|
|
|
760
|
-
def strided_load_bkv(bkv_sem_idx, start, step):
|
|
684
|
+
def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
|
|
761
685
|
assert start % kv_packing == 0
|
|
762
686
|
assert step % kv_packing == 0
|
|
763
687
|
start //= kv_packing
|
|
@@ -766,6 +690,7 @@ def _ragged_paged_attention_kernel(
|
|
|
766
690
|
bkv_sz * step, actual_head_dim_x2))
|
|
767
691
|
|
|
768
692
|
kv = strided_load(kv_ref, start, step)
|
|
693
|
+
kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
|
|
769
694
|
bitwidth = 32 // kv_packing
|
|
770
695
|
repack_ty = jnp.dtype(f"uint{bitwidth}")
|
|
771
696
|
lst = []
|
|
@@ -812,17 +737,13 @@ def _ragged_paged_attention_kernel(
|
|
|
812
737
|
next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
|
|
813
738
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
814
739
|
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
next_bkv_start_idx = lax.select(
|
|
740
|
+
next_bkv_idx = lax.select(
|
|
741
|
+
is_last_bkv,
|
|
742
|
+
lax.select(
|
|
819
743
|
is_last_bq,
|
|
820
|
-
|
|
744
|
+
next_bkv_idx_start,
|
|
821
745
|
bkv_idx_start,
|
|
822
|
-
)
|
|
823
|
-
next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
|
|
824
|
-
next_bkv_idx)
|
|
825
|
-
|
|
746
|
+
), next_bkv_idx)
|
|
826
747
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
827
748
|
|
|
828
749
|
def compute_with_bq(bq_idx, _):
|
|
@@ -839,23 +760,22 @@ def _ragged_paged_attention_kernel(
|
|
|
839
760
|
def compute_with_bkv(bkv_idx, _):
|
|
840
761
|
# Create bitmask for KV.
|
|
841
762
|
assert bkv_sz % kv_packing == 0
|
|
763
|
+
actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
|
|
764
|
+
bkv_shape = (bkv_sz, actual_head_dim_x2)
|
|
765
|
+
bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
|
|
766
|
+
0) < actual_bkv_sz
|
|
842
767
|
|
|
843
768
|
# Get next bkv ids.
|
|
844
769
|
bkv_sem_idx = sem_ids_ref[1]
|
|
845
|
-
next_seq_idx,
|
|
846
|
-
|
|
770
|
+
next_seq_idx, _, next_bkv_idx, next_bkv_sem_idx = get_next_bkv_ids(
|
|
771
|
+
seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
|
|
847
772
|
|
|
848
773
|
# Prefetch next bkv
|
|
849
774
|
@pl.when(next_seq_idx < num_seqs)
|
|
850
775
|
def prefetch_next_bkv():
|
|
851
776
|
sem_ids_ref[1] = next_bkv_sem_idx
|
|
852
|
-
start_fetch_bkv(
|
|
853
|
-
|
|
854
|
-
next_bkv_idx,
|
|
855
|
-
next_bkv_sem_idx,
|
|
856
|
-
is_full_fetch=next_seq_idx + next_bq_idx_for_kv +
|
|
857
|
-
next_bkv_idx < 2,
|
|
858
|
-
)
|
|
777
|
+
start_fetch_bkv(next_seq_idx, next_bkv_idx,
|
|
778
|
+
next_bkv_sem_idx)
|
|
859
779
|
|
|
860
780
|
# Wait for cur bq if not ready yet
|
|
861
781
|
@pl.when(bkv_idx == bkv_idx_start)
|
|
@@ -863,12 +783,8 @@ def _ragged_paged_attention_kernel(
|
|
|
863
783
|
wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
|
|
864
784
|
|
|
865
785
|
# Wait for cur bkv
|
|
866
|
-
offset, update_sz = wait_fetch_bkv(
|
|
867
|
-
|
|
868
|
-
bkv_idx,
|
|
869
|
-
bkv_sem_idx,
|
|
870
|
-
is_full_fetch=seq_idx + bq_idx + bkv_idx < 2,
|
|
871
|
-
)
|
|
786
|
+
offset, update_sz = wait_fetch_bkv(seq_idx, bkv_idx,
|
|
787
|
+
bkv_sem_idx)
|
|
872
788
|
|
|
873
789
|
# Start updating bkv to kv cache if applicable.
|
|
874
790
|
# Only needed in first bq loop.
|
|
@@ -887,64 +803,29 @@ def _ragged_paged_attention_kernel(
|
|
|
887
803
|
return
|
|
888
804
|
|
|
889
805
|
# Flash attention with cur bkv and bq
|
|
890
|
-
prev_bq_shape_0 = None
|
|
891
|
-
prev_kv_head_bkv = None
|
|
892
|
-
prev_kv_head_idx = None
|
|
893
|
-
prev_kv_head_p = None
|
|
894
|
-
prev_kv_head_exp_m_diff = None
|
|
895
806
|
for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
|
|
896
807
|
bkv_lst = strided_load_bkv(
|
|
897
808
|
bkv_sem_idx,
|
|
898
809
|
kv_head_start,
|
|
899
810
|
num_kv_heads,
|
|
811
|
+
bkv_mask=bkv_mask,
|
|
900
812
|
)
|
|
901
813
|
assert len(bkv_lst) == kv_packing
|
|
902
814
|
for i in range(kv_packing):
|
|
903
|
-
|
|
904
|
-
if
|
|
815
|
+
kv_head_idx = kv_head_start + i
|
|
816
|
+
if kv_head_idx >= actual_num_kv_heads:
|
|
905
817
|
break
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
cur_kv_head_bq,
|
|
918
|
-
cur_kv_head__bkv,
|
|
919
|
-
bq_idx=bq_idx,
|
|
920
|
-
bkv_idx=bkv_idx,
|
|
921
|
-
kv_head_idx=cur_kv_head_idx,
|
|
922
|
-
))
|
|
923
|
-
if prev_bq_shape_0 is not None:
|
|
924
|
-
flash_attention_step2_pv(
|
|
925
|
-
prev_bq_shape_0,
|
|
926
|
-
prev_kv_head_bkv,
|
|
927
|
-
prev_kv_head_p,
|
|
928
|
-
prev_kv_head_exp_m_diff,
|
|
929
|
-
bkv_idx=bkv_idx,
|
|
930
|
-
kv_head_idx=prev_kv_head_idx,
|
|
931
|
-
)
|
|
932
|
-
prev_bq_shape_0 = cur_kv_head_bq.shape[0]
|
|
933
|
-
prev_kv_head_bkv = cur_kv_head__bkv
|
|
934
|
-
prev_kv_head_p = cur_kv_head_p
|
|
935
|
-
prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
|
|
936
|
-
prev_kv_head_idx = cur_kv_head_idx
|
|
937
|
-
|
|
938
|
-
# Execute pv of last attention head.
|
|
939
|
-
assert prev_bq_shape_0 is not None
|
|
940
|
-
flash_attention_step2_pv(
|
|
941
|
-
prev_bq_shape_0,
|
|
942
|
-
prev_kv_head_bkv,
|
|
943
|
-
prev_kv_head_p,
|
|
944
|
-
prev_kv_head_exp_m_diff,
|
|
945
|
-
bkv_idx=bkv_idx,
|
|
946
|
-
kv_head_idx=prev_kv_head_idx,
|
|
947
|
-
)
|
|
818
|
+
bq = load_bq(bq_sem_idx,
|
|
819
|
+
kv_head_idx,
|
|
820
|
+
actual_bq_sz=actual_bq_sz)
|
|
821
|
+
bkv = bkv_lst[i]
|
|
822
|
+
flash_attention(
|
|
823
|
+
bq,
|
|
824
|
+
bkv,
|
|
825
|
+
bq_idx=bq_idx,
|
|
826
|
+
bkv_idx=bkv_idx,
|
|
827
|
+
kv_head_idx=kv_head_idx,
|
|
828
|
+
)
|
|
948
829
|
|
|
949
830
|
lax.fori_loop(bkv_idx_start,
|
|
950
831
|
num_bkv,
|
|
@@ -980,7 +861,7 @@ def _ragged_paged_attention_kernel(
|
|
|
980
861
|
@pl.when(seq_idx == 0)
|
|
981
862
|
def prologue():
|
|
982
863
|
start_fetch_bq(0, 0, 0)
|
|
983
|
-
start_fetch_bkv(0, bkv_idx_start, 0
|
|
864
|
+
start_fetch_bkv(0, bkv_idx_start, 0)
|
|
984
865
|
|
|
985
866
|
@pl.when(seq_idx < decode_end)
|
|
986
867
|
def process_decode():
|
|
@@ -1345,7 +1226,6 @@ def static_validate_inputs(
|
|
|
1345
1226
|
static_argnames=(
|
|
1346
1227
|
"sm_scale",
|
|
1347
1228
|
"sliding_window",
|
|
1348
|
-
"strict_sliding_window",
|
|
1349
1229
|
"soft_cap",
|
|
1350
1230
|
"mask_value",
|
|
1351
1231
|
"q_scale",
|
|
@@ -1375,7 +1255,6 @@ def ragged_paged_attention_hd64(
|
|
|
1375
1255
|
*,
|
|
1376
1256
|
sm_scale: float = 1.0,
|
|
1377
1257
|
sliding_window: int | None = None,
|
|
1378
|
-
strict_sliding_window: bool = True,
|
|
1379
1258
|
soft_cap: float | None = None,
|
|
1380
1259
|
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
1381
1260
|
q_scale: float | None = None,
|
|
@@ -1390,41 +1269,42 @@ def ragged_paged_attention_hd64(
|
|
|
1390
1269
|
# Debug params.
|
|
1391
1270
|
debug_mode: bool = False,
|
|
1392
1271
|
):
|
|
1393
|
-
"""A
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1272
|
+
"""A special Ragged paged attention version for head_dim=64 that supports mixed
|
|
1273
|
+
|
|
1274
|
+
prefill and decode.
|
|
1275
|
+
|
|
1276
|
+
Args:
|
|
1277
|
+
queries: concatenated all sequences' queries.
|
|
1278
|
+
keys: concatenated all sequences' keys (quantized).
|
|
1279
|
+
values: concatenated all sequences' values (quantized).
|
|
1280
|
+
kv_cache: paged KV cache with TPU-friendly shape.
|
|
1281
|
+
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
|
|
1282
|
+
page_indices: flattened page indices look-up table by (seq_id, page_id).
|
|
1283
|
+
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
|
|
1284
|
+
kv_lens, only the first num_seqs+1 values are valid.
|
|
1285
|
+
distribution: (i, j, k) represents that sequences[0:i] are decode-only,
|
|
1286
|
+
sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
|
|
1287
|
+
k is also the total number of sequences.
|
|
1288
|
+
attention_sink: optional attention sink for each q head.
|
|
1289
|
+
actual_head_dim: the actual head size of the attention. Here we assume k and
|
|
1290
|
+
v have the same actual head size.
|
|
1291
|
+
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
1292
|
+
sliding_window: the sliding window size for the attention.
|
|
1293
|
+
soft_cap: the logit soft cap for the attention.
|
|
1294
|
+
mask_value: mask value for causal mask.
|
|
1295
|
+
k_scale: the scale for the key cache.
|
|
1296
|
+
v_scale: the scale for the value cache.
|
|
1297
|
+
num_kv_pages_per_block: number of kv pages to be processed in one flash
|
|
1298
|
+
attention block in the pallas kernel.
|
|
1299
|
+
num_queries_per_block: number of kv pages to be processed in one flash
|
|
1300
|
+
attention block in the pallas kernel.
|
|
1301
|
+
vmem_limit_bytes: the vmem limit for the pallas kernel.
|
|
1302
|
+
debug_mode: if true, RPA does not issue any DMAs or run flash attention but
|
|
1303
|
+
print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
|
|
1304
|
+
|
|
1305
|
+
Returns:
|
|
1306
|
+
The output of the attention.
|
|
1307
|
+
"""
|
|
1428
1308
|
q, k, v = queries, keys, values
|
|
1429
1309
|
static_validate_inputs(
|
|
1430
1310
|
q,
|
|
@@ -1494,7 +1374,7 @@ def ragged_paged_attention_hd64(
|
|
|
1494
1374
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1495
1375
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1496
1376
|
None if attention_sink is None else pl.BlockSpec(
|
|
1497
|
-
memory_space=pltpu.VMEM)
|
|
1377
|
+
memory_space=pltpu.VMEM)
|
|
1498
1378
|
]
|
|
1499
1379
|
|
|
1500
1380
|
out_specs = [
|
|
@@ -1558,7 +1438,6 @@ def ragged_paged_attention_hd64(
|
|
|
1558
1438
|
_ragged_paged_attention_kernel,
|
|
1559
1439
|
sm_scale=sm_scale,
|
|
1560
1440
|
sliding_window=sliding_window,
|
|
1561
|
-
strict_sliding_window=strict_sliding_window,
|
|
1562
1441
|
soft_cap=soft_cap,
|
|
1563
1442
|
mask_value=mask_value,
|
|
1564
1443
|
q_scale=q_scale,
|
|
@@ -308,13 +308,7 @@ def sharded_ragged_paged_attention(
|
|
|
308
308
|
args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
|
|
309
309
|
|
|
310
310
|
use_hd64 = q.shape[-1] == 64
|
|
311
|
-
|
|
312
|
-
func = ragged_paged_attention
|
|
313
|
-
if use_hd64:
|
|
314
|
-
func = functools.partial(ragged_paged_attention_hd64,
|
|
315
|
-
strict_sliding_window=False)
|
|
316
|
-
else:
|
|
317
|
-
func = ragged_paged_attention
|
|
311
|
+
func = ragged_paged_attention_hd64 if use_hd64 else ragged_paged_attention
|
|
318
312
|
|
|
319
313
|
if attention_sink is not None:
|
|
320
314
|
if not use_hd64:
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import math
|
|
3
|
+
import os
|
|
3
4
|
from dataclasses import asdict, dataclass
|
|
4
5
|
from typing import TYPE_CHECKING, List, Optional
|
|
5
6
|
|
|
@@ -7,7 +8,7 @@ import jax.numpy as jnp
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
from jax.sharding import Mesh
|
|
9
10
|
|
|
10
|
-
from tpu_inference import
|
|
11
|
+
from tpu_inference import utils
|
|
11
12
|
|
|
12
13
|
if TYPE_CHECKING:
|
|
13
14
|
from vllm.v1.configs.vllm_config import VllmConfig
|
|
@@ -47,7 +48,7 @@ class ShardingAxisName2D:
|
|
|
47
48
|
|
|
48
49
|
|
|
49
50
|
try:
|
|
50
|
-
_use_base_sharding =
|
|
51
|
+
_use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False)
|
|
51
52
|
if _use_base_sharding:
|
|
52
53
|
ShardingAxisName = ShardingAxisNameBase
|
|
53
54
|
else:
|
|
@@ -165,10 +166,9 @@ class ShardingConfigManager:
|
|
|
165
166
|
f"LoRA is not supported with data parallelism "
|
|
166
167
|
f"(DP size: {total_dp_size}). Please disable LoRA or "
|
|
167
168
|
f"set data parallelism to 1.")
|
|
168
|
-
|
|
169
|
-
if not envs.NEW_MODEL_DESIGN:
|
|
169
|
+
if not os.environ.get("NEW_MODEL_DESIGN", False):
|
|
170
170
|
raise ValueError(
|
|
171
|
-
"Must run
|
|
171
|
+
"Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
|
|
172
172
|
"NEW_MODEL_DESIGN=True.")
|
|
173
173
|
|
|
174
174
|
@property
|