tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""
|
|
2
15
|
A variant of TPU-Friendly Ragged Paged Attention kernel optimized for
|
|
3
16
|
head_dim = 64.
|
|
@@ -317,19 +330,21 @@ def _ragged_paged_attention_kernel(
|
|
|
317
330
|
q_len = q_end - q_start
|
|
318
331
|
kv_len = kv_lens_ref[seq_idx]
|
|
319
332
|
|
|
320
|
-
bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
|
|
321
|
-
kv_len - sliding_window, 0) // bkv_sz
|
|
322
|
-
|
|
323
333
|
if sliding_window is None:
|
|
324
|
-
|
|
334
|
+
bkv_idx_start = next_seq_bkv_idx_start = 0
|
|
325
335
|
else:
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
336
|
+
bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
|
|
337
|
+
0) // bkv_sz
|
|
338
|
+
|
|
339
|
+
# If seq_idx + 1 == num_seqs, kv_lens_ref[seq_idx + 1] will trigger a
|
|
340
|
+
# out-of-bound error. To avoid this, we set upperbound of next_seq_idx
|
|
341
|
+
# to be num_seqs - 1.
|
|
342
|
+
next_seq_idx = jnp.minimum(seq_idx + 1, num_seqs - 1)
|
|
343
|
+
next_kv_len = kv_lens_ref[next_seq_idx]
|
|
344
|
+
next_q_len = cu_q_lens_ref[next_seq_idx + 1] - q_end
|
|
345
|
+
next_seq_bkv_idx_start = (
|
|
346
|
+
jnp.maximum(next_kv_len - next_q_len - sliding_window, 0) //
|
|
347
|
+
bkv_sz)
|
|
333
348
|
|
|
334
349
|
def debug_print(msg, *args):
|
|
335
350
|
if debug_mode:
|
|
@@ -350,7 +365,7 @@ def _ragged_paged_attention_kernel(
|
|
|
350
365
|
debug_print("[RPA debug] q_len={}", q_len)
|
|
351
366
|
debug_print("[RPA debug] kv_len={}", kv_len)
|
|
352
367
|
|
|
353
|
-
def
|
|
368
|
+
def flash_attention_step1_qk_softmax(
|
|
354
369
|
q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
|
|
355
370
|
kv, # [bkv_sz, actual_head_dim_x2]
|
|
356
371
|
*,
|
|
@@ -364,7 +379,6 @@ def _ragged_paged_attention_kernel(
|
|
|
364
379
|
assert kv.shape == (bkv_sz, actual_head_dim_x2)
|
|
365
380
|
head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
|
|
366
381
|
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
|
|
367
|
-
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
|
|
368
382
|
|
|
369
383
|
def load_with_init(ref, init_val):
|
|
370
384
|
return jnp.where(bkv_idx == bkv_idx_start,
|
|
@@ -386,16 +400,19 @@ def _ragged_paged_attention_kernel(
|
|
|
386
400
|
s *= k_scale
|
|
387
401
|
if q_scale is not None:
|
|
388
402
|
s *= q_scale
|
|
403
|
+
if soft_cap is not None:
|
|
404
|
+
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
389
405
|
|
|
390
406
|
q_span = (kv_len - q_len + bq_idx * bq_sz +
|
|
391
407
|
lax.broadcasted_iota(jnp.int32, s.shape, 0) //
|
|
392
408
|
num_q_heads_per_kv_head)
|
|
393
409
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
394
|
-
mask =
|
|
410
|
+
mask = k_span <= q_span
|
|
395
411
|
|
|
396
|
-
if
|
|
397
|
-
|
|
398
|
-
|
|
412
|
+
if sliding_window is not None:
|
|
413
|
+
mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
|
|
414
|
+
|
|
415
|
+
s = jnp.where(mask, s, mask_value)
|
|
399
416
|
s_rowmax = jnp.max(s, axis=1, keepdims=True)
|
|
400
417
|
|
|
401
418
|
if attention_sink_ref is not None:
|
|
@@ -411,15 +428,33 @@ def _ragged_paged_attention_kernel(
|
|
|
411
428
|
head_m_ref[...] = m_curr
|
|
412
429
|
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
|
|
413
430
|
|
|
414
|
-
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
415
|
-
if v_scale is not None:
|
|
416
|
-
pv *= v_scale
|
|
417
|
-
|
|
418
431
|
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
|
|
419
432
|
exp_m_diff = jnp.exp(m_prev - m_curr)
|
|
420
433
|
l_prev = load_with_init(head_l_ref, 1.0)
|
|
421
434
|
l_curr = exp_m_diff * l_prev + p_rowsum
|
|
422
435
|
head_l_ref[...] = l_curr
|
|
436
|
+
|
|
437
|
+
return p, exp_m_diff
|
|
438
|
+
|
|
439
|
+
def flash_attention_step2_pv(
|
|
440
|
+
q_shape_0,
|
|
441
|
+
kv, # [bkv_sz, actual_head_dim_x2]
|
|
442
|
+
p, # from step1
|
|
443
|
+
exp_m_diff, # from step1
|
|
444
|
+
*,
|
|
445
|
+
bkv_idx,
|
|
446
|
+
kv_head_idx,
|
|
447
|
+
):
|
|
448
|
+
head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
|
|
449
|
+
|
|
450
|
+
def load_with_init(ref, init_val):
|
|
451
|
+
return jnp.where(bkv_idx == bkv_idx_start,
|
|
452
|
+
jnp.full_like(ref, init_val), ref[...])
|
|
453
|
+
|
|
454
|
+
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
455
|
+
if v_scale is not None:
|
|
456
|
+
pv *= v_scale
|
|
457
|
+
|
|
423
458
|
o_prev = load_with_init(head_acc_ref, 0.0)
|
|
424
459
|
o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
|
|
425
460
|
head_acc_ref[...] = o_curr
|
|
@@ -498,19 +533,16 @@ def _ragged_paged_attention_kernel(
|
|
|
498
533
|
unroll=False,
|
|
499
534
|
)
|
|
500
535
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
sem,
|
|
512
|
-
wait,
|
|
513
|
-
)
|
|
536
|
+
size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
|
|
537
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
538
|
+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
|
|
539
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
540
|
+
_async_copy(
|
|
541
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
|
|
542
|
+
vmem_ref.at[pl.ds(offset, size)],
|
|
543
|
+
sem,
|
|
544
|
+
wait,
|
|
545
|
+
)
|
|
514
546
|
|
|
515
547
|
return kv_len_start + offset, bkv_sz_frm_new
|
|
516
548
|
else:
|
|
@@ -704,7 +736,7 @@ def _ragged_paged_attention_kernel(
|
|
|
704
736
|
vec = ref[start::step]
|
|
705
737
|
return vec
|
|
706
738
|
|
|
707
|
-
def strided_load_bkv(bkv_sem_idx, start, step
|
|
739
|
+
def strided_load_bkv(bkv_sem_idx, start, step):
|
|
708
740
|
assert start % kv_packing == 0
|
|
709
741
|
assert step % kv_packing == 0
|
|
710
742
|
start //= kv_packing
|
|
@@ -713,7 +745,6 @@ def _ragged_paged_attention_kernel(
|
|
|
713
745
|
bkv_sz * step, actual_head_dim_x2))
|
|
714
746
|
|
|
715
747
|
kv = strided_load(kv_ref, start, step)
|
|
716
|
-
kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
|
|
717
748
|
bitwidth = 32 // kv_packing
|
|
718
749
|
repack_ty = jnp.dtype(f"uint{bitwidth}")
|
|
719
750
|
lst = []
|
|
@@ -760,13 +791,21 @@ def _ragged_paged_attention_kernel(
|
|
|
760
791
|
next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
|
|
761
792
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
762
793
|
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
794
|
+
if sliding_window is None:
|
|
795
|
+
# When sliding window is disabled, starting bkv_idx of next request is
|
|
796
|
+
# always 0 regardless of seq_idx of next request.
|
|
797
|
+
next_bkv_idx_start = 0
|
|
798
|
+
else:
|
|
799
|
+
# Determine starting bkv_idx of next request based on whether next
|
|
800
|
+
# request is from the same sequence or next sequence.
|
|
801
|
+
next_bkv_idx_start = lax.select(
|
|
766
802
|
is_last_bq,
|
|
767
|
-
|
|
803
|
+
next_seq_bkv_idx_start,
|
|
768
804
|
bkv_idx_start,
|
|
769
|
-
)
|
|
805
|
+
)
|
|
806
|
+
next_bkv_idx = lax.select(is_last_bkv, next_bkv_idx_start,
|
|
807
|
+
next_bkv_idx)
|
|
808
|
+
|
|
770
809
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
771
810
|
|
|
772
811
|
def compute_with_bq(bq_idx, _):
|
|
@@ -783,10 +822,6 @@ def _ragged_paged_attention_kernel(
|
|
|
783
822
|
def compute_with_bkv(bkv_idx, _):
|
|
784
823
|
# Create bitmask for KV.
|
|
785
824
|
assert bkv_sz % kv_packing == 0
|
|
786
|
-
actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
|
|
787
|
-
bkv_shape = (bkv_sz, actual_head_dim_x2)
|
|
788
|
-
bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
|
|
789
|
-
0) < actual_bkv_sz
|
|
790
825
|
|
|
791
826
|
# Get next bkv ids.
|
|
792
827
|
bkv_sem_idx = sem_ids_ref[1]
|
|
@@ -826,29 +861,64 @@ def _ragged_paged_attention_kernel(
|
|
|
826
861
|
return
|
|
827
862
|
|
|
828
863
|
# Flash attention with cur bkv and bq
|
|
864
|
+
prev_bq_shape_0 = None
|
|
865
|
+
prev_kv_head_bkv = None
|
|
866
|
+
prev_kv_head_idx = None
|
|
867
|
+
prev_kv_head_p = None
|
|
868
|
+
prev_kv_head_exp_m_diff = None
|
|
829
869
|
for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
|
|
830
870
|
bkv_lst = strided_load_bkv(
|
|
831
871
|
bkv_sem_idx,
|
|
832
872
|
kv_head_start,
|
|
833
873
|
num_kv_heads,
|
|
834
|
-
bkv_mask=bkv_mask,
|
|
835
874
|
)
|
|
836
875
|
assert len(bkv_lst) == kv_packing
|
|
837
876
|
for i in range(kv_packing):
|
|
838
|
-
|
|
839
|
-
if
|
|
877
|
+
cur_kv_head_idx = kv_head_start + i
|
|
878
|
+
if cur_kv_head_idx >= actual_num_kv_heads:
|
|
840
879
|
break
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
880
|
+
cur_kv_head_bq = load_bq(bq_sem_idx,
|
|
881
|
+
cur_kv_head_idx,
|
|
882
|
+
actual_bq_sz=actual_bq_sz)
|
|
883
|
+
cur_kv_head__bkv = bkv_lst[i]
|
|
884
|
+
# FlashAttention is divided into `flash_attention_step1_qk_softmax`
|
|
885
|
+
# and `flash_attention_step2_pv` to pipeline the computation.
|
|
886
|
+
# `step2_pv` for the previous KV head, which depends on the softmax
|
|
887
|
+
# output, is overlapped with `step1_qk_softmax` for the current KV
|
|
888
|
+
# head, reducing overall wait times.
|
|
889
|
+
cur_kv_head_p, cur_kv_head_exp_m_diff = (
|
|
890
|
+
flash_attention_step1_qk_softmax(
|
|
891
|
+
cur_kv_head_bq,
|
|
892
|
+
cur_kv_head__bkv,
|
|
893
|
+
bq_idx=bq_idx,
|
|
894
|
+
bkv_idx=bkv_idx,
|
|
895
|
+
kv_head_idx=cur_kv_head_idx,
|
|
896
|
+
))
|
|
897
|
+
if prev_bq_shape_0 is not None:
|
|
898
|
+
flash_attention_step2_pv(
|
|
899
|
+
prev_bq_shape_0,
|
|
900
|
+
prev_kv_head_bkv,
|
|
901
|
+
prev_kv_head_p,
|
|
902
|
+
prev_kv_head_exp_m_diff,
|
|
903
|
+
bkv_idx=bkv_idx,
|
|
904
|
+
kv_head_idx=prev_kv_head_idx,
|
|
905
|
+
)
|
|
906
|
+
prev_bq_shape_0 = cur_kv_head_bq.shape[0]
|
|
907
|
+
prev_kv_head_bkv = cur_kv_head__bkv
|
|
908
|
+
prev_kv_head_p = cur_kv_head_p
|
|
909
|
+
prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
|
|
910
|
+
prev_kv_head_idx = cur_kv_head_idx
|
|
911
|
+
|
|
912
|
+
# Execute pv of last attention head.
|
|
913
|
+
assert prev_bq_shape_0 is not None
|
|
914
|
+
flash_attention_step2_pv(
|
|
915
|
+
prev_bq_shape_0,
|
|
916
|
+
prev_kv_head_bkv,
|
|
917
|
+
prev_kv_head_p,
|
|
918
|
+
prev_kv_head_exp_m_diff,
|
|
919
|
+
bkv_idx=bkv_idx,
|
|
920
|
+
kv_head_idx=prev_kv_head_idx,
|
|
921
|
+
)
|
|
852
922
|
|
|
853
923
|
lax.fori_loop(bkv_idx_start,
|
|
854
924
|
num_bkv,
|
|
@@ -884,7 +954,17 @@ def _ragged_paged_attention_kernel(
|
|
|
884
954
|
@pl.when(seq_idx == 0)
|
|
885
955
|
def prologue():
|
|
886
956
|
start_fetch_bq(0, 0, 0)
|
|
957
|
+
|
|
958
|
+
# Initialize bkv_x2_ref to zeros to avoid NaN issues from accessing
|
|
959
|
+
# uninitialized memory. Bitcast into int32 to avoid tiling issues.
|
|
960
|
+
bkv_x2_int32_ref = bkv_x2_ref.bitcast(jnp.int32).reshape(
|
|
961
|
+
(2, -1, 8, 128))
|
|
962
|
+
zeros = jnp.zeros(bkv_x2_int32_ref.shape[1:], jnp.int32)
|
|
963
|
+
|
|
964
|
+
# To pipeline VST and DMA, we divide the initialization into two steps.
|
|
965
|
+
bkv_x2_int32_ref[0] = zeros
|
|
887
966
|
start_fetch_bkv(0, bkv_idx_start, 0)
|
|
967
|
+
bkv_x2_int32_ref[1] = zeros
|
|
888
968
|
|
|
889
969
|
@pl.when(seq_idx < decode_end)
|
|
890
970
|
def process_decode():
|
|
@@ -1244,6 +1324,10 @@ def static_validate_inputs(
|
|
|
1244
1324
|
del debug_mode
|
|
1245
1325
|
|
|
1246
1326
|
|
|
1327
|
+
def get_kernel_scope_name(bq_size, bkv_p, page_size):
|
|
1328
|
+
return f"RPA-HD_64-bq_{bq_size}-bkvp_{bkv_p}-p_{page_size}-"
|
|
1329
|
+
|
|
1330
|
+
|
|
1247
1331
|
@functools.partial(
|
|
1248
1332
|
jax.jit,
|
|
1249
1333
|
static_argnames=(
|
|
@@ -1292,42 +1376,40 @@ def ragged_paged_attention_hd64(
|
|
|
1292
1376
|
# Debug params.
|
|
1293
1377
|
debug_mode: bool = False,
|
|
1294
1378
|
):
|
|
1295
|
-
"""A
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
The output of the attention.
|
|
1330
|
-
"""
|
|
1379
|
+
"""A variant of ragged paged attention for head_dim=64.
|
|
1380
|
+
|
|
1381
|
+
Args:
|
|
1382
|
+
queries: concatenated all sequences' queries.
|
|
1383
|
+
keys: concatenated all sequences' keys (quantized).
|
|
1384
|
+
values: concatenated all sequences' values (quantized).
|
|
1385
|
+
kv_cache: paged KV cache with TPU-friendly shape.
|
|
1386
|
+
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
|
|
1387
|
+
page_indices: flattened page indices look-up table by (seq_id, page_id).
|
|
1388
|
+
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
|
|
1389
|
+
kv_lens, only the first num_seqs+1 values are valid.
|
|
1390
|
+
distribution: (i, j, k) represents that sequences[0:i] are decode-only,
|
|
1391
|
+
sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
|
|
1392
|
+
k is also the total number of sequences.
|
|
1393
|
+
attention_sink: optional attention sink for each q head.
|
|
1394
|
+
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
|
1395
|
+
sliding_window: the sliding window size for the attention.
|
|
1396
|
+
soft_cap: the logit soft cap for the attention.
|
|
1397
|
+
mask_value: mask value for causal mask.
|
|
1398
|
+
q_scale: the scale for the query.
|
|
1399
|
+
k_scale: the scale for the key cache.
|
|
1400
|
+
v_scale: the scale for the value cache.
|
|
1401
|
+
chunk_prefill_size: the chunk prefill size for the attention.
|
|
1402
|
+
num_kv_pages_per_block: number of kv pages to be processed in one flash
|
|
1403
|
+
attention block in the pallas kernel.
|
|
1404
|
+
num_queries_per_block: number of kv pages to be processed in one flash
|
|
1405
|
+
attention block in the pallas kernel.
|
|
1406
|
+
vmem_limit_bytes: the vmem limit for the pallas kernel.
|
|
1407
|
+
debug_mode: if true, RPA does not issue any DMAs or run flash attention but
|
|
1408
|
+
print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
|
|
1409
|
+
|
|
1410
|
+
Returns:
|
|
1411
|
+
The output of the attention.
|
|
1412
|
+
"""
|
|
1331
1413
|
q, k, v = queries, keys, values
|
|
1332
1414
|
static_validate_inputs(
|
|
1333
1415
|
q,
|
|
@@ -1384,6 +1466,7 @@ def ragged_paged_attention_hd64(
|
|
|
1384
1466
|
page_size,
|
|
1385
1467
|
max_num_tokens,
|
|
1386
1468
|
pages_per_seq,
|
|
1469
|
+
sliding_window,
|
|
1387
1470
|
)
|
|
1388
1471
|
bkv_sz = bkv_p * page_size
|
|
1389
1472
|
if vmem_limit_bytes is None:
|
|
@@ -1397,7 +1480,7 @@ def ragged_paged_attention_hd64(
|
|
|
1397
1480
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1398
1481
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1399
1482
|
None if attention_sink is None else pl.BlockSpec(
|
|
1400
|
-
memory_space=pltpu.VMEM)
|
|
1483
|
+
memory_space=pltpu.VMEM),
|
|
1401
1484
|
]
|
|
1402
1485
|
|
|
1403
1486
|
out_specs = [
|
|
@@ -1454,47 +1537,45 @@ def ragged_paged_attention_hd64(
|
|
|
1454
1537
|
jnp.full((6, ), -1, jnp.int32),
|
|
1455
1538
|
)
|
|
1456
1539
|
|
|
1457
|
-
scope_name =
|
|
1458
|
-
kernel =
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
name=scope_name,
|
|
1497
|
-
))
|
|
1540
|
+
scope_name = get_kernel_scope_name(bq_sz, bkv_p, page_size)
|
|
1541
|
+
kernel = pl.pallas_call(
|
|
1542
|
+
functools.partial(
|
|
1543
|
+
_ragged_paged_attention_kernel,
|
|
1544
|
+
sm_scale=sm_scale,
|
|
1545
|
+
sliding_window=sliding_window,
|
|
1546
|
+
soft_cap=soft_cap,
|
|
1547
|
+
mask_value=mask_value,
|
|
1548
|
+
q_scale=q_scale,
|
|
1549
|
+
k_scale=k_scale,
|
|
1550
|
+
v_scale=v_scale,
|
|
1551
|
+
chunk_prefill_size=chunk_prefill_size,
|
|
1552
|
+
bq_sz=bq_sz,
|
|
1553
|
+
bkv_p=bkv_p,
|
|
1554
|
+
debug_mode=debug_mode,
|
|
1555
|
+
),
|
|
1556
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
1557
|
+
num_scalar_prefetch=len(scalar_prefetches),
|
|
1558
|
+
in_specs=in_specs,
|
|
1559
|
+
out_specs=out_specs,
|
|
1560
|
+
grid=grid,
|
|
1561
|
+
scratch_shapes=scratch_shapes,
|
|
1562
|
+
),
|
|
1563
|
+
compiler_params=pltpu.CompilerParams(
|
|
1564
|
+
# TODO(jevinjiang): since each sequence depends on the previous
|
|
1565
|
+
# one, we need some extra work to support Megacore mode.
|
|
1566
|
+
dimension_semantics=("arbitrary", ),
|
|
1567
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
1568
|
+
),
|
|
1569
|
+
out_shape=[
|
|
1570
|
+
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
|
|
1571
|
+
jax.ShapeDtypeStruct(shape=kv_cache.shape, dtype=kv_cache.dtype),
|
|
1572
|
+
],
|
|
1573
|
+
input_output_aliases={
|
|
1574
|
+
7: 0,
|
|
1575
|
+
9: 1
|
|
1576
|
+
},
|
|
1577
|
+
name=scope_name,
|
|
1578
|
+
)
|
|
1498
1579
|
|
|
1499
1580
|
output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache,
|
|
1500
1581
|
attention_sink)
|