tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -1
- tests/lora/test_lora_perf.py +53 -0
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/models/common/model_loader.py +78 -22
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama_eagle3.py +4 -5
- tpu_inference/models/jax/qwen2_5_vl.py +161 -47
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +203 -155
- tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
- tpu_inference/platforms/tpu_platform.py +29 -48
- tpu_inference/runner/compilation_manager.py +112 -46
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +40 -31
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +94 -51
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -22
- tpu_inference/utils.py +41 -14
- tpu_inference/worker/tpu_worker.py +43 -45
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -319,7 +319,7 @@ def _ragged_paged_attention_kernel(
|
|
|
319
319
|
debug_print("[RPA debug] q_len={}", q_len)
|
|
320
320
|
debug_print("[RPA debug] kv_len={}", kv_len)
|
|
321
321
|
|
|
322
|
-
def
|
|
322
|
+
def flash_attention_step1_qk_softmax(
|
|
323
323
|
q, # [actual_bq_sz * num_q_heads_per_kv_head, head_dim]
|
|
324
324
|
k, # [bkv_sz, head_dim]
|
|
325
325
|
v, # [bkv_sz, head_dim]
|
|
@@ -335,7 +335,6 @@ def _ragged_paged_attention_kernel(
|
|
|
335
335
|
assert k.dtype == v.dtype
|
|
336
336
|
head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
|
|
337
337
|
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
|
|
338
|
-
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
|
|
339
338
|
|
|
340
339
|
def load_with_init(ref, init_val):
|
|
341
340
|
return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
|
|
@@ -376,15 +375,32 @@ def _ragged_paged_attention_kernel(
|
|
|
376
375
|
head_m_ref[...] = m_curr
|
|
377
376
|
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
|
|
378
377
|
|
|
379
|
-
pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
|
|
380
|
-
if v_scale is not None:
|
|
381
|
-
pv *= v_scale
|
|
382
|
-
|
|
383
378
|
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
|
|
384
379
|
exp_m_diff = jnp.exp(m_prev - m_curr)
|
|
385
380
|
l_prev = load_with_init(head_l_ref, 0.0)
|
|
386
381
|
l_curr = exp_m_diff * l_prev + p_rowsum
|
|
387
382
|
head_l_ref[...] = l_curr
|
|
383
|
+
|
|
384
|
+
return p, exp_m_diff
|
|
385
|
+
|
|
386
|
+
def flash_attention_step2_pv(
|
|
387
|
+
q_shape_0,
|
|
388
|
+
v, # [bkv_sz, head_dim]
|
|
389
|
+
p, # from step1
|
|
390
|
+
exp_m_diff, # from step1
|
|
391
|
+
*,
|
|
392
|
+
bkv_idx,
|
|
393
|
+
kv_head_idx,
|
|
394
|
+
):
|
|
395
|
+
head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
|
|
396
|
+
|
|
397
|
+
def load_with_init(ref, init_val):
|
|
398
|
+
return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
|
|
399
|
+
ref[...])
|
|
400
|
+
|
|
401
|
+
pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
|
|
402
|
+
if v_scale is not None:
|
|
403
|
+
pv *= v_scale
|
|
388
404
|
o_prev = load_with_init(head_acc_ref, 0.0)
|
|
389
405
|
o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
|
|
390
406
|
head_acc_ref[...] = o_curr
|
|
@@ -463,19 +479,16 @@ def _ragged_paged_attention_kernel(
|
|
|
463
479
|
unroll=False,
|
|
464
480
|
)
|
|
465
481
|
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
sem,
|
|
477
|
-
wait,
|
|
478
|
-
)
|
|
482
|
+
size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
|
|
483
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
484
|
+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
|
|
485
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
486
|
+
_async_copy(
|
|
487
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
|
|
488
|
+
vmem_ref.at[pl.ds(offset, size)],
|
|
489
|
+
sem,
|
|
490
|
+
wait,
|
|
491
|
+
)
|
|
479
492
|
|
|
480
493
|
return kv_len_start + offset, bkv_sz_frm_new
|
|
481
494
|
else:
|
|
@@ -842,6 +855,11 @@ def _ragged_paged_attention_kernel(
|
|
|
842
855
|
|
|
843
856
|
# Flash attention with cur bkv and bq
|
|
844
857
|
# NOTE: kv_packing is divided by 2 because k and v are packed together.
|
|
858
|
+
prev_bq_shape_0 = None
|
|
859
|
+
prev_kv_head_bv = None
|
|
860
|
+
prev_kv_head_idx = None
|
|
861
|
+
prev_kv_head_p = None
|
|
862
|
+
prev_kv_head_exp_m_diff = None
|
|
845
863
|
heads_per_load = max(1, kv_packing // 2)
|
|
846
864
|
for kv_head_start in range(0, actual_num_kv_heads,
|
|
847
865
|
heads_per_load):
|
|
@@ -853,21 +871,53 @@ def _ragged_paged_attention_kernel(
|
|
|
853
871
|
)
|
|
854
872
|
assert len(bkv_lst) == heads_per_load
|
|
855
873
|
for i in range(heads_per_load):
|
|
856
|
-
|
|
857
|
-
if
|
|
874
|
+
cur_kv_head_idx = kv_head_start + i
|
|
875
|
+
if cur_kv_head_idx >= actual_num_kv_heads:
|
|
858
876
|
break
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
877
|
+
|
|
878
|
+
cur_kv_head_bq = load_bq(bq_sem_idx,
|
|
879
|
+
cur_kv_head_idx,
|
|
880
|
+
actual_bq_sz=actual_bq_sz)
|
|
862
881
|
bk, bv = bkv_lst[i]
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
882
|
+
# FlashAttention is divided into `flash_attention_step1_qk_softmax`
|
|
883
|
+
# and `flash_attention_step2_pv` to pipeline the computation.
|
|
884
|
+
# `step2_pv` for the previous KV head, which depends on the softmax
|
|
885
|
+
# output, is overlapped with `step1_qk_softmax` for the current KV
|
|
886
|
+
# head, reducing overall wait times.
|
|
887
|
+
cur_kv_head_p, cur_kv_head_exp_m_diff = (
|
|
888
|
+
flash_attention_step1_qk_softmax(
|
|
889
|
+
cur_kv_head_bq,
|
|
890
|
+
bk,
|
|
891
|
+
bv,
|
|
892
|
+
bq_idx=bq_idx,
|
|
893
|
+
bkv_idx=bkv_idx,
|
|
894
|
+
kv_head_idx=cur_kv_head_idx,
|
|
895
|
+
))
|
|
896
|
+
if prev_bq_shape_0 is not None:
|
|
897
|
+
flash_attention_step2_pv(
|
|
898
|
+
prev_bq_shape_0,
|
|
899
|
+
prev_kv_head_bv,
|
|
900
|
+
prev_kv_head_p,
|
|
901
|
+
prev_kv_head_exp_m_diff,
|
|
902
|
+
bkv_idx=bkv_idx,
|
|
903
|
+
kv_head_idx=prev_kv_head_idx,
|
|
904
|
+
)
|
|
905
|
+
prev_bq_shape_0 = cur_kv_head_bq.shape[0]
|
|
906
|
+
prev_kv_head_bv = bv
|
|
907
|
+
prev_kv_head_p = cur_kv_head_p
|
|
908
|
+
prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
|
|
909
|
+
prev_kv_head_idx = cur_kv_head_idx
|
|
910
|
+
|
|
911
|
+
# Execute pv of last attention head.
|
|
912
|
+
assert prev_bq_shape_0 is not None
|
|
913
|
+
flash_attention_step2_pv(
|
|
914
|
+
prev_bq_shape_0,
|
|
915
|
+
prev_kv_head_bv,
|
|
916
|
+
prev_kv_head_p,
|
|
917
|
+
prev_kv_head_exp_m_diff,
|
|
918
|
+
bkv_idx=bkv_idx,
|
|
919
|
+
kv_head_idx=prev_kv_head_idx,
|
|
920
|
+
)
|
|
871
921
|
|
|
872
922
|
lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
|
|
873
923
|
|
|
@@ -267,6 +267,7 @@ def _ragged_paged_attention_kernel(
|
|
|
267
267
|
*,
|
|
268
268
|
sm_scale: float,
|
|
269
269
|
sliding_window: int | None = None,
|
|
270
|
+
strict_sliding_window: bool = True,
|
|
270
271
|
soft_cap: float | None = None,
|
|
271
272
|
mask_value: float = DEFAULT_MASK_VALUE,
|
|
272
273
|
q_scale: float | None = None,
|
|
@@ -317,19 +318,20 @@ def _ragged_paged_attention_kernel(
|
|
|
317
318
|
q_len = q_end - q_start
|
|
318
319
|
kv_len = kv_lens_ref[seq_idx]
|
|
319
320
|
|
|
320
|
-
bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
|
|
321
|
-
kv_len - sliding_window, 0) // bkv_sz
|
|
322
|
-
|
|
323
321
|
if sliding_window is None:
|
|
324
|
-
|
|
322
|
+
bkv_idx_start = next_seq_bkv_idx_start = 0
|
|
325
323
|
else:
|
|
324
|
+
bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
|
|
325
|
+
0) // bkv_sz
|
|
326
326
|
|
|
327
327
|
def get_next_bkv_idx_start():
|
|
328
328
|
next_kv_len = kv_lens_ref[seq_idx + 1]
|
|
329
|
-
|
|
329
|
+
next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
|
|
330
|
+
return jnp.maximum(next_kv_len - next_q_len - sliding_window,
|
|
331
|
+
0) // bkv_sz
|
|
330
332
|
|
|
331
|
-
|
|
332
|
-
|
|
333
|
+
next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
|
|
334
|
+
get_next_bkv_idx_start, lambda: 0)
|
|
333
335
|
|
|
334
336
|
def debug_print(msg, *args):
|
|
335
337
|
if debug_mode:
|
|
@@ -350,7 +352,7 @@ def _ragged_paged_attention_kernel(
|
|
|
350
352
|
debug_print("[RPA debug] q_len={}", q_len)
|
|
351
353
|
debug_print("[RPA debug] kv_len={}", kv_len)
|
|
352
354
|
|
|
353
|
-
def
|
|
355
|
+
def flash_attention_step1_qk_softmax(
|
|
354
356
|
q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
|
|
355
357
|
kv, # [bkv_sz, actual_head_dim_x2]
|
|
356
358
|
*,
|
|
@@ -364,7 +366,6 @@ def _ragged_paged_attention_kernel(
|
|
|
364
366
|
assert kv.shape == (bkv_sz, actual_head_dim_x2)
|
|
365
367
|
head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
|
|
366
368
|
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
|
|
367
|
-
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
|
|
368
369
|
|
|
369
370
|
def load_with_init(ref, init_val):
|
|
370
371
|
return jnp.where(bkv_idx == bkv_idx_start,
|
|
@@ -386,16 +387,19 @@ def _ragged_paged_attention_kernel(
|
|
|
386
387
|
s *= k_scale
|
|
387
388
|
if q_scale is not None:
|
|
388
389
|
s *= q_scale
|
|
390
|
+
if soft_cap is not None:
|
|
391
|
+
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
389
392
|
|
|
390
393
|
q_span = (kv_len - q_len + bq_idx * bq_sz +
|
|
391
394
|
lax.broadcasted_iota(jnp.int32, s.shape, 0) //
|
|
392
395
|
num_q_heads_per_kv_head)
|
|
393
396
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
394
|
-
mask =
|
|
397
|
+
mask = k_span <= q_span
|
|
395
398
|
|
|
396
|
-
if
|
|
397
|
-
|
|
398
|
-
|
|
399
|
+
if sliding_window is not None and strict_sliding_window:
|
|
400
|
+
mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
|
|
401
|
+
|
|
402
|
+
s = jnp.where(mask, s, mask_value)
|
|
399
403
|
s_rowmax = jnp.max(s, axis=1, keepdims=True)
|
|
400
404
|
|
|
401
405
|
if attention_sink_ref is not None:
|
|
@@ -411,15 +415,33 @@ def _ragged_paged_attention_kernel(
|
|
|
411
415
|
head_m_ref[...] = m_curr
|
|
412
416
|
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
|
|
413
417
|
|
|
414
|
-
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
415
|
-
if v_scale is not None:
|
|
416
|
-
pv *= v_scale
|
|
417
|
-
|
|
418
418
|
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
|
|
419
419
|
exp_m_diff = jnp.exp(m_prev - m_curr)
|
|
420
420
|
l_prev = load_with_init(head_l_ref, 1.0)
|
|
421
421
|
l_curr = exp_m_diff * l_prev + p_rowsum
|
|
422
422
|
head_l_ref[...] = l_curr
|
|
423
|
+
|
|
424
|
+
return p, exp_m_diff
|
|
425
|
+
|
|
426
|
+
def flash_attention_step2_pv(
|
|
427
|
+
q_shape_0,
|
|
428
|
+
kv, # [bkv_sz, actual_head_dim_x2]
|
|
429
|
+
p, # from step1
|
|
430
|
+
exp_m_diff, # from step1
|
|
431
|
+
*,
|
|
432
|
+
bkv_idx,
|
|
433
|
+
kv_head_idx,
|
|
434
|
+
):
|
|
435
|
+
head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
|
|
436
|
+
|
|
437
|
+
def load_with_init(ref, init_val):
|
|
438
|
+
return jnp.where(bkv_idx == bkv_idx_start,
|
|
439
|
+
jnp.full_like(ref, init_val), ref[...])
|
|
440
|
+
|
|
441
|
+
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
442
|
+
if v_scale is not None:
|
|
443
|
+
pv *= v_scale
|
|
444
|
+
|
|
423
445
|
o_prev = load_with_init(head_acc_ref, 0.0)
|
|
424
446
|
o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
|
|
425
447
|
head_acc_ref[...] = o_curr
|
|
@@ -498,19 +520,16 @@ def _ragged_paged_attention_kernel(
|
|
|
498
520
|
unroll=False,
|
|
499
521
|
)
|
|
500
522
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
sem,
|
|
512
|
-
wait,
|
|
513
|
-
)
|
|
523
|
+
size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
|
|
524
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
525
|
+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
|
|
526
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
527
|
+
_async_copy(
|
|
528
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
|
|
529
|
+
vmem_ref.at[pl.ds(offset, size)],
|
|
530
|
+
sem,
|
|
531
|
+
wait,
|
|
532
|
+
)
|
|
514
533
|
|
|
515
534
|
return kv_len_start + offset, bkv_sz_frm_new
|
|
516
535
|
else:
|
|
@@ -760,13 +779,17 @@ def _ragged_paged_attention_kernel(
|
|
|
760
779
|
next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
|
|
761
780
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
762
781
|
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
782
|
+
if sliding_window is None:
|
|
783
|
+
next_bkv_start_idx = 0
|
|
784
|
+
else:
|
|
785
|
+
next_bkv_start_idx = lax.select(
|
|
766
786
|
is_last_bq,
|
|
767
|
-
|
|
787
|
+
next_seq_bkv_idx_start,
|
|
768
788
|
bkv_idx_start,
|
|
769
|
-
)
|
|
789
|
+
)
|
|
790
|
+
next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
|
|
791
|
+
next_bkv_idx)
|
|
792
|
+
|
|
770
793
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
771
794
|
|
|
772
795
|
def compute_with_bq(bq_idx, _):
|
|
@@ -826,6 +849,11 @@ def _ragged_paged_attention_kernel(
|
|
|
826
849
|
return
|
|
827
850
|
|
|
828
851
|
# Flash attention with cur bkv and bq
|
|
852
|
+
prev_bq_shape_0 = None
|
|
853
|
+
prev_kv_head_bkv = None
|
|
854
|
+
prev_kv_head_idx = None
|
|
855
|
+
prev_kv_head_p = None
|
|
856
|
+
prev_kv_head_exp_m_diff = None
|
|
829
857
|
for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
|
|
830
858
|
bkv_lst = strided_load_bkv(
|
|
831
859
|
bkv_sem_idx,
|
|
@@ -835,20 +863,51 @@ def _ragged_paged_attention_kernel(
|
|
|
835
863
|
)
|
|
836
864
|
assert len(bkv_lst) == kv_packing
|
|
837
865
|
for i in range(kv_packing):
|
|
838
|
-
|
|
839
|
-
if
|
|
866
|
+
cur_kv_head_idx = kv_head_start + i
|
|
867
|
+
if cur_kv_head_idx >= actual_num_kv_heads:
|
|
840
868
|
break
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
869
|
+
cur_kv_head_bq = load_bq(bq_sem_idx,
|
|
870
|
+
cur_kv_head_idx,
|
|
871
|
+
actual_bq_sz=actual_bq_sz)
|
|
872
|
+
cur_kv_head__bkv = bkv_lst[i]
|
|
873
|
+
# FlashAttention is divided into `flash_attention_step1_qk_softmax`
|
|
874
|
+
# and `flash_attention_step2_pv` to pipeline the computation.
|
|
875
|
+
# `step2_pv` for the previous KV head, which depends on the softmax
|
|
876
|
+
# output, is overlapped with `step1_qk_softmax` for the current KV
|
|
877
|
+
# head, reducing overall wait times.
|
|
878
|
+
cur_kv_head_p, cur_kv_head_exp_m_diff = (
|
|
879
|
+
flash_attention_step1_qk_softmax(
|
|
880
|
+
cur_kv_head_bq,
|
|
881
|
+
cur_kv_head__bkv,
|
|
882
|
+
bq_idx=bq_idx,
|
|
883
|
+
bkv_idx=bkv_idx,
|
|
884
|
+
kv_head_idx=cur_kv_head_idx,
|
|
885
|
+
))
|
|
886
|
+
if prev_bq_shape_0 is not None:
|
|
887
|
+
flash_attention_step2_pv(
|
|
888
|
+
prev_bq_shape_0,
|
|
889
|
+
prev_kv_head_bkv,
|
|
890
|
+
prev_kv_head_p,
|
|
891
|
+
prev_kv_head_exp_m_diff,
|
|
892
|
+
bkv_idx=bkv_idx,
|
|
893
|
+
kv_head_idx=prev_kv_head_idx,
|
|
894
|
+
)
|
|
895
|
+
prev_bq_shape_0 = cur_kv_head_bq.shape[0]
|
|
896
|
+
prev_kv_head_bkv = cur_kv_head__bkv
|
|
897
|
+
prev_kv_head_p = cur_kv_head_p
|
|
898
|
+
prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
|
|
899
|
+
prev_kv_head_idx = cur_kv_head_idx
|
|
900
|
+
|
|
901
|
+
# Execute pv of last attention head.
|
|
902
|
+
assert prev_bq_shape_0 is not None
|
|
903
|
+
flash_attention_step2_pv(
|
|
904
|
+
prev_bq_shape_0,
|
|
905
|
+
prev_kv_head_bkv,
|
|
906
|
+
prev_kv_head_p,
|
|
907
|
+
prev_kv_head_exp_m_diff,
|
|
908
|
+
bkv_idx=bkv_idx,
|
|
909
|
+
kv_head_idx=prev_kv_head_idx,
|
|
910
|
+
)
|
|
852
911
|
|
|
853
912
|
lax.fori_loop(bkv_idx_start,
|
|
854
913
|
num_bkv,
|
|
@@ -1249,6 +1308,7 @@ def static_validate_inputs(
|
|
|
1249
1308
|
static_argnames=(
|
|
1250
1309
|
"sm_scale",
|
|
1251
1310
|
"sliding_window",
|
|
1311
|
+
"strict_sliding_window",
|
|
1252
1312
|
"soft_cap",
|
|
1253
1313
|
"mask_value",
|
|
1254
1314
|
"q_scale",
|
|
@@ -1278,6 +1338,7 @@ def ragged_paged_attention_hd64(
|
|
|
1278
1338
|
*,
|
|
1279
1339
|
sm_scale: float = 1.0,
|
|
1280
1340
|
sliding_window: int | None = None,
|
|
1341
|
+
strict_sliding_window: bool = True,
|
|
1281
1342
|
soft_cap: float | None = None,
|
|
1282
1343
|
mask_value: float | None = DEFAULT_MASK_VALUE,
|
|
1283
1344
|
q_scale: float | None = None,
|
|
@@ -1292,42 +1353,41 @@ def ragged_paged_attention_hd64(
|
|
|
1292
1353
|
# Debug params.
|
|
1293
1354
|
debug_mode: bool = False,
|
|
1294
1355
|
):
|
|
1295
|
-
"""A
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
"""
|
|
1356
|
+
"""A variant of ragged paged attention for head_dim=64.
|
|
1357
|
+
|
|
1358
|
+
Args:
|
|
1359
|
+
queries: concatenated all sequences' queries.
|
|
1360
|
+
keys: concatenated all sequences' keys (quantized).
|
|
1361
|
+
values: concatenated all sequences' values (quantized).
|
|
1362
|
+
kv_cache: paged KV cache with TPU-friendly shape.
|
|
1363
|
+
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
|
|
1364
|
+
page_indices: flattened page indices look-up table by (seq_id, page_id).
|
|
1365
|
+
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
|
|
1366
|
+
kv_lens, only the first num_seqs+1 values are valid.
|
|
1367
|
+
distribution: (i, j, k) represents that sequences[0:i] are decode-only,
|
|
1368
|
+
sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
|
|
1369
|
+
k is also the total number of sequences.
|
|
1370
|
+
attention_sink: optional attention sink for each q head.
|
|
1371
|
+
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
1372
|
+
sliding_window: the sliding window size for the attention.
|
|
1373
|
+
strict_sliding_window: compute tokens that are strictly within the window.
|
|
1374
|
+
soft_cap: the logit soft cap for the attention.
|
|
1375
|
+
mask_value: mask value for causal mask.
|
|
1376
|
+
q_scale: the scale for the query.
|
|
1377
|
+
k_scale: the scale for the key cache.
|
|
1378
|
+
v_scale: the scale for the value cache.
|
|
1379
|
+
chunk_prefill_size: the chunk prefill size for the attention.
|
|
1380
|
+
num_kv_pages_per_block: number of kv pages to be processed in one flash
|
|
1381
|
+
attention block in the pallas kernel.
|
|
1382
|
+
num_queries_per_block: number of kv pages to be processed in one flash
|
|
1383
|
+
attention block in the pallas kernel.
|
|
1384
|
+
vmem_limit_bytes: the vmem limit for the pallas kernel.
|
|
1385
|
+
debug_mode: if true, RPA does not issue any DMAs or run flash attention but
|
|
1386
|
+
print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
|
|
1387
|
+
|
|
1388
|
+
Returns:
|
|
1389
|
+
The output of the attention.
|
|
1390
|
+
"""
|
|
1331
1391
|
q, k, v = queries, keys, values
|
|
1332
1392
|
static_validate_inputs(
|
|
1333
1393
|
q,
|
|
@@ -1397,7 +1457,7 @@ def ragged_paged_attention_hd64(
|
|
|
1397
1457
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1398
1458
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1399
1459
|
None if attention_sink is None else pl.BlockSpec(
|
|
1400
|
-
memory_space=pltpu.VMEM)
|
|
1460
|
+
memory_space=pltpu.VMEM),
|
|
1401
1461
|
]
|
|
1402
1462
|
|
|
1403
1463
|
out_specs = [
|
|
@@ -1461,6 +1521,7 @@ def ragged_paged_attention_hd64(
|
|
|
1461
1521
|
_ragged_paged_attention_kernel,
|
|
1462
1522
|
sm_scale=sm_scale,
|
|
1463
1523
|
sliding_window=sliding_window,
|
|
1524
|
+
strict_sliding_window=strict_sliding_window,
|
|
1464
1525
|
soft_cap=soft_cap,
|
|
1465
1526
|
mask_value=mask_value,
|
|
1466
1527
|
q_scale=q_scale,
|
|
@@ -295,7 +295,8 @@ def get_tuned_block_sizes(
|
|
|
295
295
|
bkv_p, bq = TUNED_BLOCK_SIZES[device][page_size][dtypes][head_dims][
|
|
296
296
|
max_model_len]
|
|
297
297
|
except KeyError:
|
|
298
|
-
|
|
298
|
+
logger.warning_once(
|
|
299
|
+
'Couldn`t find tuned sizes for the RPA v3 kernel with %s', keys)
|
|
299
300
|
|
|
300
301
|
return (min(pages_per_seq, bkv_p), min(max_num_tokens, bq))
|
|
301
302
|
|
|
@@ -308,7 +308,13 @@ def sharded_ragged_paged_attention(
|
|
|
308
308
|
args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
|
|
309
309
|
|
|
310
310
|
use_hd64 = q.shape[-1] == 64
|
|
311
|
-
|
|
311
|
+
|
|
312
|
+
func = ragged_paged_attention
|
|
313
|
+
if use_hd64:
|
|
314
|
+
func = functools.partial(ragged_paged_attention_hd64,
|
|
315
|
+
strict_sliding_window=True)
|
|
316
|
+
else:
|
|
317
|
+
func = ragged_paged_attention
|
|
312
318
|
|
|
313
319
|
if attention_sink is not None:
|
|
314
320
|
if not use_hd64:
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import math
|
|
3
|
-
import os
|
|
4
3
|
from dataclasses import asdict, dataclass
|
|
5
4
|
from typing import TYPE_CHECKING, List, Optional
|
|
6
5
|
|
|
@@ -8,7 +7,7 @@ import jax.numpy as jnp
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
from jax.sharding import Mesh
|
|
10
9
|
|
|
11
|
-
from tpu_inference import utils
|
|
10
|
+
from tpu_inference import envs, utils
|
|
12
11
|
|
|
13
12
|
if TYPE_CHECKING:
|
|
14
13
|
from vllm.v1.configs.vllm_config import VllmConfig
|
|
@@ -48,7 +47,7 @@ class ShardingAxisName2D:
|
|
|
48
47
|
|
|
49
48
|
|
|
50
49
|
try:
|
|
51
|
-
_use_base_sharding =
|
|
50
|
+
_use_base_sharding = envs.NEW_MODEL_DESIGN
|
|
52
51
|
if _use_base_sharding:
|
|
53
52
|
ShardingAxisName = ShardingAxisNameBase
|
|
54
53
|
else:
|
|
@@ -120,9 +119,13 @@ class ShardingConfigManager:
|
|
|
120
119
|
False)
|
|
121
120
|
if enable_dp_attention:
|
|
122
121
|
# Replicate attention layer when num_kv_heads < TP
|
|
123
|
-
num_kv_heads = vllm_config.model_config.get_total_num_kv_heads(
|
|
122
|
+
num_kv_heads = 1 if vllm_config.model_config.use_mla else vllm_config.model_config.get_total_num_kv_heads(
|
|
123
|
+
)
|
|
124
|
+
cache_dtype = vllm_config.cache_config.cache_dtype
|
|
125
|
+
if cache_dtype == 'auto':
|
|
126
|
+
cache_dtype = vllm_config.model_config.dtype
|
|
124
127
|
kv_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
125
|
-
|
|
128
|
+
cache_dtype) or jnp.bfloat16
|
|
126
129
|
packing = 4 // jnp.dtype(kv_dtype).itemsize
|
|
127
130
|
# When num_kv_heads * 2 / packing < TP, tensor parallelism would
|
|
128
131
|
# duplicate KV heads across devices, wasting kv cache memory.
|
|
@@ -166,9 +169,10 @@ class ShardingConfigManager:
|
|
|
166
169
|
f"LoRA is not supported with data parallelism "
|
|
167
170
|
f"(DP size: {total_dp_size}). Please disable LoRA or "
|
|
168
171
|
f"set data parallelism to 1.")
|
|
169
|
-
|
|
172
|
+
if sharding_strategy.attention_data_parallelism > 1:
|
|
173
|
+
if not envs.NEW_MODEL_DESIGN:
|
|
170
174
|
raise ValueError(
|
|
171
|
-
"Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
|
|
175
|
+
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
|
|
172
176
|
"NEW_MODEL_DESIGN=True.")
|
|
173
177
|
|
|
174
178
|
@property
|