tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,520 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
import numpy as np
|
|
18
|
+
from absl.testing import absltest, parameterized
|
|
19
|
+
from jax._src import dtypes
|
|
20
|
+
from jax._src import test_util as jtu
|
|
21
|
+
|
|
22
|
+
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import (
|
|
23
|
+
ragged_paged_attention, ref_ragged_paged_attention)
|
|
24
|
+
from tpu_inference.kernels.ragged_paged_attention.v3.util import (
|
|
25
|
+
align_to, cdiv, get_dtype_packing)
|
|
26
|
+
|
|
27
|
+
jax.config.parse_flags_with_absl()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
31
|
+
class RaggedPagedAttentionKernelTest(jtu.JaxTestCase):
|
|
32
|
+
|
|
33
|
+
def _test_ragged_paged_attention(
|
|
34
|
+
self,
|
|
35
|
+
seq_lens, # List[(q_len, kv_len)]
|
|
36
|
+
num_heads, # [num_q_heads, num_kv_heads]
|
|
37
|
+
head_dim,
|
|
38
|
+
page_size,
|
|
39
|
+
q_dtype,
|
|
40
|
+
kv_dtype,
|
|
41
|
+
num_pages,
|
|
42
|
+
*,
|
|
43
|
+
num_kv_pages_per_block=8,
|
|
44
|
+
num_queries_per_block=64,
|
|
45
|
+
vmem_limit_bytes=100 * 1024 * 1024,
|
|
46
|
+
max_num_batched_tokens=512,
|
|
47
|
+
max_num_seq=8,
|
|
48
|
+
sliding_window: int | None = None,
|
|
49
|
+
soft_cap: float | None = None,
|
|
50
|
+
q_scale: float | None = None,
|
|
51
|
+
k_scale: float | None = None,
|
|
52
|
+
v_scale: float | None = None,
|
|
53
|
+
):
|
|
54
|
+
rng = np.random.default_rng(1234)
|
|
55
|
+
|
|
56
|
+
def gen_random(shape, dtype):
|
|
57
|
+
return jnp.array(rng.random(size=shape,
|
|
58
|
+
dtype=np.float32)).astype(dtype)
|
|
59
|
+
|
|
60
|
+
if not jtu.is_device_tpu_at_least(version=4):
|
|
61
|
+
self.skipTest("Expect TPUv4+")
|
|
62
|
+
cu_q_lens = [0]
|
|
63
|
+
kv_lens = []
|
|
64
|
+
for q_len, kv_len in seq_lens:
|
|
65
|
+
assert q_len <= kv_len
|
|
66
|
+
cu_q_lens.append(cu_q_lens[-1] + q_len)
|
|
67
|
+
kv_lens.append(kv_len)
|
|
68
|
+
|
|
69
|
+
max_num_batched_tokens = max(align_to(cu_q_lens[-1], 128),
|
|
70
|
+
max_num_batched_tokens)
|
|
71
|
+
max_num_seq = max(align_to(len(seq_lens), 8), max_num_seq)
|
|
72
|
+
max_kv_len = max(kv_lens)
|
|
73
|
+
pages_per_seq = cdiv(max_kv_len, page_size)
|
|
74
|
+
num_q_heads, num_kv_heads = num_heads
|
|
75
|
+
|
|
76
|
+
q = gen_random((max_num_batched_tokens, num_q_heads, head_dim),
|
|
77
|
+
q_dtype)
|
|
78
|
+
k = gen_random((max_num_batched_tokens, num_kv_heads, head_dim),
|
|
79
|
+
kv_dtype)
|
|
80
|
+
v = gen_random((max_num_batched_tokens, num_kv_heads, head_dim),
|
|
81
|
+
kv_dtype)
|
|
82
|
+
page_cnt = 0
|
|
83
|
+
page_indices_list = []
|
|
84
|
+
kv_pages_list = []
|
|
85
|
+
kv_packing = get_dtype_packing(kv_dtype)
|
|
86
|
+
padded_head_dim = align_to(head_dim, 128)
|
|
87
|
+
num_kv_heads_x2 = align_to(num_kv_heads * 2, kv_packing)
|
|
88
|
+
for kv_len in kv_lens:
|
|
89
|
+
kv = gen_random((
|
|
90
|
+
kv_len,
|
|
91
|
+
num_kv_heads_x2 // kv_packing,
|
|
92
|
+
kv_packing,
|
|
93
|
+
padded_head_dim,
|
|
94
|
+
), kv_dtype)
|
|
95
|
+
kv = jnp.pad(
|
|
96
|
+
kv,
|
|
97
|
+
(
|
|
98
|
+
(
|
|
99
|
+
0,
|
|
100
|
+
cdiv(kv_len, page_size) * page_size - kv_len,
|
|
101
|
+
),
|
|
102
|
+
(0, 0),
|
|
103
|
+
(0, 0),
|
|
104
|
+
(0, 0),
|
|
105
|
+
),
|
|
106
|
+
constant_values=jnp.nan,
|
|
107
|
+
).reshape(
|
|
108
|
+
-1,
|
|
109
|
+
page_size,
|
|
110
|
+
num_kv_heads_x2 // kv_packing,
|
|
111
|
+
kv_packing,
|
|
112
|
+
padded_head_dim,
|
|
113
|
+
)
|
|
114
|
+
indices = page_cnt + jnp.arange(kv.shape[0], dtype=jnp.int32)
|
|
115
|
+
indices = jnp.pad(
|
|
116
|
+
indices,
|
|
117
|
+
((0, pages_per_seq - indices.shape[0]), ),
|
|
118
|
+
constant_values=jnp.nan,
|
|
119
|
+
)
|
|
120
|
+
page_indices_list.append(indices)
|
|
121
|
+
page_cnt += kv.shape[0]
|
|
122
|
+
kv_pages_list.append(kv)
|
|
123
|
+
|
|
124
|
+
kv_cache = jnp.concatenate(kv_pages_list, axis=0)
|
|
125
|
+
kv_cache = jnp.pad(
|
|
126
|
+
kv_cache,
|
|
127
|
+
((0, num_pages - kv_cache.shape[0]), (0, 0), (0, 0), (0, 0),
|
|
128
|
+
(0, 0)),
|
|
129
|
+
constant_values=jnp.nan,
|
|
130
|
+
)
|
|
131
|
+
page_indices = jnp.stack(page_indices_list, axis=0)
|
|
132
|
+
page_indices = jnp.pad(
|
|
133
|
+
page_indices,
|
|
134
|
+
((0, max_num_seq - page_indices.shape[0]), (0, 0)),
|
|
135
|
+
constant_values=jnp.nan,
|
|
136
|
+
)
|
|
137
|
+
page_indices = page_indices.reshape(-1)
|
|
138
|
+
|
|
139
|
+
cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32)
|
|
140
|
+
cu_q_lens = jnp.pad(cu_q_lens,
|
|
141
|
+
(0, max_num_seq + 1 - cu_q_lens.shape[0]))
|
|
142
|
+
kv_lens = jnp.array(kv_lens, dtype=jnp.int32)
|
|
143
|
+
kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0]))
|
|
144
|
+
distribution = jnp.array([0, 0, len(seq_lens)], dtype=jnp.int32)
|
|
145
|
+
|
|
146
|
+
args = (
|
|
147
|
+
q,
|
|
148
|
+
k,
|
|
149
|
+
v,
|
|
150
|
+
kv_cache,
|
|
151
|
+
kv_lens,
|
|
152
|
+
page_indices,
|
|
153
|
+
cu_q_lens,
|
|
154
|
+
distribution,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
kwargs = {
|
|
158
|
+
"sliding_window": sliding_window,
|
|
159
|
+
"soft_cap": soft_cap,
|
|
160
|
+
"q_scale": q_scale,
|
|
161
|
+
"k_scale": k_scale,
|
|
162
|
+
"v_scale": v_scale,
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
expected, expected_kv_cache = ref_ragged_paged_attention(
|
|
166
|
+
*args,
|
|
167
|
+
**kwargs,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
output, updated_kv_cache = ragged_paged_attention(
|
|
171
|
+
*args,
|
|
172
|
+
**kwargs,
|
|
173
|
+
num_kv_pages_per_block=num_kv_pages_per_block,
|
|
174
|
+
num_queries_per_block=num_queries_per_block,
|
|
175
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
176
|
+
)
|
|
177
|
+
output = output[:cu_q_lens[distribution[-1]]]
|
|
178
|
+
|
|
179
|
+
dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
|
|
180
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(
|
|
181
|
+
jnp.dtype(kv_dtype)))
|
|
182
|
+
tols = {
|
|
183
|
+
32: 0.15,
|
|
184
|
+
16: 0.2,
|
|
185
|
+
8: 0.2,
|
|
186
|
+
4: 0.2,
|
|
187
|
+
}
|
|
188
|
+
tol = tols[dtype_bits]
|
|
189
|
+
self.assertAllClose(output, expected, atol=tol, rtol=tol)
|
|
190
|
+
mask = ~jnp.isnan(expected_kv_cache)
|
|
191
|
+
self.assertArraysEqual(updated_kv_cache[mask], expected_kv_cache[mask])
|
|
192
|
+
self.assertEqual(output.shape[-1], head_dim)
|
|
193
|
+
|
|
194
|
+
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
|
|
195
|
+
def test_ragged_paged_attention_basic(self, dtype):
|
|
196
|
+
seq_lens = [(192, 328), (128, 180), (64, 255)]
|
|
197
|
+
num_heads = (32, 8)
|
|
198
|
+
head_dim = 128
|
|
199
|
+
page_size = 16
|
|
200
|
+
num_pages = 1000
|
|
201
|
+
|
|
202
|
+
self._test_ragged_paged_attention(
|
|
203
|
+
seq_lens,
|
|
204
|
+
num_heads,
|
|
205
|
+
head_dim,
|
|
206
|
+
page_size,
|
|
207
|
+
dtype,
|
|
208
|
+
dtype,
|
|
209
|
+
num_pages,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# TODO: support integer (int8, int4) and fp4 kv cache
|
|
213
|
+
@parameterized.product(
|
|
214
|
+
q_dtype=[jnp.bfloat16],
|
|
215
|
+
kv_dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn],
|
|
216
|
+
kv_scales=[(0.5, 0.5), (1.0, 1.0)],
|
|
217
|
+
)
|
|
218
|
+
def test_ragged_paged_attention_quantized_kv_cache(self, q_dtype, kv_dtype,
|
|
219
|
+
kv_scales):
|
|
220
|
+
if not jtu.is_device_tpu_at_least(version=5):
|
|
221
|
+
self.skipTest("Expect TPUv5+")
|
|
222
|
+
seq_lens = [(192, 328), (128, 180), (64, 255)]
|
|
223
|
+
num_heads = (32, 8)
|
|
224
|
+
head_dim = 128
|
|
225
|
+
page_size = 16
|
|
226
|
+
num_pages = 1000
|
|
227
|
+
k_scale, v_scale = kv_scales
|
|
228
|
+
|
|
229
|
+
self._test_ragged_paged_attention(
|
|
230
|
+
seq_lens,
|
|
231
|
+
num_heads,
|
|
232
|
+
head_dim,
|
|
233
|
+
page_size,
|
|
234
|
+
q_dtype,
|
|
235
|
+
kv_dtype,
|
|
236
|
+
num_pages,
|
|
237
|
+
k_scale=k_scale,
|
|
238
|
+
v_scale=v_scale,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
@parameterized.product(
|
|
242
|
+
q_dtype=[jnp.bfloat16],
|
|
243
|
+
kv_dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn],
|
|
244
|
+
q_scale=[0.5, 1.0],
|
|
245
|
+
kv_scales=[(0.5, 0.5), (1.0, 1.0)],
|
|
246
|
+
)
|
|
247
|
+
def test_ragged_paged_attention_quantized_attention(
|
|
248
|
+
self, q_dtype, kv_dtype, q_scale, kv_scales):
|
|
249
|
+
if not jtu.is_device_tpu_at_least(version=5):
|
|
250
|
+
self.skipTest("Expect TPUv5+")
|
|
251
|
+
seq_lens = [(192, 328), (128, 180), (64, 255)]
|
|
252
|
+
num_heads = (32, 8)
|
|
253
|
+
head_dim = 128
|
|
254
|
+
page_size = 16
|
|
255
|
+
num_pages = 1000
|
|
256
|
+
k_scale, v_scale = kv_scales
|
|
257
|
+
|
|
258
|
+
self._test_ragged_paged_attention(
|
|
259
|
+
seq_lens,
|
|
260
|
+
num_heads,
|
|
261
|
+
head_dim,
|
|
262
|
+
page_size,
|
|
263
|
+
q_dtype,
|
|
264
|
+
kv_dtype,
|
|
265
|
+
num_pages,
|
|
266
|
+
q_scale=q_scale,
|
|
267
|
+
k_scale=k_scale,
|
|
268
|
+
v_scale=v_scale,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
|
|
272
|
+
def test_ragged_paged_attention_decode_only(self, dtype):
|
|
273
|
+
seq_lens = [
|
|
274
|
+
(1, 18),
|
|
275
|
+
(1, 129),
|
|
276
|
+
(1, 597),
|
|
277
|
+
(1, 122),
|
|
278
|
+
(1, 64),
|
|
279
|
+
(1, 322),
|
|
280
|
+
(1, 463),
|
|
281
|
+
(1, 181),
|
|
282
|
+
(1, 1107),
|
|
283
|
+
(1, 123),
|
|
284
|
+
(1, 31),
|
|
285
|
+
(1, 18),
|
|
286
|
+
(1, 1229),
|
|
287
|
+
(1, 229),
|
|
288
|
+
(1, 87),
|
|
289
|
+
(1, 1328),
|
|
290
|
+
]
|
|
291
|
+
num_heads = (32, 8)
|
|
292
|
+
head_dim = 128
|
|
293
|
+
page_size = 16
|
|
294
|
+
num_pages = 1000
|
|
295
|
+
|
|
296
|
+
self._test_ragged_paged_attention(
|
|
297
|
+
seq_lens,
|
|
298
|
+
num_heads,
|
|
299
|
+
head_dim,
|
|
300
|
+
page_size,
|
|
301
|
+
dtype,
|
|
302
|
+
dtype,
|
|
303
|
+
num_pages,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
|
|
307
|
+
def test_ragged_paged_attention_prefill_only(self, dtype):
|
|
308
|
+
seq_lens = [
|
|
309
|
+
(5, 18),
|
|
310
|
+
(15, 129),
|
|
311
|
+
(120, 597),
|
|
312
|
+
(100, 122),
|
|
313
|
+
(21, 64),
|
|
314
|
+
(32, 322),
|
|
315
|
+
(251, 463),
|
|
316
|
+
(40, 181),
|
|
317
|
+
(64, 1107),
|
|
318
|
+
(99, 123),
|
|
319
|
+
(10, 31),
|
|
320
|
+
(5, 18),
|
|
321
|
+
(3, 1229),
|
|
322
|
+
(120, 229),
|
|
323
|
+
(9, 87),
|
|
324
|
+
(2, 1328),
|
|
325
|
+
]
|
|
326
|
+
num_heads = (32, 8)
|
|
327
|
+
head_dim = 128
|
|
328
|
+
page_size = 16
|
|
329
|
+
num_pages = 1000
|
|
330
|
+
|
|
331
|
+
self._test_ragged_paged_attention(
|
|
332
|
+
seq_lens,
|
|
333
|
+
num_heads,
|
|
334
|
+
head_dim,
|
|
335
|
+
page_size,
|
|
336
|
+
dtype,
|
|
337
|
+
dtype,
|
|
338
|
+
num_pages,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
|
|
342
|
+
def test_ragged_paged_attention_mixed(self, dtype):
|
|
343
|
+
seq_lens = [
|
|
344
|
+
(5, 18),
|
|
345
|
+
(1, 129),
|
|
346
|
+
(120, 597),
|
|
347
|
+
(1, 122),
|
|
348
|
+
(1, 64),
|
|
349
|
+
(32, 322),
|
|
350
|
+
(251, 463),
|
|
351
|
+
(1, 181),
|
|
352
|
+
(1, 1107),
|
|
353
|
+
(99, 123),
|
|
354
|
+
(1, 31),
|
|
355
|
+
(5, 18),
|
|
356
|
+
(3, 1229),
|
|
357
|
+
(117, 229),
|
|
358
|
+
(1, 87),
|
|
359
|
+
(1, 1328),
|
|
360
|
+
]
|
|
361
|
+
num_heads = (32, 8)
|
|
362
|
+
head_dim = 128
|
|
363
|
+
page_size = 16
|
|
364
|
+
num_pages = 1000
|
|
365
|
+
|
|
366
|
+
self._test_ragged_paged_attention(
|
|
367
|
+
seq_lens,
|
|
368
|
+
num_heads,
|
|
369
|
+
head_dim,
|
|
370
|
+
page_size,
|
|
371
|
+
dtype,
|
|
372
|
+
dtype,
|
|
373
|
+
num_pages,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
@parameterized.product(
|
|
377
|
+
num_seqs=[1, 17],
|
|
378
|
+
num_heads=[(32, 8), (12, 2), (5, 1), (3, 3)],
|
|
379
|
+
head_dim=[80, 240],
|
|
380
|
+
dtype=[jnp.float32, jnp.bfloat16],
|
|
381
|
+
# num_kv_pages_per_block=[8, 16],
|
|
382
|
+
# num_queries_per_block=[16, 32],
|
|
383
|
+
)
|
|
384
|
+
def test_ragged_paged_attention_complex(
|
|
385
|
+
self,
|
|
386
|
+
num_seqs,
|
|
387
|
+
num_heads,
|
|
388
|
+
head_dim,
|
|
389
|
+
dtype,
|
|
390
|
+
# num_kv_pages_per_block,
|
|
391
|
+
# num_queries_per_block,
|
|
392
|
+
):
|
|
393
|
+
rng = np.random.default_rng(1234)
|
|
394
|
+
q_lens = rng.integers(1, 100, num_seqs)
|
|
395
|
+
kv_lens = q_lens + rng.integers(0, 50, num_seqs)
|
|
396
|
+
seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
|
|
397
|
+
page_size = 16
|
|
398
|
+
num_pages = 1000
|
|
399
|
+
|
|
400
|
+
self._test_ragged_paged_attention(
|
|
401
|
+
seq_lens,
|
|
402
|
+
num_heads,
|
|
403
|
+
head_dim,
|
|
404
|
+
page_size,
|
|
405
|
+
dtype,
|
|
406
|
+
dtype,
|
|
407
|
+
num_pages,
|
|
408
|
+
# num_kv_pages_per_block=num_kv_pages_per_block,
|
|
409
|
+
# num_queries_per_block=num_queries_per_block,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
@parameterized.product(sliding_window=[None, 5, 128], )
|
|
413
|
+
def test_ragged_paged_attention_sliding_window(
|
|
414
|
+
self,
|
|
415
|
+
sliding_window: int | None,
|
|
416
|
+
):
|
|
417
|
+
num_seqs = 5
|
|
418
|
+
num_heads = (4, 4)
|
|
419
|
+
dtype = jnp.float32
|
|
420
|
+
rng = np.random.default_rng(1234)
|
|
421
|
+
q_lens = rng.integers(1, 100, num_seqs)
|
|
422
|
+
kv_lens = q_lens + rng.integers(0, 50, num_seqs)
|
|
423
|
+
seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
|
|
424
|
+
head_dim = 128
|
|
425
|
+
page_size = 16
|
|
426
|
+
num_pages = 1000
|
|
427
|
+
|
|
428
|
+
self._test_ragged_paged_attention(
|
|
429
|
+
seq_lens,
|
|
430
|
+
num_heads,
|
|
431
|
+
head_dim,
|
|
432
|
+
page_size,
|
|
433
|
+
dtype,
|
|
434
|
+
dtype,
|
|
435
|
+
num_pages,
|
|
436
|
+
sliding_window=sliding_window,
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
@parameterized.product(soft_cap=[None, 50.0], )
|
|
440
|
+
def test_ragged_paged_attention_logit_soft_capping(
|
|
441
|
+
self,
|
|
442
|
+
soft_cap: float | None,
|
|
443
|
+
):
|
|
444
|
+
num_heads = (16, 2)
|
|
445
|
+
num_seqs = 2
|
|
446
|
+
dtype = jnp.float32
|
|
447
|
+
rng = np.random.default_rng(1234)
|
|
448
|
+
q_lens = rng.integers(1, 100, num_seqs)
|
|
449
|
+
kv_lens = q_lens + rng.integers(0, 50, num_seqs)
|
|
450
|
+
seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
|
|
451
|
+
head_dim = 128
|
|
452
|
+
page_size = 16
|
|
453
|
+
num_pages = 1000
|
|
454
|
+
|
|
455
|
+
self._test_ragged_paged_attention(
|
|
456
|
+
seq_lens,
|
|
457
|
+
num_heads,
|
|
458
|
+
head_dim,
|
|
459
|
+
page_size,
|
|
460
|
+
dtype,
|
|
461
|
+
dtype,
|
|
462
|
+
num_pages,
|
|
463
|
+
soft_cap=soft_cap,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
def test_ragged_paged_attention_sliding_window_should_be_positive(self):
|
|
467
|
+
dtype = jnp.float32
|
|
468
|
+
seq_lens = [(192, 328), (128, 180), (64, 255)]
|
|
469
|
+
num_heads = (32, 8)
|
|
470
|
+
head_dim = 128
|
|
471
|
+
page_size = 16
|
|
472
|
+
num_pages = 1000
|
|
473
|
+
|
|
474
|
+
with self.assertRaisesRegex(ValueError, "must be positive"):
|
|
475
|
+
self._test_ragged_paged_attention(
|
|
476
|
+
seq_lens,
|
|
477
|
+
num_heads,
|
|
478
|
+
head_dim,
|
|
479
|
+
page_size,
|
|
480
|
+
dtype,
|
|
481
|
+
dtype,
|
|
482
|
+
num_pages,
|
|
483
|
+
sliding_window=0,
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
with self.assertRaisesRegex(ValueError, "must be positive"):
|
|
487
|
+
self._test_ragged_paged_attention(
|
|
488
|
+
seq_lens,
|
|
489
|
+
num_heads,
|
|
490
|
+
head_dim,
|
|
491
|
+
page_size,
|
|
492
|
+
dtype,
|
|
493
|
+
dtype,
|
|
494
|
+
num_pages,
|
|
495
|
+
sliding_window=-1,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
def test_ragged_paged_attention_soft_cap_cannot_be_zero(self):
|
|
499
|
+
dtype = jnp.float32
|
|
500
|
+
seq_lens = [(192, 328), (128, 180), (64, 255)]
|
|
501
|
+
num_heads = (32, 8)
|
|
502
|
+
head_dim = 128
|
|
503
|
+
page_size = 16
|
|
504
|
+
num_pages = 1000
|
|
505
|
+
|
|
506
|
+
with self.assertRaisesRegex(ValueError, "must not be 0.0"):
|
|
507
|
+
self._test_ragged_paged_attention(
|
|
508
|
+
seq_lens,
|
|
509
|
+
num_heads,
|
|
510
|
+
head_dim,
|
|
511
|
+
page_size,
|
|
512
|
+
dtype,
|
|
513
|
+
dtype,
|
|
514
|
+
num_pages,
|
|
515
|
+
soft_cap=0.0,
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
if __name__ == "__main__":
|
|
520
|
+
absltest.main(testLoader=jtu.JaxTestLoader())
|
tests/layers/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pytest
|
|
21
|
+
from jax.sharding import Mesh
|
|
22
|
+
|
|
23
|
+
from tpu_inference.layers.common.attention_interface import attention
|
|
24
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
25
|
+
from tpu_inference.runner.kv_cache import get_kv_cache_shape_with_mesh
|
|
26
|
+
|
|
27
|
+
# ---- Test Configuration & Constants ----
|
|
28
|
+
|
|
29
|
+
# Total number of tokens across all sequences in the batch
|
|
30
|
+
TOTAL_TOKENS = 10
|
|
31
|
+
# Number of sequences in the batch
|
|
32
|
+
NUM_SEQS = 2
|
|
33
|
+
# Padded maximum number of sequences
|
|
34
|
+
MAX_NUM_SEQS = 4
|
|
35
|
+
# Number of attention heads (Query)
|
|
36
|
+
NUM_HEADS = 8
|
|
37
|
+
# Number of attention heads (Key/Value) - for Grouped-Query Attention
|
|
38
|
+
NUM_KV_HEADS = 4
|
|
39
|
+
# Total number of blocks in the KV cache
|
|
40
|
+
NUM_BLOCKS = 32
|
|
41
|
+
# Number of tokens per block
|
|
42
|
+
BLOCK_SIZE = 16
|
|
43
|
+
# Maximum number of blocks a single sequence can occupy
|
|
44
|
+
MAX_BLOCKS_PER_SEQ = 8
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@pytest.fixture
|
|
48
|
+
def mesh():
|
|
49
|
+
"""Provides a mock 1D JAX mesh for testing."""
|
|
50
|
+
# Create a mesh with available devices, useful for running on CPU/GPU/TPU
|
|
51
|
+
# For this test, it will likely be a single CPU device.
|
|
52
|
+
devices = np.array(jax.local_devices()[:1])
|
|
53
|
+
if not devices.any():
|
|
54
|
+
# Add a mock device if no devices are present (e.g., in a CI environment)
|
|
55
|
+
devices = np.array([jax.devices("cpu")[0]])
|
|
56
|
+
return Mesh(devices.reshape((-1, 1, 1)), ("data", "attn_dp", "model"))
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# ---- Test for `attention` ----
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _test_attention(monkeypatch, mesh, head_dim, use_sinks=False):
|
|
63
|
+
"""
|
|
64
|
+
Tests the main `attention` function.
|
|
65
|
+
|
|
66
|
+
Verifies that:
|
|
67
|
+
1. It calls the `sharded_ragged_paged_attention` kernel with correct metadata.
|
|
68
|
+
2. The final outputs (kv_cache and attention output) have the correct shapes.
|
|
69
|
+
"""
|
|
70
|
+
# 1. Arrange
|
|
71
|
+
|
|
72
|
+
# Create input tensors
|
|
73
|
+
q_dtype = jnp.float32
|
|
74
|
+
kv_dtype = jnp.float32
|
|
75
|
+
q = jnp.ones((TOTAL_TOKENS, NUM_HEADS, head_dim), dtype=q_dtype)
|
|
76
|
+
k = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, head_dim), dtype=kv_dtype)
|
|
77
|
+
v = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, head_dim), dtype=kv_dtype)
|
|
78
|
+
sinks = jnp.ones((NUM_HEADS, ), dtype=jnp.float32) if use_sinks else None
|
|
79
|
+
|
|
80
|
+
kv_cache_shape = get_kv_cache_shape_with_mesh(
|
|
81
|
+
mesh,
|
|
82
|
+
NUM_BLOCKS,
|
|
83
|
+
BLOCK_SIZE,
|
|
84
|
+
NUM_KV_HEADS,
|
|
85
|
+
head_dim,
|
|
86
|
+
kv_dtype,
|
|
87
|
+
)
|
|
88
|
+
kv_cache = jnp.zeros(kv_cache_shape, dtype=kv_dtype)
|
|
89
|
+
|
|
90
|
+
# Mock ragged_paged_attention to return a tensor of the correct shape
|
|
91
|
+
mock_paged_attn_kernel = MagicMock(return_value=(jnp.ones(
|
|
92
|
+
(TOTAL_TOKENS, NUM_HEADS, head_dim)), kv_cache), )
|
|
93
|
+
|
|
94
|
+
if head_dim == 64:
|
|
95
|
+
monkeypatch.setattr(
|
|
96
|
+
"tpu_inference.layers.common.attention_interface.ragged_paged_attention_hd64",
|
|
97
|
+
mock_paged_attn_kernel,
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
monkeypatch.setattr(
|
|
101
|
+
"tpu_inference.layers.common.attention_interface.ragged_paged_attention",
|
|
102
|
+
mock_paged_attn_kernel,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Create AttentionMetadata
|
|
106
|
+
attention_metadata = AttentionMetadata(
|
|
107
|
+
input_positions=jnp.arange(TOTAL_TOKENS, dtype=jnp.int32),
|
|
108
|
+
block_tables=jnp.zeros((MAX_NUM_SEQS * MAX_BLOCKS_PER_SEQ, ),
|
|
109
|
+
dtype=jnp.int32),
|
|
110
|
+
seq_lens=jnp.array([5, 5, 0, 0], dtype=jnp.int32),
|
|
111
|
+
query_start_loc=jnp.array([0, 5, 10, 10, 10], dtype=jnp.int32),
|
|
112
|
+
request_distribution=jnp.array([0, 0, NUM_SEQS], dtype=jnp.int32),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# 2. Act
|
|
116
|
+
final_kv_cache, output = attention(
|
|
117
|
+
kv_cache=kv_cache,
|
|
118
|
+
q=q,
|
|
119
|
+
k=k,
|
|
120
|
+
v=v,
|
|
121
|
+
attention_metadata=attention_metadata,
|
|
122
|
+
mesh=mesh,
|
|
123
|
+
head_dim_original=head_dim,
|
|
124
|
+
sinks=sinks,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# 3. Assert
|
|
128
|
+
# Check that both mocked kernels were called
|
|
129
|
+
mock_paged_attn_kernel.assert_called_once()
|
|
130
|
+
|
|
131
|
+
# Check output shapes
|
|
132
|
+
assert final_kv_cache.shape == kv_cache.shape
|
|
133
|
+
assert output.shape == q.shape
|
|
134
|
+
|
|
135
|
+
# Check that the output is the one from our mock
|
|
136
|
+
assert jnp.all(output == 1.0)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def test_attention(monkeypatch, mesh):
|
|
140
|
+
_test_attention(monkeypatch, mesh, 128)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def test_attention_hd64(monkeypatch, mesh):
|
|
144
|
+
_test_attention(monkeypatch, mesh, 64)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def test_attention_sink(monkeypatch, mesh):
|
|
148
|
+
_test_attention(monkeypatch, mesh, 64, True)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_attention_sink_no_64_raises_error(monkeypatch, mesh):
|
|
152
|
+
with pytest.raises(
|
|
153
|
+
NotImplementedError,
|
|
154
|
+
match="Attention sink support is only available when head_dim==64"
|
|
155
|
+
):
|
|
156
|
+
_test_attention(monkeypatch, mesh, 128, True)
|