tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +46 -17
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +44 -17
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
- tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import unittest
|
|
16
|
+
from unittest.mock import MagicMock
|
|
17
|
+
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
from flax import nnx
|
|
20
|
+
|
|
21
|
+
from tpu_inference.layers.jax.layers import DenseFFW
|
|
22
|
+
from tpu_inference.layers.jax.moe.moe import MoE
|
|
23
|
+
from tpu_inference.layers.jax.transformer_block import (
|
|
24
|
+
SharedExpertsTransformerBlock, TransformerBlock)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TestTransformerBlock(unittest.TestCase):
|
|
28
|
+
"""Unit test suite for the JAX TransformerBlock module."""
|
|
29
|
+
|
|
30
|
+
def test_transformer_block_dense_logic(self):
|
|
31
|
+
"""
|
|
32
|
+
Tests the forward pass logic of a dense TransformerBlock by mocking its sub-modules.
|
|
33
|
+
This test verifies the sequence of operations and residual connections.
|
|
34
|
+
"""
|
|
35
|
+
hidden_size = 1024
|
|
36
|
+
|
|
37
|
+
mock_pre_attn_norm = MagicMock(spec=nnx.Module)
|
|
38
|
+
mock_pre_mlp_norm = MagicMock(spec=nnx.Module)
|
|
39
|
+
|
|
40
|
+
mock_attn = MagicMock(spec=nnx.Module)
|
|
41
|
+
dummy_attn_output = jnp.full((64, hidden_size),
|
|
42
|
+
2.0,
|
|
43
|
+
dtype=jnp.bfloat16)
|
|
44
|
+
dummy_kv_cache = jnp.zeros((8, 16, 16, 128), dtype=jnp.bfloat16)
|
|
45
|
+
mock_attn.return_value = (dummy_kv_cache, dummy_attn_output)
|
|
46
|
+
|
|
47
|
+
mock_mlp = MagicMock(spec=DenseFFW)
|
|
48
|
+
dummy_mlp_output = jnp.full((64, hidden_size), 3.0, dtype=jnp.bfloat16)
|
|
49
|
+
mock_mlp.return_value = dummy_mlp_output
|
|
50
|
+
|
|
51
|
+
transformer_block = TransformerBlock(
|
|
52
|
+
pre_attention_norm=mock_pre_attn_norm,
|
|
53
|
+
pre_mlp_norm=mock_pre_mlp_norm,
|
|
54
|
+
custom_module=mock_mlp,
|
|
55
|
+
attn=mock_attn,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
seq_len = 64
|
|
59
|
+
x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
|
|
60
|
+
initial_kv_cache = MagicMock()
|
|
61
|
+
attention_metadata = MagicMock()
|
|
62
|
+
|
|
63
|
+
mock_pre_attn_norm.side_effect = lambda val: val
|
|
64
|
+
mock_pre_mlp_norm.side_effect = lambda val: val
|
|
65
|
+
|
|
66
|
+
new_kv_cache, final_output = transformer_block(
|
|
67
|
+
x,
|
|
68
|
+
is_prefill=True,
|
|
69
|
+
kv_cache=initial_kv_cache,
|
|
70
|
+
attention_metadata=attention_metadata,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
mock_pre_attn_norm.assert_called_once()
|
|
74
|
+
self.assertTrue(
|
|
75
|
+
jnp.array_equal(mock_pre_attn_norm.call_args.args[0], x))
|
|
76
|
+
|
|
77
|
+
mock_attn.assert_called_once_with(x, True, initial_kv_cache,
|
|
78
|
+
attention_metadata, True)
|
|
79
|
+
|
|
80
|
+
expected_mlp_norm_input = dummy_attn_output + x
|
|
81
|
+
|
|
82
|
+
mock_pre_mlp_norm.assert_called_once()
|
|
83
|
+
self.assertTrue(
|
|
84
|
+
jnp.array_equal(mock_pre_mlp_norm.call_args.args[0],
|
|
85
|
+
expected_mlp_norm_input))
|
|
86
|
+
|
|
87
|
+
mock_mlp.assert_called_once()
|
|
88
|
+
self.assertTrue(
|
|
89
|
+
jnp.array_equal(mock_mlp.call_args.args[0],
|
|
90
|
+
expected_mlp_norm_input))
|
|
91
|
+
|
|
92
|
+
expected_final_output = dummy_mlp_output + expected_mlp_norm_input
|
|
93
|
+
self.assertTrue(jnp.allclose(final_output, expected_final_output))
|
|
94
|
+
|
|
95
|
+
self.assertTrue(jnp.array_equal(new_kv_cache, dummy_kv_cache))
|
|
96
|
+
|
|
97
|
+
def test_shared_experts_transformer_block_logic(self):
|
|
98
|
+
"""Tests the forward pass logic of a SharedExpertsTransformerBlock."""
|
|
99
|
+
hidden_size = 1024
|
|
100
|
+
|
|
101
|
+
mock_pre_attn_norm = MagicMock(spec=nnx.Module)
|
|
102
|
+
mock_pre_mlp_norm = MagicMock(spec=nnx.Module)
|
|
103
|
+
|
|
104
|
+
mock_attn = MagicMock(spec=nnx.Module)
|
|
105
|
+
dummy_attn_output = jnp.full((64, hidden_size),
|
|
106
|
+
2.0,
|
|
107
|
+
dtype=jnp.bfloat16)
|
|
108
|
+
dummy_kv_cache = jnp.zeros((8, 16, 16, 128), dtype=jnp.bfloat16)
|
|
109
|
+
mock_attn.return_value = (dummy_kv_cache, dummy_attn_output)
|
|
110
|
+
|
|
111
|
+
mock_moe = MagicMock(spec=MoE)
|
|
112
|
+
dummy_moe_output = jnp.full((64, hidden_size), 3.0, dtype=jnp.bfloat16)
|
|
113
|
+
mock_moe.return_value = dummy_moe_output
|
|
114
|
+
|
|
115
|
+
mock_shared_experts = MagicMock(spec=DenseFFW)
|
|
116
|
+
dummy_shared_experts_output = jnp.full((64, hidden_size),
|
|
117
|
+
4.0,
|
|
118
|
+
dtype=jnp.bfloat16)
|
|
119
|
+
mock_shared_experts.return_value = dummy_shared_experts_output
|
|
120
|
+
|
|
121
|
+
transformer_block = SharedExpertsTransformerBlock(
|
|
122
|
+
pre_attention_norm=mock_pre_attn_norm,
|
|
123
|
+
pre_mlp_norm=mock_pre_mlp_norm,
|
|
124
|
+
custom_module=mock_moe,
|
|
125
|
+
attn=mock_attn,
|
|
126
|
+
shared_experts=mock_shared_experts,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
seq_len = 64
|
|
130
|
+
x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
|
|
131
|
+
initial_kv_cache = MagicMock()
|
|
132
|
+
attention_metadata = MagicMock()
|
|
133
|
+
|
|
134
|
+
mock_pre_attn_norm.side_effect = lambda val: val
|
|
135
|
+
mock_pre_mlp_norm.side_effect = lambda val: val
|
|
136
|
+
|
|
137
|
+
new_kv_cache, final_output = transformer_block(
|
|
138
|
+
x,
|
|
139
|
+
is_prefill=True,
|
|
140
|
+
kv_cache=initial_kv_cache,
|
|
141
|
+
attention_metadata=attention_metadata,
|
|
142
|
+
)
|
|
143
|
+
self.assertTrue(jnp.array_equal(new_kv_cache, dummy_kv_cache))
|
|
144
|
+
self.assertEqual(final_output.shape, (seq_len, hidden_size))
|
|
145
|
+
|
|
146
|
+
self.assertEqual(mock_moe.call_count, 1)
|
|
147
|
+
self.assertEqual(mock_attn.call_count, 1)
|
|
148
|
+
self.assertEqual(mock_shared_experts.call_count, 1)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
if __name__ == "__main__":
|
|
152
|
+
unittest.main()
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pytest
|
|
21
|
+
import torch
|
|
22
|
+
import torchax
|
|
23
|
+
from jax.sharding import Mesh
|
|
24
|
+
from torchax.interop import torch_view
|
|
25
|
+
from vllm.attention.backends.abstract import AttentionType
|
|
26
|
+
|
|
27
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
28
|
+
from tpu_inference.layers.vllm.attention import (PallasAttentionBackend,
|
|
29
|
+
PallasAttentionBackendImpl)
|
|
30
|
+
from tpu_inference.models.vllm.vllm_model_wrapper_context import \
|
|
31
|
+
set_vllm_model_wrapper_context
|
|
32
|
+
from tpu_inference.runner.kv_cache import get_kv_cache_shape_with_mesh
|
|
33
|
+
|
|
34
|
+
# ---- Test Configuration & Constants ----
|
|
35
|
+
|
|
36
|
+
# Total number of tokens across all sequences in the batch
|
|
37
|
+
TOTAL_TOKENS = 10
|
|
38
|
+
# Number of sequences in the batch
|
|
39
|
+
NUM_SEQS = 2
|
|
40
|
+
# Padded maximum number of sequences
|
|
41
|
+
MAX_NUM_SEQS = 4
|
|
42
|
+
# Number of attention heads (Query)
|
|
43
|
+
NUM_HEADS = 8
|
|
44
|
+
# Number of attention heads (Key/Value) - for Grouped-Query Attention
|
|
45
|
+
NUM_KV_HEADS = 4
|
|
46
|
+
# Dimension of each attention head
|
|
47
|
+
HEAD_DIM = 128
|
|
48
|
+
# Total number of blocks in the KV cache
|
|
49
|
+
NUM_BLOCKS = 32
|
|
50
|
+
# Number of tokens per block
|
|
51
|
+
BLOCK_SIZE = 16
|
|
52
|
+
# Maximum number of blocks a single sequence can occupy
|
|
53
|
+
MAX_BLOCKS_PER_SEQ = 8
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def create_inputs(
|
|
57
|
+
mesh: Mesh,
|
|
58
|
+
q_dtype: jnp.dtype = jnp.bfloat16,
|
|
59
|
+
kv_dtype: jnp.dtype = jnp.bfloat16,
|
|
60
|
+
total_tokens: int = TOTAL_TOKENS,
|
|
61
|
+
num_seqs: int = NUM_SEQS,
|
|
62
|
+
max_num_seqs: int = MAX_NUM_SEQS,
|
|
63
|
+
num_heads: int = NUM_HEADS,
|
|
64
|
+
num_kv_heads: int = NUM_KV_HEADS,
|
|
65
|
+
head_dim: int = HEAD_DIM,
|
|
66
|
+
num_blocks: int = NUM_BLOCKS,
|
|
67
|
+
block_size: int = BLOCK_SIZE,
|
|
68
|
+
max_blocks_per_seq: int = MAX_BLOCKS_PER_SEQ,
|
|
69
|
+
):
|
|
70
|
+
key = jax.random.key(0)
|
|
71
|
+
q = jax.random.uniform(key, (total_tokens, num_heads * head_dim),
|
|
72
|
+
dtype=q_dtype)
|
|
73
|
+
k = jax.random.uniform(key, (total_tokens, num_kv_heads * head_dim),
|
|
74
|
+
dtype=q_dtype)
|
|
75
|
+
v = jax.random.uniform(key, (total_tokens, num_kv_heads * head_dim),
|
|
76
|
+
dtype=q_dtype)
|
|
77
|
+
q = torch_view(q)
|
|
78
|
+
k = torch_view(k)
|
|
79
|
+
v = torch_view(v)
|
|
80
|
+
|
|
81
|
+
kv_cache_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
|
|
82
|
+
num_kv_heads, head_dim,
|
|
83
|
+
kv_dtype)
|
|
84
|
+
kv_cache = jax.random.normal(key, kv_cache_shape, dtype=kv_dtype)
|
|
85
|
+
|
|
86
|
+
positions = jnp.ones((total_tokens, ), dtype=jnp.int32)
|
|
87
|
+
block_tables = jnp.zeros((max_num_seqs * max_blocks_per_seq),
|
|
88
|
+
dtype=jnp.int32).reshape(-1)
|
|
89
|
+
seq_lens = jnp.array([5, 5, 0, 0], dtype=jnp.int32)
|
|
90
|
+
query_start_loc = jnp.array([0, 5, 10, 10, 10], dtype=jnp.int32)
|
|
91
|
+
request_distribution = jnp.array([0, 0, num_seqs], dtype=jnp.int32)
|
|
92
|
+
|
|
93
|
+
metadata = AttentionMetadata(
|
|
94
|
+
input_positions=positions,
|
|
95
|
+
block_tables=block_tables,
|
|
96
|
+
seq_lens=seq_lens,
|
|
97
|
+
query_start_loc=query_start_loc,
|
|
98
|
+
request_distribution=request_distribution,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return q, k, v, kv_cache, metadata
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@pytest.fixture
|
|
105
|
+
def mesh():
|
|
106
|
+
"""Provides a mock 1D JAX mesh for testing."""
|
|
107
|
+
# Create a mesh with available devices, useful for running on CPU/GPU/TPU
|
|
108
|
+
# For this test, it will likely be a single CPU device.
|
|
109
|
+
devices = np.array(jax.local_devices())[0:1]
|
|
110
|
+
if not devices.any():
|
|
111
|
+
# Add a mock device if no devices are present (e.g., in a CI environment)
|
|
112
|
+
devices = np.array([jax.devices("cpu")[0]])
|
|
113
|
+
return Mesh(devices.reshape((-1, 1, 1)), ("data", "attn_dp", "model"))
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class TestPallasAttentionBackend:
|
|
117
|
+
|
|
118
|
+
def test_get_name(self):
|
|
119
|
+
assert PallasAttentionBackend.get_name() == "PALLAS"
|
|
120
|
+
|
|
121
|
+
def test_get_impl_cls(self):
|
|
122
|
+
assert PallasAttentionBackend.get_impl_cls(
|
|
123
|
+
) == PallasAttentionBackendImpl
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class TestPallasAttentionBackendImpl:
|
|
127
|
+
|
|
128
|
+
def test_init_valid_params(self):
|
|
129
|
+
impl = PallasAttentionBackendImpl(
|
|
130
|
+
num_heads=32,
|
|
131
|
+
head_size=128,
|
|
132
|
+
scale=0.088,
|
|
133
|
+
num_kv_heads=8,
|
|
134
|
+
alibi_slopes=None,
|
|
135
|
+
sliding_window=None,
|
|
136
|
+
kv_cache_dtype="auto",
|
|
137
|
+
attn_type=AttentionType.DECODER,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
assert impl.num_heads == 32
|
|
141
|
+
assert impl.head_size == 128
|
|
142
|
+
assert impl.scale == 0.088
|
|
143
|
+
assert impl.num_kv_heads == 8
|
|
144
|
+
assert impl.num_queries_per_kv == 4
|
|
145
|
+
assert impl.sliding_window is None
|
|
146
|
+
|
|
147
|
+
def test_init_with_alibi_slopes_raises_error(self):
|
|
148
|
+
with pytest.raises(NotImplementedError,
|
|
149
|
+
match="Alibi slopes is not supported"):
|
|
150
|
+
PallasAttentionBackendImpl(
|
|
151
|
+
num_heads=32,
|
|
152
|
+
head_size=128,
|
|
153
|
+
scale=0.088,
|
|
154
|
+
num_kv_heads=8,
|
|
155
|
+
alibi_slopes=[1.0, 2.0],
|
|
156
|
+
sliding_window=None,
|
|
157
|
+
kv_cache_dtype="auto",
|
|
158
|
+
attn_type=AttentionType.DECODER,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def test_init_with_encoder_attention_raises_error(self):
|
|
162
|
+
with pytest.raises(NotImplementedError,
|
|
163
|
+
match="Encoder self-attention"):
|
|
164
|
+
PallasAttentionBackendImpl(
|
|
165
|
+
num_heads=32,
|
|
166
|
+
head_size=128,
|
|
167
|
+
scale=0.088,
|
|
168
|
+
num_kv_heads=8,
|
|
169
|
+
alibi_slopes=None,
|
|
170
|
+
sliding_window=None,
|
|
171
|
+
kv_cache_dtype="auto",
|
|
172
|
+
attn_type=AttentionType.ENCODER,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def test_forward(self, mesh):
|
|
176
|
+
impl = PallasAttentionBackendImpl(
|
|
177
|
+
num_heads=NUM_HEADS,
|
|
178
|
+
head_size=HEAD_DIM,
|
|
179
|
+
scale=0.088,
|
|
180
|
+
num_kv_heads=NUM_KV_HEADS,
|
|
181
|
+
alibi_slopes=None,
|
|
182
|
+
sliding_window=None,
|
|
183
|
+
kv_cache_dtype="auto",
|
|
184
|
+
attn_type=AttentionType.DECODER,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
layer = MagicMock()
|
|
188
|
+
layer.layer_name = "0"
|
|
189
|
+
|
|
190
|
+
query, key, value, kv_cache, metadata = create_inputs(mesh)
|
|
191
|
+
|
|
192
|
+
with torchax.default_env(), set_vllm_model_wrapper_context(
|
|
193
|
+
kv_caches=[kv_cache],
|
|
194
|
+
mesh=mesh,
|
|
195
|
+
layer_name_to_kvcache_index={'0': 0}):
|
|
196
|
+
impl.forward(layer, query, key, value, torch.tensor([]), metadata)
|
|
197
|
+
|
|
198
|
+
def test_forward_with_fp8_kv_cache(self, mesh):
|
|
199
|
+
impl = PallasAttentionBackendImpl(
|
|
200
|
+
num_heads=NUM_HEADS,
|
|
201
|
+
head_size=HEAD_DIM,
|
|
202
|
+
scale=0.088,
|
|
203
|
+
num_kv_heads=NUM_KV_HEADS,
|
|
204
|
+
alibi_slopes=None,
|
|
205
|
+
sliding_window=None,
|
|
206
|
+
kv_cache_dtype="fp8",
|
|
207
|
+
attn_type=AttentionType.DECODER,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
layer = MagicMock()
|
|
211
|
+
layer.layer_name = "0"
|
|
212
|
+
layer._q_scale_float = None
|
|
213
|
+
layer._k_scale_float = 1
|
|
214
|
+
layer._v_scale_float = 1
|
|
215
|
+
|
|
216
|
+
query, key, value, kv_cache, metadata = create_inputs(
|
|
217
|
+
mesh, kv_dtype=jnp.float8_e4m3fn)
|
|
218
|
+
|
|
219
|
+
with torchax.default_env(), set_vllm_model_wrapper_context(
|
|
220
|
+
kv_caches=[kv_cache],
|
|
221
|
+
mesh=mesh,
|
|
222
|
+
layer_name_to_kvcache_index={'0': 0}):
|
|
223
|
+
impl.forward(layer, query, key, value, torch.tensor([]), metadata)
|
|
224
|
+
|
|
225
|
+
def test_forward_with_w8a8(self, mesh):
|
|
226
|
+
impl = PallasAttentionBackendImpl(
|
|
227
|
+
num_heads=NUM_HEADS,
|
|
228
|
+
head_size=HEAD_DIM,
|
|
229
|
+
scale=0.088,
|
|
230
|
+
num_kv_heads=NUM_KV_HEADS,
|
|
231
|
+
alibi_slopes=None,
|
|
232
|
+
sliding_window=None,
|
|
233
|
+
kv_cache_dtype="fp8",
|
|
234
|
+
attn_type=AttentionType.DECODER,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
layer = MagicMock()
|
|
238
|
+
layer.layer_name = "0"
|
|
239
|
+
layer._q_scale_float = 1
|
|
240
|
+
layer._k_scale_float = 1
|
|
241
|
+
layer._v_scale_float = 1
|
|
242
|
+
|
|
243
|
+
query, key, value, kv_cache, metadata = create_inputs(
|
|
244
|
+
mesh, kv_dtype=jnp.float8_e4m3fn)
|
|
245
|
+
|
|
246
|
+
with torchax.default_env(), set_vllm_model_wrapper_context(
|
|
247
|
+
kv_caches=[kv_cache],
|
|
248
|
+
mesh=mesh,
|
|
249
|
+
layer_name_to_kvcache_index={'0': 0}):
|
|
250
|
+
impl.forward(layer, query, key, value, torch.tensor([]), metadata)
|
|
251
|
+
|
|
252
|
+
def test_forward_with_vllm_kv_cache_raises_error(self, mesh):
|
|
253
|
+
impl = PallasAttentionBackendImpl(
|
|
254
|
+
num_heads=NUM_HEADS,
|
|
255
|
+
head_size=HEAD_DIM,
|
|
256
|
+
scale=0.088,
|
|
257
|
+
num_kv_heads=NUM_KV_HEADS,
|
|
258
|
+
alibi_slopes=None,
|
|
259
|
+
sliding_window=None,
|
|
260
|
+
kv_cache_dtype="auto",
|
|
261
|
+
attn_type=AttentionType.DECODER,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
layer = MagicMock()
|
|
265
|
+
layer.layer_name = "0"
|
|
266
|
+
|
|
267
|
+
query, key, value, kv_cache, metadata = create_inputs(mesh)
|
|
268
|
+
|
|
269
|
+
with torchax.default_env(), set_vllm_model_wrapper_context(
|
|
270
|
+
kv_caches=[kv_cache],
|
|
271
|
+
mesh=mesh), pytest.raises(RuntimeError,
|
|
272
|
+
match="should be empty but has"):
|
|
273
|
+
impl.forward(layer, query, key, value, torch.tensor([1]), metadata)
|
|
274
|
+
|
|
275
|
+
def test_forward_with_output_scale_raises_error(self, mesh):
|
|
276
|
+
impl = PallasAttentionBackendImpl(
|
|
277
|
+
num_heads=NUM_HEADS,
|
|
278
|
+
head_size=HEAD_DIM,
|
|
279
|
+
scale=0.088,
|
|
280
|
+
num_kv_heads=NUM_KV_HEADS,
|
|
281
|
+
alibi_slopes=None,
|
|
282
|
+
sliding_window=None,
|
|
283
|
+
kv_cache_dtype="auto",
|
|
284
|
+
attn_type=AttentionType.DECODER,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
layer = MagicMock()
|
|
288
|
+
layer.layer_name = "0"
|
|
289
|
+
|
|
290
|
+
query, key, value, kv_cache, metadata = create_inputs(mesh)
|
|
291
|
+
output_scale = torch.tensor([1.0])
|
|
292
|
+
|
|
293
|
+
with torchax.default_env(), set_vllm_model_wrapper_context(
|
|
294
|
+
kv_caches=[kv_cache],
|
|
295
|
+
mesh=mesh), pytest.raises(NotImplementedError,
|
|
296
|
+
match="fused output quantization"):
|
|
297
|
+
impl.forward(layer,
|
|
298
|
+
query,
|
|
299
|
+
key,
|
|
300
|
+
value,
|
|
301
|
+
torch.tensor([]),
|
|
302
|
+
metadata,
|
|
303
|
+
output_scale=output_scale)
|
|
304
|
+
|
|
305
|
+
def test_forward_with_attention_sink(self, mesh):
|
|
306
|
+
head_dim = 64
|
|
307
|
+
sinks = torch.rand([NUM_HEADS], dtype=torch.float32)
|
|
308
|
+
|
|
309
|
+
impl = PallasAttentionBackendImpl(num_heads=NUM_HEADS,
|
|
310
|
+
head_size=head_dim,
|
|
311
|
+
scale=0.088,
|
|
312
|
+
num_kv_heads=NUM_KV_HEADS,
|
|
313
|
+
alibi_slopes=None,
|
|
314
|
+
sliding_window=None,
|
|
315
|
+
kv_cache_dtype="auto",
|
|
316
|
+
attn_type=AttentionType.DECODER,
|
|
317
|
+
sinks=sinks)
|
|
318
|
+
impl.process_weights_after_loading(torch.bfloat16)
|
|
319
|
+
|
|
320
|
+
layer = MagicMock()
|
|
321
|
+
layer.layer_name = "0"
|
|
322
|
+
|
|
323
|
+
query, key, value, kv_cache, metadata = create_inputs(
|
|
324
|
+
mesh, head_dim=head_dim)
|
|
325
|
+
|
|
326
|
+
with torchax.default_env(), set_vllm_model_wrapper_context(
|
|
327
|
+
kv_caches=[kv_cache],
|
|
328
|
+
mesh=mesh,
|
|
329
|
+
layer_name_to_kvcache_index={'0': 0}):
|
|
330
|
+
assert impl.sinks is not None
|
|
331
|
+
impl.forward(layer, query, key, value, torch.tensor([]), metadata)
|
|
332
|
+
|
|
333
|
+
def test_forward_with_attention_sink_head_dim_128_raises_error(self, mesh):
|
|
334
|
+
head_dim = 128
|
|
335
|
+
sinks = torch.rand([NUM_HEADS], dtype=torch.float32)
|
|
336
|
+
|
|
337
|
+
impl = PallasAttentionBackendImpl(num_heads=NUM_HEADS,
|
|
338
|
+
head_size=head_dim,
|
|
339
|
+
scale=0.088,
|
|
340
|
+
num_kv_heads=NUM_KV_HEADS,
|
|
341
|
+
alibi_slopes=None,
|
|
342
|
+
sliding_window=None,
|
|
343
|
+
kv_cache_dtype="auto",
|
|
344
|
+
attn_type=AttentionType.DECODER,
|
|
345
|
+
sinks=sinks)
|
|
346
|
+
impl.process_weights_after_loading(torch.bfloat16)
|
|
347
|
+
|
|
348
|
+
layer = MagicMock()
|
|
349
|
+
layer.layer_name = "0"
|
|
350
|
+
|
|
351
|
+
query, key, value, kv_cache, metadata = create_inputs(
|
|
352
|
+
mesh, head_dim=head_dim)
|
|
353
|
+
|
|
354
|
+
with torchax.default_env(), set_vllm_model_wrapper_context(
|
|
355
|
+
kv_caches=[kv_cache],
|
|
356
|
+
mesh=mesh,
|
|
357
|
+
layer_name_to_kvcache_index={'0': 0}
|
|
358
|
+
), pytest.raises(
|
|
359
|
+
NotImplementedError,
|
|
360
|
+
match=
|
|
361
|
+
"Attention sink support is only available when head_dim==64"):
|
|
362
|
+
assert impl.sinks is not None
|
|
363
|
+
impl.forward(layer, query, key, value, torch.tensor([]), metadata)
|