tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +46 -17
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +44 -17
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
- tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock, patch
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pytest
|
|
21
|
+
import torch
|
|
22
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
23
|
+
from vllm.v1.kv_cache_interface import FullAttentionSpec, MLAAttentionSpec
|
|
24
|
+
|
|
25
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
26
|
+
from tpu_inference.runner.kv_cache import (create_kv_caches,
|
|
27
|
+
get_attention_page_size_bytes,
|
|
28
|
+
get_kv_cache_shape_with_mesh)
|
|
29
|
+
from tpu_inference.utils import get_dtype_packing
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.fixture
|
|
33
|
+
def mesh():
|
|
34
|
+
devices = np.array(jax.local_devices()[:1])
|
|
35
|
+
devices = devices.reshape((1, 1, -1))
|
|
36
|
+
return Mesh(devices, axis_names=("data", "attn_dp", "model"))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def test_create_kv_caches(mesh: Mesh):
|
|
40
|
+
"""
|
|
41
|
+
Tests that `create_kv_caches` correctly allocates and shards the KV caches
|
|
42
|
+
for all specified layers.
|
|
43
|
+
"""
|
|
44
|
+
num_blocks = 64
|
|
45
|
+
block_size = 16
|
|
46
|
+
num_kv_heads = 8
|
|
47
|
+
head_size = 128
|
|
48
|
+
layer_names = ["decoder.0", "decoder.1", "decoder.2"] # Test with 3 layers
|
|
49
|
+
|
|
50
|
+
expected_sharding = NamedSharding(
|
|
51
|
+
mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None, "model"))
|
|
52
|
+
expected_dtype = jnp.bfloat16
|
|
53
|
+
expected_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
|
|
54
|
+
num_kv_heads, head_size,
|
|
55
|
+
expected_dtype)
|
|
56
|
+
|
|
57
|
+
with patch("tpu_inference.logger.init_logger",
|
|
58
|
+
return_value=MagicMock()), patch(
|
|
59
|
+
"tpu_inference.utils.hbm_usage_gb",
|
|
60
|
+
return_value=[(0.0, 0.0), (0.0, 0.0)]):
|
|
61
|
+
kv_caches = create_kv_caches(
|
|
62
|
+
num_blocks=num_blocks,
|
|
63
|
+
block_size=block_size,
|
|
64
|
+
num_kv_heads=num_kv_heads,
|
|
65
|
+
head_size=head_size,
|
|
66
|
+
mesh=mesh,
|
|
67
|
+
layer_names=layer_names,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
assert isinstance(kv_caches, list)
|
|
71
|
+
assert len(kv_caches) == len(layer_names)
|
|
72
|
+
|
|
73
|
+
for cache_array in kv_caches:
|
|
74
|
+
assert isinstance(cache_array, jax.Array)
|
|
75
|
+
assert cache_array.shape == expected_shape
|
|
76
|
+
assert cache_array.dtype == expected_dtype
|
|
77
|
+
assert cache_array.sharding == expected_sharding
|
|
78
|
+
|
|
79
|
+
# Ensure that separate array objects were created for each layer
|
|
80
|
+
assert kv_caches[0] is not kv_caches[1]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def test_create_kv_caches_mla(mesh: Mesh):
|
|
84
|
+
"""
|
|
85
|
+
Tests that `create_kv_caches` correctly allocates and shards the KV caches
|
|
86
|
+
for all specified layers when `use_mla` is True.
|
|
87
|
+
"""
|
|
88
|
+
num_blocks = 64
|
|
89
|
+
block_size = 16
|
|
90
|
+
num_kv_heads = 1 # Not used for MLA shape calculation
|
|
91
|
+
head_size = 512 + 64 # Combined dimension for MLA
|
|
92
|
+
layer_names = ["decoder.0", "decoder.1"]
|
|
93
|
+
|
|
94
|
+
# For MLA, sharding is by the 'model' axis on the token dimension.
|
|
95
|
+
expected_sharding = NamedSharding(
|
|
96
|
+
mesh, PartitionSpec(ShardingAxisName.MLP_TENSOR))
|
|
97
|
+
expected_dtype = jnp.bfloat16
|
|
98
|
+
expected_shape = get_kv_cache_shape_with_mesh(
|
|
99
|
+
mesh,
|
|
100
|
+
num_blocks,
|
|
101
|
+
block_size,
|
|
102
|
+
num_kv_heads,
|
|
103
|
+
head_size,
|
|
104
|
+
expected_dtype,
|
|
105
|
+
use_mla=True,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
with patch("tpu_inference.logger.init_logger",
|
|
109
|
+
return_value=MagicMock()), patch(
|
|
110
|
+
"tpu_inference.utils.hbm_usage_gb",
|
|
111
|
+
return_value=[(0.0, 0.0), (0.0, 0.0)]):
|
|
112
|
+
kv_caches = create_kv_caches(
|
|
113
|
+
num_blocks=num_blocks,
|
|
114
|
+
block_size=block_size,
|
|
115
|
+
num_kv_heads=num_kv_heads,
|
|
116
|
+
head_size=head_size,
|
|
117
|
+
mesh=mesh,
|
|
118
|
+
layer_names=layer_names,
|
|
119
|
+
use_mla=True,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
assert isinstance(kv_caches, list)
|
|
123
|
+
assert len(kv_caches) == len(layer_names)
|
|
124
|
+
|
|
125
|
+
for cache_array in kv_caches:
|
|
126
|
+
assert isinstance(cache_array, jax.Array)
|
|
127
|
+
assert cache_array.shape == expected_shape
|
|
128
|
+
assert cache_array.dtype == expected_dtype
|
|
129
|
+
assert cache_array.sharding == expected_sharding
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def test_get_kv_cache_shape_with_mesh_mla(mesh: Mesh):
|
|
133
|
+
"""
|
|
134
|
+
Tests `get_kv_cache_shape_with_mesh` with `use_mla=True`.
|
|
135
|
+
"""
|
|
136
|
+
total_num_pages = 64
|
|
137
|
+
page_size = 16
|
|
138
|
+
actual_num_kv_heads = 1 # Not used for MLA
|
|
139
|
+
actual_head_dim = 512 + 128 # lkv_dim + r_dim
|
|
140
|
+
kv_dtype = jnp.bfloat16
|
|
141
|
+
|
|
142
|
+
# Expected shape calculation for MLA:
|
|
143
|
+
# kv_packing = 2 (for bfloat16)
|
|
144
|
+
# shape[0] = total_num_pages = 64
|
|
145
|
+
# shape[1] = align_to(page_size, 2) // 2 = 16 // 2 = 8
|
|
146
|
+
# shape[2] = 2
|
|
147
|
+
# shape[3] = align_to(actual_head_dim, 128) = align_to(640, 128) = 640
|
|
148
|
+
expected_shape = (64, 8, 2, 640)
|
|
149
|
+
|
|
150
|
+
shape = get_kv_cache_shape_with_mesh(
|
|
151
|
+
mesh,
|
|
152
|
+
total_num_pages,
|
|
153
|
+
page_size,
|
|
154
|
+
actual_num_kv_heads,
|
|
155
|
+
actual_head_dim,
|
|
156
|
+
kv_dtype,
|
|
157
|
+
use_mla=True,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
assert shape == expected_shape
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def test_get_attention_page_size_bytes(mesh: Mesh):
|
|
164
|
+
"""
|
|
165
|
+
Tests `get_attention_page_size_bytes`.
|
|
166
|
+
"""
|
|
167
|
+
block_size = 16
|
|
168
|
+
num_kv_heads = 8
|
|
169
|
+
head_size = 128
|
|
170
|
+
dtype = torch.bfloat16
|
|
171
|
+
|
|
172
|
+
full_attn_spec = FullAttentionSpec(block_size=block_size,
|
|
173
|
+
num_kv_heads=num_kv_heads,
|
|
174
|
+
head_size=head_size,
|
|
175
|
+
dtype=dtype)
|
|
176
|
+
|
|
177
|
+
kv_cache_specs = {"layer.0": full_attn_spec}
|
|
178
|
+
|
|
179
|
+
page_size_bytes = get_attention_page_size_bytes(mesh, kv_cache_specs)
|
|
180
|
+
|
|
181
|
+
shape = get_kv_cache_shape_with_mesh(mesh, 1, block_size, num_kv_heads,
|
|
182
|
+
head_size, jnp.bfloat16)
|
|
183
|
+
expected_page_size = (
|
|
184
|
+
(32 // get_dtype_packing(jnp.bfloat16)) * np.prod(shape)) // 8
|
|
185
|
+
|
|
186
|
+
assert page_size_bytes == expected_page_size
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def test_get_attention_page_size_bytes_mla(mesh: Mesh):
|
|
190
|
+
"""
|
|
191
|
+
Tests `get_attention_page_size_bytes` for MLA.
|
|
192
|
+
"""
|
|
193
|
+
block_size = 16
|
|
194
|
+
num_kv_heads = 1
|
|
195
|
+
head_size = 512 + 128 # lkv_dim + r_dim
|
|
196
|
+
dtype = torch.bfloat16
|
|
197
|
+
|
|
198
|
+
mla_spec = MLAAttentionSpec(
|
|
199
|
+
block_size=block_size,
|
|
200
|
+
num_kv_heads=num_kv_heads,
|
|
201
|
+
head_size=head_size,
|
|
202
|
+
dtype=dtype,
|
|
203
|
+
cache_dtype_str="bfloat16",
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
kv_cache_specs = {"layer.0": mla_spec}
|
|
207
|
+
|
|
208
|
+
page_size_bytes = get_attention_page_size_bytes(mesh, kv_cache_specs)
|
|
209
|
+
|
|
210
|
+
shape = get_kv_cache_shape_with_mesh(mesh,
|
|
211
|
+
1,
|
|
212
|
+
block_size,
|
|
213
|
+
num_kv_heads,
|
|
214
|
+
head_size,
|
|
215
|
+
jnp.bfloat16,
|
|
216
|
+
use_mla=True)
|
|
217
|
+
expected_page_size = (
|
|
218
|
+
(32 // get_dtype_packing(jnp.bfloat16)) * np.prod(shape)) // 8
|
|
219
|
+
|
|
220
|
+
assert page_size_bytes == expected_page_size
|