tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -7
- tests/lora/test_lora_perf.py +53 -0
- tests/lora/utils.py +0 -8
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +93 -9
- tpu_inference/executors/ray_distributed_executor.py +9 -2
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +84 -28
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +205 -144
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
- tpu_inference/platforms/tpu_platform.py +34 -50
- tpu_inference/runner/compilation_manager.py +144 -60
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +48 -33
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +280 -149
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +46 -18
- tpu_inference/worker/tpu_worker.py +197 -63
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -267,6 +267,7 @@ def _ragged_paged_attention_kernel(
|
|
|
267
267
|
*,
|
|
268
268
|
sm_scale: float,
|
|
269
269
|
sliding_window: int | None = None,
|
|
270
|
+
strict_sliding_window: bool = True,
|
|
270
271
|
soft_cap: float | None = None,
|
|
271
272
|
mask_value: float = DEFAULT_MASK_VALUE,
|
|
272
273
|
q_scale: float | None = None,
|
|
@@ -317,19 +318,20 @@ def _ragged_paged_attention_kernel(
|
|
|
317
318
|
q_len = q_end - q_start
|
|
318
319
|
kv_len = kv_lens_ref[seq_idx]
|
|
319
320
|
|
|
320
|
-
bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
|
|
321
|
-
kv_len - sliding_window, 0) // bkv_sz
|
|
322
|
-
|
|
323
321
|
if sliding_window is None:
|
|
324
|
-
|
|
322
|
+
bkv_idx_start = next_seq_bkv_idx_start = 0
|
|
325
323
|
else:
|
|
324
|
+
bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
|
|
325
|
+
0) // bkv_sz
|
|
326
326
|
|
|
327
327
|
def get_next_bkv_idx_start():
|
|
328
328
|
next_kv_len = kv_lens_ref[seq_idx + 1]
|
|
329
|
-
|
|
329
|
+
next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
|
|
330
|
+
return jnp.maximum(next_kv_len - next_q_len - sliding_window,
|
|
331
|
+
0) // bkv_sz
|
|
330
332
|
|
|
331
|
-
|
|
332
|
-
|
|
333
|
+
next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
|
|
334
|
+
get_next_bkv_idx_start, lambda: 0)
|
|
333
335
|
|
|
334
336
|
def debug_print(msg, *args):
|
|
335
337
|
if debug_mode:
|
|
@@ -350,7 +352,7 @@ def _ragged_paged_attention_kernel(
|
|
|
350
352
|
debug_print("[RPA debug] q_len={}", q_len)
|
|
351
353
|
debug_print("[RPA debug] kv_len={}", kv_len)
|
|
352
354
|
|
|
353
|
-
def
|
|
355
|
+
def flash_attention_step1_qk_softmax(
|
|
354
356
|
q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
|
|
355
357
|
kv, # [bkv_sz, actual_head_dim_x2]
|
|
356
358
|
*,
|
|
@@ -364,7 +366,6 @@ def _ragged_paged_attention_kernel(
|
|
|
364
366
|
assert kv.shape == (bkv_sz, actual_head_dim_x2)
|
|
365
367
|
head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
|
|
366
368
|
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
|
|
367
|
-
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
|
|
368
369
|
|
|
369
370
|
def load_with_init(ref, init_val):
|
|
370
371
|
return jnp.where(bkv_idx == bkv_idx_start,
|
|
@@ -386,16 +387,19 @@ def _ragged_paged_attention_kernel(
|
|
|
386
387
|
s *= k_scale
|
|
387
388
|
if q_scale is not None:
|
|
388
389
|
s *= q_scale
|
|
390
|
+
if soft_cap is not None:
|
|
391
|
+
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
389
392
|
|
|
390
393
|
q_span = (kv_len - q_len + bq_idx * bq_sz +
|
|
391
394
|
lax.broadcasted_iota(jnp.int32, s.shape, 0) //
|
|
392
395
|
num_q_heads_per_kv_head)
|
|
393
396
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
394
|
-
mask =
|
|
397
|
+
mask = k_span <= q_span
|
|
395
398
|
|
|
396
|
-
if
|
|
397
|
-
|
|
398
|
-
|
|
399
|
+
if sliding_window is not None and strict_sliding_window:
|
|
400
|
+
mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
|
|
401
|
+
|
|
402
|
+
s = jnp.where(mask, s, mask_value)
|
|
399
403
|
s_rowmax = jnp.max(s, axis=1, keepdims=True)
|
|
400
404
|
|
|
401
405
|
if attention_sink_ref is not None:
|
|
@@ -411,15 +415,33 @@ def _ragged_paged_attention_kernel(
|
|
|
411
415
|
head_m_ref[...] = m_curr
|
|
412
416
|
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
|
|
413
417
|
|
|
414
|
-
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
415
|
-
if v_scale is not None:
|
|
416
|
-
pv *= v_scale
|
|
417
|
-
|
|
418
418
|
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
|
|
419
419
|
exp_m_diff = jnp.exp(m_prev - m_curr)
|
|
420
420
|
l_prev = load_with_init(head_l_ref, 1.0)
|
|
421
421
|
l_curr = exp_m_diff * l_prev + p_rowsum
|
|
422
422
|
head_l_ref[...] = l_curr
|
|
423
|
+
|
|
424
|
+
return p, exp_m_diff
|
|
425
|
+
|
|
426
|
+
def flash_attention_step2_pv(
|
|
427
|
+
q_shape_0,
|
|
428
|
+
kv, # [bkv_sz, actual_head_dim_x2]
|
|
429
|
+
p, # from step1
|
|
430
|
+
exp_m_diff, # from step1
|
|
431
|
+
*,
|
|
432
|
+
bkv_idx,
|
|
433
|
+
kv_head_idx,
|
|
434
|
+
):
|
|
435
|
+
head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
|
|
436
|
+
|
|
437
|
+
def load_with_init(ref, init_val):
|
|
438
|
+
return jnp.where(bkv_idx == bkv_idx_start,
|
|
439
|
+
jnp.full_like(ref, init_val), ref[...])
|
|
440
|
+
|
|
441
|
+
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
442
|
+
if v_scale is not None:
|
|
443
|
+
pv *= v_scale
|
|
444
|
+
|
|
423
445
|
o_prev = load_with_init(head_acc_ref, 0.0)
|
|
424
446
|
o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
|
|
425
447
|
head_acc_ref[...] = o_curr
|
|
@@ -475,42 +497,51 @@ def _ragged_paged_attention_kernel(
|
|
|
475
497
|
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
476
498
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
477
499
|
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
500
|
+
if not wait:
|
|
501
|
+
# Fetch effective kv from kv cache.
|
|
502
|
+
def loop_body(i, offset):
|
|
503
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
504
|
+
_async_copy(
|
|
505
|
+
cache_hbm_ref.at[pl.ds(
|
|
506
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
507
|
+
sz)],
|
|
508
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
509
|
+
sem,
|
|
510
|
+
wait=False,
|
|
511
|
+
)
|
|
512
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
513
|
+
return offset + sz
|
|
514
|
+
|
|
515
|
+
offset = lax.fori_loop(
|
|
516
|
+
0,
|
|
517
|
+
bkv_p_frm_cache,
|
|
518
|
+
loop_body,
|
|
519
|
+
0, # offset
|
|
520
|
+
unroll=False,
|
|
488
521
|
)
|
|
489
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
490
|
-
return offset + sz
|
|
491
|
-
|
|
492
|
-
offset = lax.fori_loop(
|
|
493
|
-
0,
|
|
494
|
-
bkv_p_frm_cache,
|
|
495
|
-
loop_body,
|
|
496
|
-
0, # offset
|
|
497
|
-
unroll=False,
|
|
498
|
-
)
|
|
499
522
|
|
|
500
|
-
|
|
501
|
-
@pl.when(bkv_sz_frm_new > 0)
|
|
502
|
-
def _fetch_bkv_from_new_kv():
|
|
523
|
+
size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
|
|
503
524
|
new_kv_len_start = q_end - kv_left_frm_new
|
|
504
525
|
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
|
|
505
526
|
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
506
527
|
_async_copy(
|
|
507
|
-
kv_hbm_ref.at[pl.ds(new_kv_len_start,
|
|
508
|
-
vmem_ref.at[pl.ds(offset,
|
|
528
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
|
|
529
|
+
vmem_ref.at[pl.ds(offset, size)],
|
|
509
530
|
sem,
|
|
510
531
|
wait,
|
|
511
532
|
)
|
|
512
533
|
|
|
513
|
-
|
|
534
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
535
|
+
else:
|
|
536
|
+
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
537
|
+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
|
|
538
|
+
_async_copy(
|
|
539
|
+
src=dst,
|
|
540
|
+
dst=dst,
|
|
541
|
+
sem=sem,
|
|
542
|
+
wait=True,
|
|
543
|
+
)
|
|
544
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
514
545
|
|
|
515
546
|
def _update_kv_cache(seq_idx,
|
|
516
547
|
bkv_sem_idx,
|
|
@@ -546,30 +577,41 @@ def _ragged_paged_attention_kernel(
|
|
|
546
577
|
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
547
578
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
548
579
|
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
580
|
+
if not wait:
|
|
581
|
+
|
|
582
|
+
def loop_body(i, states):
|
|
583
|
+
update_sz, ignore = states
|
|
584
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
585
|
+
|
|
586
|
+
_async_copy(
|
|
587
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
588
|
+
sz)],
|
|
589
|
+
cache_hbm_ref.at[pl.ds(
|
|
590
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
591
|
+
ignore,
|
|
592
|
+
sz,
|
|
593
|
+
)],
|
|
594
|
+
sem,
|
|
595
|
+
wait=False,
|
|
596
|
+
)
|
|
597
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
598
|
+
return update_sz - sz, 0
|
|
599
|
+
|
|
600
|
+
lax.fori_loop(
|
|
601
|
+
0,
|
|
602
|
+
kv_p_end - kv_p_start,
|
|
603
|
+
loop_body,
|
|
604
|
+
(update_sz, ignore), # total transfer size
|
|
605
|
+
unroll=False,
|
|
606
|
+
)
|
|
607
|
+
else:
|
|
608
|
+
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
553
609
|
_async_copy(
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
sz,
|
|
559
|
-
)],
|
|
560
|
-
sem,
|
|
561
|
-
wait,
|
|
610
|
+
src=dst,
|
|
611
|
+
dst=dst,
|
|
612
|
+
sem=sem,
|
|
613
|
+
wait=True,
|
|
562
614
|
)
|
|
563
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
564
|
-
return update_sz - sz, 0
|
|
565
|
-
|
|
566
|
-
lax.fori_loop(
|
|
567
|
-
0,
|
|
568
|
-
kv_p_end - kv_p_start,
|
|
569
|
-
loop_body,
|
|
570
|
-
(update_sz, ignore), # total transfer size
|
|
571
|
-
unroll=False,
|
|
572
|
-
)
|
|
573
615
|
|
|
574
616
|
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
575
617
|
sem = sems.at[1, bq_sem_idx]
|
|
@@ -737,13 +779,17 @@ def _ragged_paged_attention_kernel(
|
|
|
737
779
|
next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
|
|
738
780
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
739
781
|
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
782
|
+
if sliding_window is None:
|
|
783
|
+
next_bkv_start_idx = 0
|
|
784
|
+
else:
|
|
785
|
+
next_bkv_start_idx = lax.select(
|
|
743
786
|
is_last_bq,
|
|
744
|
-
|
|
787
|
+
next_seq_bkv_idx_start,
|
|
745
788
|
bkv_idx_start,
|
|
746
|
-
)
|
|
789
|
+
)
|
|
790
|
+
next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
|
|
791
|
+
next_bkv_idx)
|
|
792
|
+
|
|
747
793
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
748
794
|
|
|
749
795
|
def compute_with_bq(bq_idx, _):
|
|
@@ -803,6 +849,11 @@ def _ragged_paged_attention_kernel(
|
|
|
803
849
|
return
|
|
804
850
|
|
|
805
851
|
# Flash attention with cur bkv and bq
|
|
852
|
+
prev_bq_shape_0 = None
|
|
853
|
+
prev_kv_head_bkv = None
|
|
854
|
+
prev_kv_head_idx = None
|
|
855
|
+
prev_kv_head_p = None
|
|
856
|
+
prev_kv_head_exp_m_diff = None
|
|
806
857
|
for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
|
|
807
858
|
bkv_lst = strided_load_bkv(
|
|
808
859
|
bkv_sem_idx,
|
|
@@ -812,20 +863,51 @@ def _ragged_paged_attention_kernel(
|
|
|
812
863
|
)
|
|
813
864
|
assert len(bkv_lst) == kv_packing
|
|
814
865
|
for i in range(kv_packing):
|
|
815
|
-
|
|
816
|
-
if
|
|
866
|
+
cur_kv_head_idx = kv_head_start + i
|
|
867
|
+
if cur_kv_head_idx >= actual_num_kv_heads:
|
|
817
868
|
break
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
869
|
+
cur_kv_head_bq = load_bq(bq_sem_idx,
|
|
870
|
+
cur_kv_head_idx,
|
|
871
|
+
actual_bq_sz=actual_bq_sz)
|
|
872
|
+
cur_kv_head__bkv = bkv_lst[i]
|
|
873
|
+
# FlashAttention is divided into `flash_attention_step1_qk_softmax`
|
|
874
|
+
# and `flash_attention_step2_pv` to pipeline the computation.
|
|
875
|
+
# `step2_pv` for the previous KV head, which depends on the softmax
|
|
876
|
+
# output, is overlapped with `step1_qk_softmax` for the current KV
|
|
877
|
+
# head, reducing overall wait times.
|
|
878
|
+
cur_kv_head_p, cur_kv_head_exp_m_diff = (
|
|
879
|
+
flash_attention_step1_qk_softmax(
|
|
880
|
+
cur_kv_head_bq,
|
|
881
|
+
cur_kv_head__bkv,
|
|
882
|
+
bq_idx=bq_idx,
|
|
883
|
+
bkv_idx=bkv_idx,
|
|
884
|
+
kv_head_idx=cur_kv_head_idx,
|
|
885
|
+
))
|
|
886
|
+
if prev_bq_shape_0 is not None:
|
|
887
|
+
flash_attention_step2_pv(
|
|
888
|
+
prev_bq_shape_0,
|
|
889
|
+
prev_kv_head_bkv,
|
|
890
|
+
prev_kv_head_p,
|
|
891
|
+
prev_kv_head_exp_m_diff,
|
|
892
|
+
bkv_idx=bkv_idx,
|
|
893
|
+
kv_head_idx=prev_kv_head_idx,
|
|
894
|
+
)
|
|
895
|
+
prev_bq_shape_0 = cur_kv_head_bq.shape[0]
|
|
896
|
+
prev_kv_head_bkv = cur_kv_head__bkv
|
|
897
|
+
prev_kv_head_p = cur_kv_head_p
|
|
898
|
+
prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
|
|
899
|
+
prev_kv_head_idx = cur_kv_head_idx
|
|
900
|
+
|
|
901
|
+
# Execute pv of last attention head.
|
|
902
|
+
assert prev_bq_shape_0 is not None
|
|
903
|
+
flash_attention_step2_pv(
|
|
904
|
+
prev_bq_shape_0,
|
|
905
|
+
prev_kv_head_bkv,
|
|
906
|
+
prev_kv_head_p,
|
|
907
|
+
prev_kv_head_exp_m_diff,
|
|
908
|
+
bkv_idx=bkv_idx,
|
|
909
|
+
kv_head_idx=prev_kv_head_idx,
|
|
910
|
+
)
|
|
829
911
|
|
|
830
912
|
lax.fori_loop(bkv_idx_start,
|
|
831
913
|
num_bkv,
|
|
@@ -1226,6 +1308,7 @@ def static_validate_inputs(
|
|
|
1226
1308
|
static_argnames=(
|
|
1227
1309
|
"sm_scale",
|
|
1228
1310
|
"sliding_window",
|
|
1311
|
+
"strict_sliding_window",
|
|
1229
1312
|
"soft_cap",
|
|
1230
1313
|
"mask_value",
|
|
1231
1314
|
"q_scale",
|
|
@@ -1255,6 +1338,7 @@ def ragged_paged_attention_hd64(
|
|
|
1255
1338
|
*,
|
|
1256
1339
|
sm_scale: float = 1.0,
|
|
1257
1340
|
sliding_window: int | None = None,
|
|
1341
|
+
strict_sliding_window: bool = True,
|
|
1258
1342
|
soft_cap: float | None = None,
|
|
1259
1343
|
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
1260
1344
|
q_scale: float | None = None,
|
|
@@ -1269,42 +1353,41 @@ def ragged_paged_attention_hd64(
|
|
|
1269
1353
|
# Debug params.
|
|
1270
1354
|
debug_mode: bool = False,
|
|
1271
1355
|
):
|
|
1272
|
-
"""A
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
"""
|
|
1356
|
+
"""A variant of ragged paged attention for head_dim=64.
|
|
1357
|
+
|
|
1358
|
+
Args:
|
|
1359
|
+
queries: concatenated all sequences' queries.
|
|
1360
|
+
keys: concatenated all sequences' keys (quantized).
|
|
1361
|
+
values: concatenated all sequences' values (quantized).
|
|
1362
|
+
kv_cache: paged KV cache with TPU-friendly shape.
|
|
1363
|
+
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
|
|
1364
|
+
page_indices: flattened page indices look-up table by (seq_id, page_id).
|
|
1365
|
+
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
|
|
1366
|
+
kv_lens, only the first num_seqs+1 values are valid.
|
|
1367
|
+
distribution: (i, j, k) represents that sequences[0:i] are decode-only,
|
|
1368
|
+
sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
|
|
1369
|
+
k is also the total number of sequences.
|
|
1370
|
+
attention_sink: optional attention sink for each q head.
|
|
1371
|
+
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
1372
|
+
sliding_window: the sliding window size for the attention.
|
|
1373
|
+
strict_sliding_window: compute tokens that are strictly within the window.
|
|
1374
|
+
soft_cap: the logit soft cap for the attention.
|
|
1375
|
+
mask_value: mask value for causal mask.
|
|
1376
|
+
q_scale: the scale for the query.
|
|
1377
|
+
k_scale: the scale for the key cache.
|
|
1378
|
+
v_scale: the scale for the value cache.
|
|
1379
|
+
chunk_prefill_size: the chunk prefill size for the attention.
|
|
1380
|
+
num_kv_pages_per_block: number of kv pages to be processed in one flash
|
|
1381
|
+
attention block in the pallas kernel.
|
|
1382
|
+
num_queries_per_block: number of kv pages to be processed in one flash
|
|
1383
|
+
attention block in the pallas kernel.
|
|
1384
|
+
vmem_limit_bytes: the vmem limit for the pallas kernel.
|
|
1385
|
+
debug_mode: if true, RPA does not issue any DMAs or run flash attention but
|
|
1386
|
+
print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
|
|
1387
|
+
|
|
1388
|
+
Returns:
|
|
1389
|
+
The output of the attention.
|
|
1390
|
+
"""
|
|
1308
1391
|
q, k, v = queries, keys, values
|
|
1309
1392
|
static_validate_inputs(
|
|
1310
1393
|
q,
|
|
@@ -1374,7 +1457,7 @@ def ragged_paged_attention_hd64(
|
|
|
1374
1457
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1375
1458
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1376
1459
|
None if attention_sink is None else pl.BlockSpec(
|
|
1377
|
-
memory_space=pltpu.VMEM)
|
|
1460
|
+
memory_space=pltpu.VMEM),
|
|
1378
1461
|
]
|
|
1379
1462
|
|
|
1380
1463
|
out_specs = [
|
|
@@ -1438,6 +1521,7 @@ def ragged_paged_attention_hd64(
|
|
|
1438
1521
|
_ragged_paged_attention_kernel,
|
|
1439
1522
|
sm_scale=sm_scale,
|
|
1440
1523
|
sliding_window=sliding_window,
|
|
1524
|
+
strict_sliding_window=strict_sliding_window,
|
|
1441
1525
|
soft_cap=soft_cap,
|
|
1442
1526
|
mask_value=mask_value,
|
|
1443
1527
|
q_scale=q_scale,
|
|
@@ -295,7 +295,8 @@ def get_tuned_block_sizes(
|
|
|
295
295
|
bkv_p, bq = TUNED_BLOCK_SIZES[device][page_size][dtypes][head_dims][
|
|
296
296
|
max_model_len]
|
|
297
297
|
except KeyError:
|
|
298
|
-
|
|
298
|
+
logger.warning_once(
|
|
299
|
+
'Couldn`t find tuned sizes for the RPA v3 kernel with %s', keys)
|
|
299
300
|
|
|
300
301
|
return (min(pages_per_seq, bkv_p), min(max_num_tokens, bq))
|
|
301
302
|
|
|
@@ -308,7 +308,13 @@ def sharded_ragged_paged_attention(
|
|
|
308
308
|
args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
|
|
309
309
|
|
|
310
310
|
use_hd64 = q.shape[-1] == 64
|
|
311
|
-
|
|
311
|
+
|
|
312
|
+
func = ragged_paged_attention
|
|
313
|
+
if use_hd64:
|
|
314
|
+
func = functools.partial(ragged_paged_attention_hd64,
|
|
315
|
+
strict_sliding_window=True)
|
|
316
|
+
else:
|
|
317
|
+
func = ragged_paged_attention
|
|
312
318
|
|
|
313
319
|
if attention_sink is not None:
|
|
314
320
|
if not use_hd64:
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import math
|
|
3
|
-
import os
|
|
4
3
|
from dataclasses import asdict, dataclass
|
|
5
4
|
from typing import TYPE_CHECKING, List, Optional
|
|
6
5
|
|
|
@@ -8,7 +7,7 @@ import jax.numpy as jnp
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
from jax.sharding import Mesh
|
|
10
9
|
|
|
11
|
-
from tpu_inference import utils
|
|
10
|
+
from tpu_inference import envs, utils
|
|
12
11
|
|
|
13
12
|
if TYPE_CHECKING:
|
|
14
13
|
from vllm.v1.configs.vllm_config import VllmConfig
|
|
@@ -48,7 +47,7 @@ class ShardingAxisName2D:
|
|
|
48
47
|
|
|
49
48
|
|
|
50
49
|
try:
|
|
51
|
-
_use_base_sharding =
|
|
50
|
+
_use_base_sharding = envs.NEW_MODEL_DESIGN
|
|
52
51
|
if _use_base_sharding:
|
|
53
52
|
ShardingAxisName = ShardingAxisNameBase
|
|
54
53
|
else:
|
|
@@ -120,9 +119,13 @@ class ShardingConfigManager:
|
|
|
120
119
|
False)
|
|
121
120
|
if enable_dp_attention:
|
|
122
121
|
# Replicate attention layer when num_kv_heads < TP
|
|
123
|
-
num_kv_heads = vllm_config.model_config.get_total_num_kv_heads(
|
|
122
|
+
num_kv_heads = 1 if vllm_config.model_config.use_mla else vllm_config.model_config.get_total_num_kv_heads(
|
|
123
|
+
)
|
|
124
|
+
cache_dtype = vllm_config.cache_config.cache_dtype
|
|
125
|
+
if cache_dtype == 'auto':
|
|
126
|
+
cache_dtype = vllm_config.model_config.dtype
|
|
124
127
|
kv_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
125
|
-
|
|
128
|
+
cache_dtype) or jnp.bfloat16
|
|
126
129
|
packing = 4 // jnp.dtype(kv_dtype).itemsize
|
|
127
130
|
# When num_kv_heads * 2 / packing < TP, tensor parallelism would
|
|
128
131
|
# duplicate KV heads across devices, wasting kv cache memory.
|
|
@@ -166,9 +169,10 @@ class ShardingConfigManager:
|
|
|
166
169
|
f"LoRA is not supported with data parallelism "
|
|
167
170
|
f"(DP size: {total_dp_size}). Please disable LoRA or "
|
|
168
171
|
f"set data parallelism to 1.")
|
|
169
|
-
|
|
172
|
+
if sharding_strategy.attention_data_parallelism > 1:
|
|
173
|
+
if not envs.NEW_MODEL_DESIGN:
|
|
170
174
|
raise ValueError(
|
|
171
|
-
"Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
|
|
175
|
+
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
|
|
172
176
|
"NEW_MODEL_DESIGN=True.")
|
|
173
177
|
|
|
174
178
|
@property
|