tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
from jax import numpy as jnp
|
|
17
|
+
from jax._src import test_util as jtu
|
|
18
|
+
from jax.sharding import Mesh
|
|
19
|
+
|
|
20
|
+
from tpu_inference.layers.jax.rope import (DeepseekScalingRotaryEmbedding,
|
|
21
|
+
RotaryEmbedding)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class RotaryEmbeddingTest(jtu.JaxTestCase):
|
|
25
|
+
|
|
26
|
+
def test_apply_rope(self):
|
|
27
|
+
head_dim = 2
|
|
28
|
+
rope_theta = 10000
|
|
29
|
+
original_max_position_embeddings = 2
|
|
30
|
+
rope = RotaryEmbedding(
|
|
31
|
+
rotary_dim=head_dim,
|
|
32
|
+
rope_theta=rope_theta,
|
|
33
|
+
original_max_position_embeddings=original_max_position_embeddings,
|
|
34
|
+
dtype=jnp.float32)
|
|
35
|
+
rope.initialize_cache()
|
|
36
|
+
self.assertTrue(
|
|
37
|
+
rope.sin_cos_cache.shape == (original_max_position_embeddings,
|
|
38
|
+
head_dim))
|
|
39
|
+
expected_sin_cos = jnp.array([[1, 0], [0.5403023, 0.841471]],
|
|
40
|
+
dtype=jnp.float32)
|
|
41
|
+
self.assertArraysAllClose(rope.sin_cos_cache, expected_sin_cos)
|
|
42
|
+
|
|
43
|
+
num_tokens = 2
|
|
44
|
+
num_heads = 1
|
|
45
|
+
positions = jnp.arange(num_tokens)
|
|
46
|
+
x = jnp.ones((num_tokens, num_heads, head_dim))
|
|
47
|
+
x_rope = rope.apply_rope(positions, x)
|
|
48
|
+
expected_x_rope = jnp.array([[[1, 1]], [[-0.30116874, 1.3817732]]],
|
|
49
|
+
dtype=jnp.float32)
|
|
50
|
+
self.assertTrue(x_rope.shape == x.shape)
|
|
51
|
+
self.assertArraysAllClose(x_rope, expected_x_rope)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class DeepseekScalingRotaryEmbeddingTest(jtu.JaxTestCase):
|
|
55
|
+
|
|
56
|
+
def test_apply_rope(self):
|
|
57
|
+
head_dim = 2
|
|
58
|
+
rope_theta = 10000
|
|
59
|
+
original_max_position_embeddings = 1
|
|
60
|
+
scaling_factor = 2
|
|
61
|
+
devices = jax.devices()
|
|
62
|
+
mesh = Mesh(devices, ('data', ))
|
|
63
|
+
|
|
64
|
+
rope = DeepseekScalingRotaryEmbedding(
|
|
65
|
+
rotary_dim=head_dim,
|
|
66
|
+
rope_theta=rope_theta,
|
|
67
|
+
original_max_position_embeddings=original_max_position_embeddings,
|
|
68
|
+
scaling_factor=scaling_factor,
|
|
69
|
+
dtype=jnp.float32)
|
|
70
|
+
rope.initialize_cache(mesh)
|
|
71
|
+
expected_padded_dim = 128
|
|
72
|
+
self.assertTrue(
|
|
73
|
+
rope.sin_cos_cache.shape == (scaling_factor *
|
|
74
|
+
original_max_position_embeddings,
|
|
75
|
+
expected_padded_dim))
|
|
76
|
+
|
|
77
|
+
valid_cache_slice = rope.sin_cos_cache[:, :head_dim]
|
|
78
|
+
|
|
79
|
+
expected_sin_cos = jnp.array([[1.0693147, 0], [0.5777532, 0.8997973]],
|
|
80
|
+
dtype=jnp.float32)
|
|
81
|
+
|
|
82
|
+
self.assertArraysAllClose(valid_cache_slice, expected_sin_cos)
|
|
83
|
+
|
|
84
|
+
num_tokens = 2
|
|
85
|
+
num_heads = 1
|
|
86
|
+
positions = jnp.arange(num_tokens)
|
|
87
|
+
x = jnp.ones((num_tokens, num_heads, head_dim))
|
|
88
|
+
x_rope = rope.apply_rope(positions, x)
|
|
89
|
+
expected_x_rope = jnp.array(
|
|
90
|
+
[[[1.0693147, 1.0693147]], [[-0.32204413, 1.4775505]]],
|
|
91
|
+
dtype=jnp.float32)
|
|
92
|
+
self.assertTrue(x_rope.shape == x.shape)
|
|
93
|
+
self.assertArraysAllClose(x_rope, expected_x_rope)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import unittest
|
|
16
|
+
from unittest.mock import MagicMock
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
|
|
20
|
+
from tpu_inference.layers.common.sharding import (Sharding, ShardingConfig,
|
|
21
|
+
ShardingRulesConfig,
|
|
22
|
+
ShardingStrategy)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TestSharding(unittest.TestCase):
|
|
26
|
+
"""Unit test suite for the sharding configuration logic."""
|
|
27
|
+
|
|
28
|
+
def setUp(self):
|
|
29
|
+
"""Sets up the testing environment before each test."""
|
|
30
|
+
|
|
31
|
+
self.mock_devices = [MagicMock(coords=i) for i in range(8)]
|
|
32
|
+
self.original_jax_devices = jax.devices
|
|
33
|
+
jax.devices = lambda: self.mock_devices
|
|
34
|
+
|
|
35
|
+
def tearDown(self):
|
|
36
|
+
"""Restores the original jax.devices function after tests."""
|
|
37
|
+
jax.devices = self.original_jax_devices
|
|
38
|
+
|
|
39
|
+
def test_sharding_strategy_init(self):
|
|
40
|
+
"""Tests the initialization of the ShardingStrategy."""
|
|
41
|
+
strategy = ShardingStrategy(
|
|
42
|
+
tensor_parallelism=2,
|
|
43
|
+
expert_parallelism=4,
|
|
44
|
+
data_parallelism=1,
|
|
45
|
+
sequence_parallelism=1,
|
|
46
|
+
)
|
|
47
|
+
self.assertEqual(strategy.tensor_parallelism, 2)
|
|
48
|
+
self.assertEqual(strategy.expert_parallelism, 4)
|
|
49
|
+
|
|
50
|
+
def test_sharding_config_init(self):
|
|
51
|
+
"""Tests the initialization of ShardingConfig."""
|
|
52
|
+
config = ShardingConfig()
|
|
53
|
+
self.assertIsInstance(config.prefill_rules, ShardingRulesConfig)
|
|
54
|
+
self.assertIsInstance(config.generate_rules, ShardingRulesConfig)
|
|
55
|
+
|
|
56
|
+
custom_rules = ShardingRulesConfig(activation_ffw_td=("model", None))
|
|
57
|
+
config_with_rules = ShardingConfig(prefill_rules=custom_rules)
|
|
58
|
+
self.assertEqual(config_with_rules.prefill_rules.activation_ffw_td,
|
|
59
|
+
("model", None))
|
|
60
|
+
|
|
61
|
+
def test_apply_overrides(self):
|
|
62
|
+
"""Tests the _apply_overrides method for valid and invalid keys."""
|
|
63
|
+
sharding = Sharding(
|
|
64
|
+
prefill_rules={},
|
|
65
|
+
generate_rules={},
|
|
66
|
+
)
|
|
67
|
+
config_obj = ShardingRulesConfig()
|
|
68
|
+
|
|
69
|
+
valid_overrides = {"activation_ffw_td": ("model", None)}
|
|
70
|
+
sharding._apply_overrides(config_obj, valid_overrides)
|
|
71
|
+
self.assertEqual(config_obj.activation_ffw_td, ("model", None))
|
|
72
|
+
|
|
73
|
+
invalid_overrides = {"non_existent_attribute": (None, "model")}
|
|
74
|
+
with self.assertRaises(AttributeError):
|
|
75
|
+
sharding._apply_overrides(config_obj, invalid_overrides)
|
|
76
|
+
|
|
77
|
+
def test_default_sharding_config(self):
|
|
78
|
+
"""Tests that default sharding rules are created correctly."""
|
|
79
|
+
sharding = Sharding(
|
|
80
|
+
prefill_rules={},
|
|
81
|
+
generate_rules={},
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
sharding_cfg = sharding.get_sharding_cfg()
|
|
85
|
+
generate_rules = sharding_cfg.generate_rules
|
|
86
|
+
|
|
87
|
+
self.assertEqual(generate_rules.ffw_weight_df, (None, "model"))
|
|
88
|
+
self.assertEqual(generate_rules.moe_router_de, (None, "model"))
|
|
89
|
+
self.assertEqual(generate_rules.attn_q_weight_dnh,
|
|
90
|
+
(None, "model", None))
|
|
91
|
+
|
|
92
|
+
def test_sharding_init_with_overrides(self):
|
|
93
|
+
"""Tests Sharding initialization with programmatic overrides."""
|
|
94
|
+
generate_overrides = {"logits_tv": ("data", "model")}
|
|
95
|
+
|
|
96
|
+
sharding = Sharding(
|
|
97
|
+
generate_rules=generate_overrides,
|
|
98
|
+
prefill_rules={},
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
sharding_cfg = sharding.get_sharding_cfg()
|
|
102
|
+
self.assertNotEqual(sharding_cfg.generate_rules.logits_tv,
|
|
103
|
+
(None, "model"))
|
|
104
|
+
self.assertEqual(sharding_cfg.generate_rules.logits_tv,
|
|
105
|
+
("data", "model"))
|
|
106
|
+
|
|
107
|
+
def test_get_overrides_from_vllm_config(self):
|
|
108
|
+
"""Tests fetching sharding overrides from a mock VllmConfig."""
|
|
109
|
+
|
|
110
|
+
mock_vllm_config_prefill = MagicMock()
|
|
111
|
+
mock_vllm_config_prefill.additional_config = {
|
|
112
|
+
"sharding": {
|
|
113
|
+
"logical_rules": {
|
|
114
|
+
"all": {
|
|
115
|
+
"norm_scale": ("model", )
|
|
116
|
+
},
|
|
117
|
+
"prefill": {
|
|
118
|
+
"activation_ffw_td": ("data", "model")
|
|
119
|
+
},
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
sharding_prefill = Sharding(
|
|
124
|
+
vllm_config=mock_vllm_config_prefill,
|
|
125
|
+
prefill_rules={},
|
|
126
|
+
generate_rules={},
|
|
127
|
+
)
|
|
128
|
+
prefill_overrides = sharding_prefill._get_overrides("prefill")
|
|
129
|
+
|
|
130
|
+
self.assertEqual(prefill_overrides["norm_scale"], ("model", ))
|
|
131
|
+
self.assertEqual(prefill_overrides["activation_ffw_td"],
|
|
132
|
+
("data", "model"))
|
|
133
|
+
|
|
134
|
+
mock_vllm_config_generate = MagicMock()
|
|
135
|
+
mock_vllm_config_generate.additional_config = {
|
|
136
|
+
"sharding": {
|
|
137
|
+
"logical_rules": {
|
|
138
|
+
"all": {
|
|
139
|
+
"norm_scale": ("model", )
|
|
140
|
+
},
|
|
141
|
+
"prefill": {
|
|
142
|
+
"activation_ffw_td": ("data", "model")
|
|
143
|
+
},
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
sharding_generate = Sharding(
|
|
148
|
+
vllm_config=mock_vllm_config_generate,
|
|
149
|
+
prefill_rules={},
|
|
150
|
+
generate_rules={},
|
|
151
|
+
)
|
|
152
|
+
generate_overrides = sharding_generate._get_overrides("generate")
|
|
153
|
+
|
|
154
|
+
self.assertEqual(generate_overrides["norm_scale"], ("model", ))
|
|
155
|
+
self.assertNotIn("activation_ffw_td", generate_overrides)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
if __name__ == "__main__":
|
|
159
|
+
unittest.main()
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import unittest
|
|
16
|
+
from unittest.mock import MagicMock
|
|
17
|
+
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
from flax import nnx
|
|
20
|
+
|
|
21
|
+
from tpu_inference.layers.jax.layers import DenseFFW
|
|
22
|
+
from tpu_inference.layers.jax.moe.moe import MoE
|
|
23
|
+
from tpu_inference.layers.jax.transformer_block import (
|
|
24
|
+
SharedExpertsTransformerBlock, TransformerBlock)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TestTransformerBlock(unittest.TestCase):
|
|
28
|
+
"""Unit test suite for the JAX TransformerBlock module."""
|
|
29
|
+
|
|
30
|
+
def test_transformer_block_dense_logic(self):
|
|
31
|
+
"""
|
|
32
|
+
Tests the forward pass logic of a dense TransformerBlock by mocking its sub-modules.
|
|
33
|
+
This test verifies the sequence of operations and residual connections.
|
|
34
|
+
"""
|
|
35
|
+
hidden_size = 1024
|
|
36
|
+
|
|
37
|
+
mock_pre_attn_norm = MagicMock(spec=nnx.Module)
|
|
38
|
+
mock_pre_mlp_norm = MagicMock(spec=nnx.Module)
|
|
39
|
+
|
|
40
|
+
mock_attn = MagicMock(spec=nnx.Module)
|
|
41
|
+
dummy_attn_output = jnp.full((64, hidden_size),
|
|
42
|
+
2.0,
|
|
43
|
+
dtype=jnp.bfloat16)
|
|
44
|
+
dummy_kv_cache = jnp.zeros((8, 16, 16, 128), dtype=jnp.bfloat16)
|
|
45
|
+
mock_attn.return_value = (dummy_kv_cache, dummy_attn_output)
|
|
46
|
+
|
|
47
|
+
mock_mlp = MagicMock(spec=DenseFFW)
|
|
48
|
+
dummy_mlp_output = jnp.full((64, hidden_size), 3.0, dtype=jnp.bfloat16)
|
|
49
|
+
mock_mlp.return_value = dummy_mlp_output
|
|
50
|
+
|
|
51
|
+
transformer_block = TransformerBlock(
|
|
52
|
+
pre_attention_norm=mock_pre_attn_norm,
|
|
53
|
+
pre_mlp_norm=mock_pre_mlp_norm,
|
|
54
|
+
custom_module=mock_mlp,
|
|
55
|
+
attn=mock_attn,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
seq_len = 64
|
|
59
|
+
x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
|
|
60
|
+
initial_kv_cache = MagicMock()
|
|
61
|
+
attention_metadata = MagicMock()
|
|
62
|
+
|
|
63
|
+
mock_pre_attn_norm.side_effect = lambda val: val
|
|
64
|
+
mock_pre_mlp_norm.side_effect = lambda val: val
|
|
65
|
+
|
|
66
|
+
new_kv_cache, final_output = transformer_block(
|
|
67
|
+
x,
|
|
68
|
+
is_prefill=True,
|
|
69
|
+
kv_cache=initial_kv_cache,
|
|
70
|
+
attention_metadata=attention_metadata,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
mock_pre_attn_norm.assert_called_once()
|
|
74
|
+
self.assertTrue(
|
|
75
|
+
jnp.array_equal(mock_pre_attn_norm.call_args.args[0], x))
|
|
76
|
+
|
|
77
|
+
mock_attn.assert_called_once_with(x, True, initial_kv_cache,
|
|
78
|
+
attention_metadata, True)
|
|
79
|
+
|
|
80
|
+
expected_mlp_norm_input = dummy_attn_output + x
|
|
81
|
+
|
|
82
|
+
mock_pre_mlp_norm.assert_called_once()
|
|
83
|
+
self.assertTrue(
|
|
84
|
+
jnp.array_equal(mock_pre_mlp_norm.call_args.args[0],
|
|
85
|
+
expected_mlp_norm_input))
|
|
86
|
+
|
|
87
|
+
mock_mlp.assert_called_once()
|
|
88
|
+
self.assertTrue(
|
|
89
|
+
jnp.array_equal(mock_mlp.call_args.args[0],
|
|
90
|
+
expected_mlp_norm_input))
|
|
91
|
+
|
|
92
|
+
expected_final_output = dummy_mlp_output + expected_mlp_norm_input
|
|
93
|
+
self.assertTrue(jnp.allclose(final_output, expected_final_output))
|
|
94
|
+
|
|
95
|
+
self.assertTrue(jnp.array_equal(new_kv_cache, dummy_kv_cache))
|
|
96
|
+
|
|
97
|
+
def test_shared_experts_transformer_block_logic(self):
|
|
98
|
+
"""Tests the forward pass logic of a SharedExpertsTransformerBlock."""
|
|
99
|
+
hidden_size = 1024
|
|
100
|
+
|
|
101
|
+
mock_pre_attn_norm = MagicMock(spec=nnx.Module)
|
|
102
|
+
mock_pre_mlp_norm = MagicMock(spec=nnx.Module)
|
|
103
|
+
|
|
104
|
+
mock_attn = MagicMock(spec=nnx.Module)
|
|
105
|
+
dummy_attn_output = jnp.full((64, hidden_size),
|
|
106
|
+
2.0,
|
|
107
|
+
dtype=jnp.bfloat16)
|
|
108
|
+
dummy_kv_cache = jnp.zeros((8, 16, 16, 128), dtype=jnp.bfloat16)
|
|
109
|
+
mock_attn.return_value = (dummy_kv_cache, dummy_attn_output)
|
|
110
|
+
|
|
111
|
+
mock_moe = MagicMock(spec=MoE)
|
|
112
|
+
dummy_moe_output = jnp.full((64, hidden_size), 3.0, dtype=jnp.bfloat16)
|
|
113
|
+
mock_moe.return_value = dummy_moe_output
|
|
114
|
+
|
|
115
|
+
mock_shared_experts = MagicMock(spec=DenseFFW)
|
|
116
|
+
dummy_shared_experts_output = jnp.full((64, hidden_size),
|
|
117
|
+
4.0,
|
|
118
|
+
dtype=jnp.bfloat16)
|
|
119
|
+
mock_shared_experts.return_value = dummy_shared_experts_output
|
|
120
|
+
|
|
121
|
+
transformer_block = SharedExpertsTransformerBlock(
|
|
122
|
+
pre_attention_norm=mock_pre_attn_norm,
|
|
123
|
+
pre_mlp_norm=mock_pre_mlp_norm,
|
|
124
|
+
custom_module=mock_moe,
|
|
125
|
+
attn=mock_attn,
|
|
126
|
+
shared_experts=mock_shared_experts,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
seq_len = 64
|
|
130
|
+
x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
|
|
131
|
+
initial_kv_cache = MagicMock()
|
|
132
|
+
attention_metadata = MagicMock()
|
|
133
|
+
|
|
134
|
+
mock_pre_attn_norm.side_effect = lambda val: val
|
|
135
|
+
mock_pre_mlp_norm.side_effect = lambda val: val
|
|
136
|
+
|
|
137
|
+
new_kv_cache, final_output = transformer_block(
|
|
138
|
+
x,
|
|
139
|
+
is_prefill=True,
|
|
140
|
+
kv_cache=initial_kv_cache,
|
|
141
|
+
attention_metadata=attention_metadata,
|
|
142
|
+
)
|
|
143
|
+
self.assertTrue(jnp.array_equal(new_kv_cache, dummy_kv_cache))
|
|
144
|
+
self.assertEqual(final_output.shape, (seq_len, hidden_size))
|
|
145
|
+
|
|
146
|
+
self.assertEqual(mock_moe.call_count, 1)
|
|
147
|
+
self.assertEqual(mock_attn.call_count, 1)
|
|
148
|
+
self.assertEqual(mock_shared_experts.call_count, 1)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
if __name__ == "__main__":
|
|
152
|
+
unittest.main()
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|