tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import unittest
|
|
16
|
+
from typing import Tuple
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
import numpy as np
|
|
21
|
+
from flax import nnx
|
|
22
|
+
from jax.sharding import Mesh
|
|
23
|
+
from parameterized import parameterized
|
|
24
|
+
|
|
25
|
+
from tpu_inference.layers.common.attention_interface import get_kv_cache_shape
|
|
26
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
27
|
+
from tpu_inference.layers.jax.attention.attention import Attention
|
|
28
|
+
|
|
29
|
+
KVCache = Tuple[jax.Array, jax.Array]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TestAttention(unittest.TestCase):
|
|
33
|
+
"""Unit test suite for the JAX Attention module."""
|
|
34
|
+
|
|
35
|
+
def setUp(self):
|
|
36
|
+
"""Sets up the testing environment before each test."""
|
|
37
|
+
self.mesh = Mesh(
|
|
38
|
+
np.array(jax.devices()[:1]).reshape(1, 1, 1, -1),
|
|
39
|
+
axis_names=(
|
|
40
|
+
"data",
|
|
41
|
+
"attn_dp",
|
|
42
|
+
"expert",
|
|
43
|
+
"model",
|
|
44
|
+
),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@parameterized.expand([["auto"], ["fp8"]])
|
|
48
|
+
def test_attention_forward_pass(self, kv_cache_str):
|
|
49
|
+
"""Tests the forward pass of the Attention module in prefill mode."""
|
|
50
|
+
hidden_size = 1024
|
|
51
|
+
num_attention_heads = 8
|
|
52
|
+
head_dim = hidden_size // num_attention_heads
|
|
53
|
+
|
|
54
|
+
with jax.set_mesh(self.mesh):
|
|
55
|
+
attention = Attention(hidden_size=hidden_size,
|
|
56
|
+
num_attention_heads=num_attention_heads,
|
|
57
|
+
num_key_value_heads=num_attention_heads,
|
|
58
|
+
head_dim=head_dim,
|
|
59
|
+
rope_theta=10000.0,
|
|
60
|
+
rope_scaling={},
|
|
61
|
+
dtype=jnp.bfloat16,
|
|
62
|
+
mesh=self.mesh,
|
|
63
|
+
random_init=True,
|
|
64
|
+
rngs=nnx.Rngs(42),
|
|
65
|
+
kv_cache_dtype=kv_cache_str)
|
|
66
|
+
|
|
67
|
+
seq_len = 64
|
|
68
|
+
x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
|
|
69
|
+
|
|
70
|
+
block_size = 16
|
|
71
|
+
num_blocks = 8
|
|
72
|
+
kv_dtype = jnp.float8_e4m3fn if kv_cache_str == "fp8" else jnp.bfloat16
|
|
73
|
+
cache_shape = get_kv_cache_shape(num_blocks, block_size,
|
|
74
|
+
num_attention_heads, head_dim,
|
|
75
|
+
kv_dtype)
|
|
76
|
+
|
|
77
|
+
kv_cache = jnp.zeros(cache_shape, dtype=kv_dtype)
|
|
78
|
+
|
|
79
|
+
num_required_blocks = seq_len // block_size
|
|
80
|
+
|
|
81
|
+
attention_metadata = AttentionMetadata(
|
|
82
|
+
input_positions=jnp.arange(seq_len, dtype=jnp.int32),
|
|
83
|
+
block_tables=jnp.array(list(range(num_required_blocks)),
|
|
84
|
+
dtype=jnp.int32),
|
|
85
|
+
seq_lens=jnp.array([seq_len], dtype=jnp.int32),
|
|
86
|
+
query_start_loc=jnp.array([0, seq_len], dtype=jnp.int32),
|
|
87
|
+
request_distribution=jnp.array([0, 0, 1], dtype=jnp.int32),
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
new_kv_cache, output = attention(
|
|
91
|
+
x,
|
|
92
|
+
is_prefill=True,
|
|
93
|
+
kv_cache=kv_cache,
|
|
94
|
+
attention_metadata=attention_metadata,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
self.assertEqual(output.shape, (seq_len, hidden_size))
|
|
98
|
+
|
|
99
|
+
self.assertEqual(new_kv_cache.shape, kv_cache.shape)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
if __name__ == "__main__":
|
|
103
|
+
unittest.main()
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import unittest
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
import numpy as np
|
|
21
|
+
from flax import nnx
|
|
22
|
+
from jax.sharding import Mesh, PartitionSpec
|
|
23
|
+
from parameterized import parameterized
|
|
24
|
+
|
|
25
|
+
import tpu_inference.kernels.mla.v1.kernel as mla
|
|
26
|
+
from tpu_inference.layers.common.attention_interface import get_kv_cache_shape
|
|
27
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
28
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
29
|
+
from tpu_inference.layers.jax.attention.deepseek_v3_attention import MLA
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TestMLA(unittest.TestCase):
|
|
33
|
+
|
|
34
|
+
def setUp(self):
|
|
35
|
+
os.environ["NEW_MODEL_DESIGN"] = "1"
|
|
36
|
+
self.mesh = Mesh(
|
|
37
|
+
np.array(jax.devices("tpu")[:1]).reshape(1, 1, 1, 1),
|
|
38
|
+
axis_names=("data", "attn_dp", "expert", "model"),
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
@parameterized.expand([["auto"], ["fp8"]])
|
|
42
|
+
def test_mla_forward_pass(self, kv_cache_str):
|
|
43
|
+
hidden_size = 256
|
|
44
|
+
|
|
45
|
+
num_key_value_heads = 32
|
|
46
|
+
qk_nope_head_dim = 64
|
|
47
|
+
qk_rope_head_dim = 32
|
|
48
|
+
|
|
49
|
+
with jax.set_mesh(self.mesh):
|
|
50
|
+
query_tnh_spec = PartitionSpec(None, ShardingAxisName.MLP_TENSOR,
|
|
51
|
+
None)
|
|
52
|
+
keyvalue_skh_spec = PartitionSpec(None,
|
|
53
|
+
ShardingAxisName.MLP_TENSOR,
|
|
54
|
+
None)
|
|
55
|
+
attn_o_tnh_spec = PartitionSpec(None, ShardingAxisName.MLP_TENSOR,
|
|
56
|
+
None)
|
|
57
|
+
|
|
58
|
+
mla_layer = MLA(
|
|
59
|
+
hidden_size=hidden_size,
|
|
60
|
+
num_attention_heads=32,
|
|
61
|
+
num_key_value_heads=num_key_value_heads,
|
|
62
|
+
head_dim=64, # MLA uses v_head_dim as head_dim
|
|
63
|
+
rope_theta=10000,
|
|
64
|
+
dtype=jnp.bfloat16,
|
|
65
|
+
q_lora_rank=512,
|
|
66
|
+
kv_lora_rank=512,
|
|
67
|
+
qk_nope_head_dim=
|
|
68
|
+
qk_nope_head_dim, # Half of DeepSeek v3's real values
|
|
69
|
+
qk_rope_head_dim=
|
|
70
|
+
qk_rope_head_dim, # Half of DeepSeek v3's real values
|
|
71
|
+
v_head_dim=64, # Half of DeepSeek v3's real values
|
|
72
|
+
rms_norm_eps=1e-5,
|
|
73
|
+
rngs=nnx.Rngs(42),
|
|
74
|
+
rope_scaling={
|
|
75
|
+
"beta_fast": 32,
|
|
76
|
+
"beta_slow": 1,
|
|
77
|
+
"factor": 40,
|
|
78
|
+
"mscale": 1.0,
|
|
79
|
+
"mscale_all_dim": 1.0,
|
|
80
|
+
"original_max_position_embeddings": 4096,
|
|
81
|
+
"type": "yarn",
|
|
82
|
+
},
|
|
83
|
+
mesh=self.mesh,
|
|
84
|
+
random_init=True,
|
|
85
|
+
kv_cache_dtype=kv_cache_str,
|
|
86
|
+
query_tnh=query_tnh_spec,
|
|
87
|
+
keyvalue_skh=keyvalue_skh_spec,
|
|
88
|
+
attn_o_tnh=attn_o_tnh_spec,
|
|
89
|
+
q_da_sharding=(None, ShardingAxisName.VOCAB),
|
|
90
|
+
anh_sharding=(None, ShardingAxisName.MLP_TENSOR, None),
|
|
91
|
+
ap_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
92
|
+
kv_da_sharding=(None, ShardingAxisName.VOCAB),
|
|
93
|
+
rd_sharding=(ShardingAxisName.MLP_TENSOR, None),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Create input tensor
|
|
97
|
+
seq_len = 32
|
|
98
|
+
x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
|
|
99
|
+
|
|
100
|
+
# Create KV cache
|
|
101
|
+
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
|
102
|
+
block_size = 16
|
|
103
|
+
num_blocks = 8
|
|
104
|
+
kv_dtype = jnp.float8_e4m3fn if kv_cache_str == "fp8" else jnp.bfloat16
|
|
105
|
+
cache_shape = get_kv_cache_shape(num_blocks, block_size,
|
|
106
|
+
num_key_value_heads, qk_head_dim,
|
|
107
|
+
kv_dtype)
|
|
108
|
+
kv_cache = jnp.zeros(cache_shape, dtype=kv_dtype)
|
|
109
|
+
|
|
110
|
+
# Create attention metadata
|
|
111
|
+
attention_metadata = AttentionMetadata(
|
|
112
|
+
input_positions=jnp.arange(seq_len, dtype=jnp.int32),
|
|
113
|
+
block_tables=jnp.zeros((8, ), dtype=jnp.int32),
|
|
114
|
+
seq_lens=jnp.ones((1, ), dtype=jnp.int32) * seq_len,
|
|
115
|
+
query_start_loc=jnp.array(
|
|
116
|
+
[0, seq_len], dtype=jnp.int32), # This is cu_q_lens
|
|
117
|
+
request_distribution=jnp.array([0, 0, 1], dtype=jnp.int32),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
mla_layer.rope.initialize_cache(self.mesh)
|
|
121
|
+
|
|
122
|
+
# Run forward pass
|
|
123
|
+
new_kv_cache, output = mla_layer(
|
|
124
|
+
x,
|
|
125
|
+
is_prefill=True,
|
|
126
|
+
kv_cache=kv_cache,
|
|
127
|
+
attention_metadata=attention_metadata)
|
|
128
|
+
|
|
129
|
+
# Verify output shapes
|
|
130
|
+
self.assertEqual(output.shape, (seq_len, hidden_size))
|
|
131
|
+
self.assertEqual(new_kv_cache.shape, kv_cache.shape)
|
|
132
|
+
|
|
133
|
+
@parameterized.expand([["auto"]]) # MLA kernel does not support fp8 yet
|
|
134
|
+
def test_mla_kernel_forward_pass(self, kv_cache_str):
|
|
135
|
+
hidden_size = 256
|
|
136
|
+
|
|
137
|
+
num_key_value_heads = 1
|
|
138
|
+
qk_nope_head_dim = 64
|
|
139
|
+
qk_rope_head_dim = 32
|
|
140
|
+
v_head_dim = 64
|
|
141
|
+
kv_lora_rank = 512
|
|
142
|
+
|
|
143
|
+
with jax.set_mesh(self.mesh):
|
|
144
|
+
query_tnh_spec = PartitionSpec(ShardingAxisName.MLP_TENSOR, None,
|
|
145
|
+
None)
|
|
146
|
+
keyvalue_skh_spec = PartitionSpec(ShardingAxisName.MLP_TENSOR,
|
|
147
|
+
None)
|
|
148
|
+
attn_o_tnh_spec = PartitionSpec(ShardingAxisName.MLP_TENSOR, None,
|
|
149
|
+
None)
|
|
150
|
+
|
|
151
|
+
mla_layer = MLA(
|
|
152
|
+
hidden_size=hidden_size,
|
|
153
|
+
num_attention_heads=32,
|
|
154
|
+
num_key_value_heads=num_key_value_heads,
|
|
155
|
+
head_dim=v_head_dim, # MLA uses v_head_dim as head_dim
|
|
156
|
+
rope_theta=10000,
|
|
157
|
+
dtype=jnp.bfloat16,
|
|
158
|
+
q_lora_rank=512,
|
|
159
|
+
kv_lora_rank=kv_lora_rank,
|
|
160
|
+
qk_nope_head_dim=qk_nope_head_dim,
|
|
161
|
+
qk_rope_head_dim=qk_rope_head_dim,
|
|
162
|
+
v_head_dim=v_head_dim,
|
|
163
|
+
rms_norm_eps=1e-5,
|
|
164
|
+
rngs=nnx.Rngs(42),
|
|
165
|
+
rope_scaling={
|
|
166
|
+
"beta_fast": 32,
|
|
167
|
+
"beta_slow": 1,
|
|
168
|
+
"factor": 40,
|
|
169
|
+
"mscale": 1.0,
|
|
170
|
+
"mscale_all_dim": 1.0,
|
|
171
|
+
"original_max_position_embeddings": 4096,
|
|
172
|
+
"type": "yarn",
|
|
173
|
+
},
|
|
174
|
+
mesh=self.mesh,
|
|
175
|
+
random_init=True,
|
|
176
|
+
kv_cache_dtype=kv_cache_str,
|
|
177
|
+
use_mla_kernel=
|
|
178
|
+
True, # Set to true, in order to trigger MLA kernel.
|
|
179
|
+
query_tnh=query_tnh_spec,
|
|
180
|
+
keyvalue_skh=keyvalue_skh_spec,
|
|
181
|
+
attn_o_tnh=attn_o_tnh_spec,
|
|
182
|
+
q_da_sharding=(None, ShardingAxisName.VOCAB),
|
|
183
|
+
anh_sharding=(None, ShardingAxisName.MLP_TENSOR, None),
|
|
184
|
+
ap_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
185
|
+
kv_da_sharding=(None, ShardingAxisName.VOCAB),
|
|
186
|
+
rd_sharding=(ShardingAxisName.MLP_TENSOR, None),
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Create input tensor
|
|
190
|
+
seq_len = 32
|
|
191
|
+
x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
|
|
192
|
+
|
|
193
|
+
# Create KV cache for MLA kernel
|
|
194
|
+
block_size = 16
|
|
195
|
+
num_blocks = 8
|
|
196
|
+
kv_dtype = jnp.float8_e4m3fn if kv_cache_str == "fp8" else jnp.bfloat16
|
|
197
|
+
|
|
198
|
+
# For the MLA kernel, the head dimension is the sum of qk_nope_head_dim and v_head_dim
|
|
199
|
+
# and lora rank
|
|
200
|
+
cache_shape = mla.get_kv_cache_shape(
|
|
201
|
+
num_blocks, block_size, kv_lora_rank + qk_rope_head_dim,
|
|
202
|
+
kv_dtype)
|
|
203
|
+
kv_cache = jnp.zeros(cache_shape, dtype=kv_dtype)
|
|
204
|
+
|
|
205
|
+
# Create attention metadata
|
|
206
|
+
attention_metadata = AttentionMetadata(
|
|
207
|
+
input_positions=jnp.arange(seq_len, dtype=jnp.int32),
|
|
208
|
+
block_tables=jnp.zeros((8, ), dtype=jnp.int32),
|
|
209
|
+
seq_lens=jnp.ones((1, ), dtype=jnp.int32) * seq_len,
|
|
210
|
+
query_start_loc=jnp.array([0, seq_len], dtype=jnp.int32),
|
|
211
|
+
request_distribution=jnp.array([0, 0, 1], dtype=jnp.int32),
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
mla_layer.rope.initialize_cache(self.mesh)
|
|
215
|
+
|
|
216
|
+
# Run forward pass
|
|
217
|
+
new_kv_cache, output = mla_layer(
|
|
218
|
+
x,
|
|
219
|
+
is_prefill=True,
|
|
220
|
+
kv_cache=kv_cache,
|
|
221
|
+
attention_metadata=attention_metadata)
|
|
222
|
+
|
|
223
|
+
# Verify output shapes
|
|
224
|
+
self.assertEqual(output.shape, (seq_len, hidden_size))
|
|
225
|
+
self.assertEqual(new_kv_cache.shape, kv_cache.shape)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
if __name__ == "__main__":
|
|
229
|
+
unittest.main()
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def tearDownModule():
|
|
233
|
+
del os.environ["NEW_MODEL_DESIGN"]
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import unittest
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
|
|
19
|
+
import chex
|
|
20
|
+
|
|
21
|
+
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
|
|
22
|
+
|
|
23
|
+
import jax
|
|
24
|
+
import jax.numpy as jnp
|
|
25
|
+
from flax import nnx
|
|
26
|
+
from jax.sharding import NamedSharding
|
|
27
|
+
from jax.sharding import PartitionSpec as P
|
|
28
|
+
|
|
29
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
30
|
+
from tpu_inference.layers.common.sharding import build_mesh
|
|
31
|
+
from tpu_inference.layers.jax.attention.llama4_attention import (
|
|
32
|
+
L2Norm, Llama4Attention)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class SimpleVLLMConfig:
|
|
37
|
+
additional_config: dict = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Llama4AttentionTest(unittest.TestCase):
|
|
41
|
+
"""Unit test suite for Llama4-specific attention components."""
|
|
42
|
+
|
|
43
|
+
def setUp(self):
|
|
44
|
+
devices = jax.devices()[:1]
|
|
45
|
+
sharding_strategy = {"tensor_parallelism": len(devices)}
|
|
46
|
+
self.mesh = build_mesh(devices, sharding_strategy)
|
|
47
|
+
|
|
48
|
+
def test_l2norm_forward_pass(self):
|
|
49
|
+
"""Tests the forward pass of the L2Norm module with hardcoded values."""
|
|
50
|
+
eps = 1e-5
|
|
51
|
+
l2_norm = L2Norm(eps=eps)
|
|
52
|
+
x = jnp.array([[1.0, 2.0, 3.0, 4.0]], dtype=jnp.float32)
|
|
53
|
+
|
|
54
|
+
output = l2_norm(x)
|
|
55
|
+
|
|
56
|
+
self.assertEqual(output.shape, x.shape)
|
|
57
|
+
self.assertEqual(output.dtype, x.dtype)
|
|
58
|
+
|
|
59
|
+
# Expected values calculated manually:
|
|
60
|
+
# mean_sq = (1^2 + 2^2 + 3^2 + 4^2) / 4 = (1+4+9+16)/4 = 30/4 = 7.5
|
|
61
|
+
# norm_val = sqrt(7.5 + 1e-5)
|
|
62
|
+
# expected = x / norm_val
|
|
63
|
+
expected_output = jnp.array([[0.365148, 0.730297, 1.095445, 1.460594]],
|
|
64
|
+
dtype=jnp.float32)
|
|
65
|
+
self.assertTrue(jnp.allclose(output, expected_output, atol=1e-6))
|
|
66
|
+
|
|
67
|
+
def test_l2norm_with_zeros(self):
|
|
68
|
+
"""Tests L2Norm with an all-zero input."""
|
|
69
|
+
l2_norm = L2Norm(eps=1e-5)
|
|
70
|
+
x = jnp.zeros((4, 8, 16))
|
|
71
|
+
output = l2_norm(x)
|
|
72
|
+
self.assertEqual(output.shape, x.shape)
|
|
73
|
+
# Output should be all zeros.
|
|
74
|
+
self.assertTrue(jnp.all(output == 0))
|
|
75
|
+
|
|
76
|
+
def test_l2norm_eps_effect(self):
|
|
77
|
+
"""Tests the effect of the epsilon value in L2Norm."""
|
|
78
|
+
eps = 1e-3
|
|
79
|
+
l2_norm = L2Norm(eps=eps)
|
|
80
|
+
x = jax.random.normal(jax.random.PRNGKey(0), (1, 1, 128))
|
|
81
|
+
output = l2_norm(x)
|
|
82
|
+
|
|
83
|
+
mean_sq = jnp.mean(x**2, axis=-1, keepdims=True)
|
|
84
|
+
expected_output = x * jax.lax.rsqrt(mean_sq + eps)
|
|
85
|
+
|
|
86
|
+
self.assertTrue(jnp.allclose(output, expected_output))
|
|
87
|
+
|
|
88
|
+
def test_apply_temperature_tuning(self):
|
|
89
|
+
with jax.set_mesh(self.mesh):
|
|
90
|
+
hidden_size = 64
|
|
91
|
+
num_attention_heads = 4
|
|
92
|
+
head_dim = hidden_size // num_attention_heads
|
|
93
|
+
|
|
94
|
+
# Create dummy sharding objects
|
|
95
|
+
dummy_sharding = NamedSharding(self.mesh, P())
|
|
96
|
+
|
|
97
|
+
llama4_attention = Llama4Attention(
|
|
98
|
+
hidden_size=hidden_size,
|
|
99
|
+
num_attention_heads=num_attention_heads,
|
|
100
|
+
num_key_value_heads=num_attention_heads,
|
|
101
|
+
head_dim=head_dim,
|
|
102
|
+
rope_theta=10000.0,
|
|
103
|
+
rope_scaling={},
|
|
104
|
+
dtype=jnp.bfloat16,
|
|
105
|
+
kv_cache_dtype="auto",
|
|
106
|
+
use_qk_norm=False,
|
|
107
|
+
temperature_tuning=True,
|
|
108
|
+
temperature_tuning_scale=2.0,
|
|
109
|
+
temperature_tuning_floor_scale=2.0,
|
|
110
|
+
mesh=self.mesh,
|
|
111
|
+
random_init=True,
|
|
112
|
+
activation_attention_td=dummy_sharding,
|
|
113
|
+
activation_attention_out_td=dummy_sharding,
|
|
114
|
+
rngs=nnx.Rngs(42),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
seq_len = 8
|
|
118
|
+
input_arr_TNH = jnp.ones((seq_len, num_attention_heads, head_dim),
|
|
119
|
+
dtype=jnp.bfloat16)
|
|
120
|
+
attention_metadata = AttentionMetadata(
|
|
121
|
+
input_positions=jnp.arange(seq_len, dtype=jnp.int32))
|
|
122
|
+
expected_scales = jnp.array(
|
|
123
|
+
[1, 2.375, 2.375, 3.20312, 3.20312, 3.76562, 3.76562, 4.21875],
|
|
124
|
+
dtype=jnp.bfloat16)
|
|
125
|
+
output = llama4_attention.apply_temperature_tuning(
|
|
126
|
+
attention_metadata, input_arr_TNH)
|
|
127
|
+
chex.assert_shape(output, (seq_len, num_attention_heads, head_dim))
|
|
128
|
+
|
|
129
|
+
expected_output = jnp.ones_like(
|
|
130
|
+
input_arr_TNH) * expected_scales[:, None, None]
|
|
131
|
+
chex.assert_trees_all_close(output, expected_output, atol=1e-3)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
if __name__ == "__main__":
|
|
135
|
+
unittest.main()
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|