tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# /home/pooyam/tpu_inference/tests/models/jax/layers/test_sampling.py
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
import numpy as np
|
|
18
|
+
from vllm.v1.outputs import LogprobsTensors
|
|
19
|
+
|
|
20
|
+
from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
|
|
21
|
+
gather_logprobs)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TestSampling:
|
|
25
|
+
|
|
26
|
+
def test_compute_logprobs(self):
|
|
27
|
+
logits = jnp.array([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]],
|
|
28
|
+
dtype=jnp.float32)
|
|
29
|
+
logprobs = compute_logprobs(logits)
|
|
30
|
+
|
|
31
|
+
# Expected values computed with scipy.special.log_softmax
|
|
32
|
+
expected_logprobs = np.array(
|
|
33
|
+
[
|
|
34
|
+
[-2.40760596, -1.40760596, -0.40760596],
|
|
35
|
+
[-0.40760596, -1.40760596, -2.40760596],
|
|
36
|
+
],
|
|
37
|
+
dtype=np.float32,
|
|
38
|
+
)
|
|
39
|
+
assert np.allclose(logprobs, expected_logprobs, atol=1e-6)
|
|
40
|
+
|
|
41
|
+
def test_gather_logprobs(self):
|
|
42
|
+
logprobs = jnp.array(
|
|
43
|
+
[
|
|
44
|
+
[-2.40760596, -1.40760596, -0.40760596, -3.40760596],
|
|
45
|
+
[-0.40760596, -1.40760596, -2.40760596, -3.40760596],
|
|
46
|
+
],
|
|
47
|
+
dtype=jnp.float32,
|
|
48
|
+
)
|
|
49
|
+
token_ids = jnp.array([2, 0], dtype=jnp.int32)
|
|
50
|
+
num_logprobs = 2
|
|
51
|
+
|
|
52
|
+
result: LogprobsTensors = gather_logprobs(logprobs, token_ids,
|
|
53
|
+
num_logprobs)
|
|
54
|
+
|
|
55
|
+
# check indices
|
|
56
|
+
expected_indices = np.array(
|
|
57
|
+
[
|
|
58
|
+
[2, 2, 1], # token id 2, top-k are 2, 1
|
|
59
|
+
[0, 0, 1], # token id 0, top-k are 0, 1
|
|
60
|
+
],
|
|
61
|
+
dtype=np.int32,
|
|
62
|
+
)
|
|
63
|
+
assert np.array_equal(result.logprob_token_ids, expected_indices)
|
|
64
|
+
|
|
65
|
+
# check logprobs
|
|
66
|
+
expected_logprobs_values = np.array(
|
|
67
|
+
[
|
|
68
|
+
[-0.40760596, -0.40760596, -1.40760596],
|
|
69
|
+
[-0.40760596, -0.40760596, -1.40760596],
|
|
70
|
+
],
|
|
71
|
+
dtype=np.float32,
|
|
72
|
+
)
|
|
73
|
+
assert np.allclose(result.logprobs,
|
|
74
|
+
expected_logprobs_values,
|
|
75
|
+
atol=1e-6)
|
|
76
|
+
|
|
77
|
+
# check ranks
|
|
78
|
+
expected_ranks = np.array([1, 1], dtype=np.int32)
|
|
79
|
+
assert np.array_equal(result.selected_token_ranks, expected_ranks)
|
|
80
|
+
|
|
81
|
+
def test_gather_logprobs_with_ties(self):
|
|
82
|
+
logprobs = jnp.array(
|
|
83
|
+
[
|
|
84
|
+
[-1.0, -1.0, -2.0, -2.0],
|
|
85
|
+
],
|
|
86
|
+
dtype=jnp.float32,
|
|
87
|
+
)
|
|
88
|
+
token_ids = jnp.array([1], dtype=jnp.int32)
|
|
89
|
+
num_logprobs = 3
|
|
90
|
+
|
|
91
|
+
result: LogprobsTensors = gather_logprobs(logprobs, token_ids,
|
|
92
|
+
num_logprobs)
|
|
93
|
+
|
|
94
|
+
# check logprobs
|
|
95
|
+
expected_logprobs_values = np.array(
|
|
96
|
+
[
|
|
97
|
+
[-1.0, -1.0, -1.0, -2.0],
|
|
98
|
+
],
|
|
99
|
+
dtype=np.float32,
|
|
100
|
+
)
|
|
101
|
+
assert np.allclose(result.logprobs,
|
|
102
|
+
expected_logprobs_values,
|
|
103
|
+
atol=1e-6)
|
|
104
|
+
|
|
105
|
+
# check ranks
|
|
106
|
+
# rank of token 1 is 2 because there are 2 values >= -1.0
|
|
107
|
+
expected_ranks = np.array([2], dtype=np.int32)
|
|
108
|
+
assert np.array_equal(result.selected_token_ranks, expected_ranks)
|
|
109
|
+
|
|
110
|
+
# check indices
|
|
111
|
+
# The order of tied elements is not guaranteed.
|
|
112
|
+
# token id is 1. top-k indices are a permutation of {0, 1, 2} or {0, 1, 3}.
|
|
113
|
+
assert result.logprob_token_ids[0, 0] == 1
|
|
114
|
+
top_k_indices = sorted(result.logprob_token_ids[0, 1:].tolist())
|
|
115
|
+
assert top_k_indices == [0, 1, 2] or top_k_indices == [0, 1, 3]
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pytest
|
|
21
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
22
|
+
|
|
23
|
+
from tpu_inference.layers.jax.sample.sampling_metadata import (
|
|
24
|
+
DEFAULT_SAMPLING_PARAMS, TPUSupportedSamplingMetadata)
|
|
25
|
+
|
|
26
|
+
## Mocks and Fixtures
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class MockInputBatch:
|
|
31
|
+
"""A mock of the InputBatch class, using NumPy arrays for CPU tensors."""
|
|
32
|
+
|
|
33
|
+
all_greedy: bool
|
|
34
|
+
num_reqs: int = 0
|
|
35
|
+
temperature_cpu: np.ndarray = None
|
|
36
|
+
top_k_cpu: np.ndarray = None
|
|
37
|
+
top_p_cpu: np.ndarray = None
|
|
38
|
+
max_num_logprobs: int = None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@pytest.fixture(scope="module")
|
|
42
|
+
def mesh() -> Mesh:
|
|
43
|
+
"""Creates a 1D JAX mesh for testing on available devices."""
|
|
44
|
+
if not jax.devices():
|
|
45
|
+
pytest.skip("No JAX devices available for testing.")
|
|
46
|
+
return Mesh(np.array(jax.devices()), axis_names=("data", ))
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
## Test Cases
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_from_input_batch_all_greedy(mesh: Mesh):
|
|
53
|
+
"""
|
|
54
|
+
Tests TPUSupportedSamplingMetadata.from_input_batch when **all_greedy is True**.
|
|
55
|
+
|
|
56
|
+
It should return an object with `do_sampling=False` and `None` for the tensors.
|
|
57
|
+
"""
|
|
58
|
+
mock_batch = MockInputBatch(all_greedy=True)
|
|
59
|
+
padded_num_reqs = 4
|
|
60
|
+
|
|
61
|
+
metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
62
|
+
mesh=mesh, input_batch=mock_batch, padded_num_reqs=padded_num_reqs)
|
|
63
|
+
|
|
64
|
+
assert not metadata.do_sampling, "do_sampling should be False for greedy requests"
|
|
65
|
+
assert metadata.temperature is None
|
|
66
|
+
assert metadata.top_k is None
|
|
67
|
+
assert metadata.top_p is None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def test_from_input_batch_with_sampling_and_padding(mesh: Mesh):
|
|
71
|
+
"""
|
|
72
|
+
Tests TPUSupportedSamplingMetadata.from_input_batch with sampling enabled,
|
|
73
|
+
requiring the tensors to be **padded** to the correct shape.
|
|
74
|
+
"""
|
|
75
|
+
num_reqs = 2
|
|
76
|
+
padded_num_reqs = 4
|
|
77
|
+
|
|
78
|
+
# Input tensors must be large enough to hold the padded values.
|
|
79
|
+
temp_tensor = np.array([0.7, 0.8, 0.0, 0.0], dtype=np.float32)
|
|
80
|
+
top_k_tensor = np.array([10, 20, 0, 0], dtype=np.int32)
|
|
81
|
+
top_p_tensor = np.array([0.9, 0.95, 0.0, 0.0], dtype=np.float32)
|
|
82
|
+
|
|
83
|
+
mock_batch = MockInputBatch(
|
|
84
|
+
all_greedy=False,
|
|
85
|
+
num_reqs=num_reqs,
|
|
86
|
+
temperature_cpu=temp_tensor,
|
|
87
|
+
top_k_cpu=top_k_tensor,
|
|
88
|
+
top_p_cpu=top_p_tensor,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
92
|
+
mesh=mesh, input_batch=mock_batch, padded_num_reqs=padded_num_reqs)
|
|
93
|
+
|
|
94
|
+
# 1. Check metadata flags and types
|
|
95
|
+
assert metadata.do_sampling, "do_sampling should be True"
|
|
96
|
+
assert isinstance(metadata.temperature, jnp.ndarray)
|
|
97
|
+
assert isinstance(metadata.top_k, jnp.ndarray)
|
|
98
|
+
assert isinstance(metadata.top_p, jnp.ndarray)
|
|
99
|
+
|
|
100
|
+
# 2. Check shapes
|
|
101
|
+
assert metadata.temperature.shape == (padded_num_reqs, )
|
|
102
|
+
assert metadata.top_k.shape == (padded_num_reqs, )
|
|
103
|
+
assert metadata.top_p.shape == (padded_num_reqs, )
|
|
104
|
+
|
|
105
|
+
# 3. Check sharding (should be fully replicated)
|
|
106
|
+
expected_sharding = NamedSharding(mesh, PartitionSpec(None))
|
|
107
|
+
assert metadata.temperature.sharding == expected_sharding
|
|
108
|
+
assert metadata.top_k.sharding == expected_sharding
|
|
109
|
+
assert metadata.top_p.sharding == expected_sharding
|
|
110
|
+
|
|
111
|
+
# 4. Check that values were correctly padded
|
|
112
|
+
expected_temp = np.array(
|
|
113
|
+
[
|
|
114
|
+
0.7, 0.8, DEFAULT_SAMPLING_PARAMS["temperature"],
|
|
115
|
+
DEFAULT_SAMPLING_PARAMS["temperature"]
|
|
116
|
+
],
|
|
117
|
+
dtype=np.float32,
|
|
118
|
+
)
|
|
119
|
+
expected_top_k = np.array(
|
|
120
|
+
[
|
|
121
|
+
10, 20, DEFAULT_SAMPLING_PARAMS["top_k"],
|
|
122
|
+
DEFAULT_SAMPLING_PARAMS["top_k"]
|
|
123
|
+
],
|
|
124
|
+
dtype=np.int32,
|
|
125
|
+
)
|
|
126
|
+
expected_top_p = np.array(
|
|
127
|
+
[
|
|
128
|
+
0.9, 0.95, DEFAULT_SAMPLING_PARAMS["top_p"],
|
|
129
|
+
DEFAULT_SAMPLING_PARAMS["top_p"]
|
|
130
|
+
],
|
|
131
|
+
dtype=np.float32,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
np.testing.assert_allclose(np.asarray(metadata.temperature), expected_temp)
|
|
135
|
+
np.testing.assert_array_equal(np.asarray(metadata.top_k), expected_top_k)
|
|
136
|
+
np.testing.assert_allclose(np.asarray(metadata.top_p), expected_top_p)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def test_from_input_batch_no_padding_needed(mesh: Mesh):
|
|
140
|
+
"""
|
|
141
|
+
Tests the case where `num_reqs` equals `padded_num_reqs`, so **no padding** should occur.
|
|
142
|
+
"""
|
|
143
|
+
num_reqs = 4
|
|
144
|
+
padded_num_reqs = 4
|
|
145
|
+
|
|
146
|
+
temp_tensor = np.array([0.7, 0.8, 0.6, 0.5], dtype=np.float32)
|
|
147
|
+
top_k_tensor = np.array([10, 20, 30, 40], dtype=np.int32)
|
|
148
|
+
top_p_tensor = np.array([0.9, 0.95, 0.85, 0.8], dtype=np.float32)
|
|
149
|
+
|
|
150
|
+
mock_batch = MockInputBatch(
|
|
151
|
+
all_greedy=False,
|
|
152
|
+
num_reqs=num_reqs,
|
|
153
|
+
temperature_cpu=temp_tensor,
|
|
154
|
+
top_k_cpu=top_k_tensor,
|
|
155
|
+
top_p_cpu=top_p_tensor,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
159
|
+
mesh=mesh, input_batch=mock_batch, padded_num_reqs=padded_num_reqs)
|
|
160
|
+
|
|
161
|
+
assert metadata.do_sampling
|
|
162
|
+
# Check that values are identical to the input, since no padding was needed
|
|
163
|
+
np.testing.assert_allclose(np.asarray(metadata.temperature), temp_tensor)
|
|
164
|
+
np.testing.assert_array_equal(np.asarray(metadata.top_k), top_k_tensor)
|
|
165
|
+
np.testing.assert_allclose(np.asarray(metadata.top_p), top_p_tensor)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def test_jax_tree_util_registration():
|
|
169
|
+
"""
|
|
170
|
+
Tests that the dataclass is correctly registered as a **JAX PyTree**,
|
|
171
|
+
meaning `jax.tree_util` functions can operate on it as expected. 🌳
|
|
172
|
+
"""
|
|
173
|
+
metadata = TPUSupportedSamplingMetadata(
|
|
174
|
+
temperature=jnp.array([0.7]),
|
|
175
|
+
top_k=jnp.array([10]),
|
|
176
|
+
top_p=jnp.array([0.9]),
|
|
177
|
+
do_sampling=True,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Flatten the PyTree
|
|
181
|
+
leaves, treedef = jax.tree_util.tree_flatten(metadata)
|
|
182
|
+
|
|
183
|
+
# The leaves should be the "data_fields" specified in the decorator
|
|
184
|
+
assert len(leaves) == 3
|
|
185
|
+
np.testing.assert_array_equal(leaves[0], jnp.array([0.7]))
|
|
186
|
+
np.testing.assert_array_equal(leaves[1], jnp.array([10]))
|
|
187
|
+
np.testing.assert_array_equal(leaves[2], jnp.array([0.9]))
|
|
188
|
+
|
|
189
|
+
# Reconstruct the PyTree from leaves
|
|
190
|
+
new_metadata = jax.tree_util.tree_unflatten(treedef, leaves)
|
|
191
|
+
|
|
192
|
+
# The reconstructed object should match the original
|
|
193
|
+
assert new_metadata.do_sampling == metadata.do_sampling
|
|
194
|
+
np.testing.assert_array_equal(new_metadata.temperature,
|
|
195
|
+
metadata.temperature)
|
|
196
|
+
np.testing.assert_array_equal(new_metadata.top_k, metadata.top_k)
|
|
197
|
+
np.testing.assert_array_equal(new_metadata.top_p, metadata.top_p)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def test_from_input_batch_with_logprobs(mesh: Mesh):
|
|
201
|
+
"""
|
|
202
|
+
Tests that the `logprobs` flag is correctly set based on `max_num_logprobs`.
|
|
203
|
+
"""
|
|
204
|
+
# Case 1: Logprobs are requested
|
|
205
|
+
mock_batch_with_logprobs = MockInputBatch(all_greedy=True,
|
|
206
|
+
max_num_logprobs=5)
|
|
207
|
+
metadata_with = TPUSupportedSamplingMetadata.from_input_batch(
|
|
208
|
+
mesh=mesh,
|
|
209
|
+
input_batch=mock_batch_with_logprobs,
|
|
210
|
+
padded_num_reqs=4,
|
|
211
|
+
)
|
|
212
|
+
assert metadata_with.logprobs, "logprobs should be True when max_num_logprobs > 0"
|
|
213
|
+
|
|
214
|
+
# Case 2: Logprobs are not requested (max_num_logprobs is 0)
|
|
215
|
+
mock_batch_no_logprobs_zero = MockInputBatch(all_greedy=True,
|
|
216
|
+
max_num_logprobs=0)
|
|
217
|
+
metadata_without_zero = TPUSupportedSamplingMetadata.from_input_batch(
|
|
218
|
+
mesh=mesh,
|
|
219
|
+
input_batch=mock_batch_no_logprobs_zero,
|
|
220
|
+
padded_num_reqs=4,
|
|
221
|
+
)
|
|
222
|
+
assert not metadata_without_zero.logprobs, "logprobs should be False when max_num_logprobs is 0"
|
|
223
|
+
|
|
224
|
+
# Case 3: Logprobs are not requested (max_num_logprobs is None)
|
|
225
|
+
mock_batch_no_logprobs_none = MockInputBatch(all_greedy=True,
|
|
226
|
+
max_num_logprobs=None)
|
|
227
|
+
metadata_without_none = TPUSupportedSamplingMetadata.from_input_batch(
|
|
228
|
+
mesh=mesh,
|
|
229
|
+
input_batch=mock_batch_no_logprobs_none,
|
|
230
|
+
padded_num_reqs=4,
|
|
231
|
+
)
|
|
232
|
+
assert not metadata_without_none.logprobs, "logprobs should be False when max_num_logprobs is None"
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def test_from_input_batch_sampling_with_logprobs(mesh: Mesh):
|
|
236
|
+
"""
|
|
237
|
+
Tests enabling both sampling and logprobs simultaneously.
|
|
238
|
+
"""
|
|
239
|
+
num_reqs = 2
|
|
240
|
+
padded_num_reqs = 4
|
|
241
|
+
mock_batch = MockInputBatch(
|
|
242
|
+
all_greedy=False,
|
|
243
|
+
num_reqs=num_reqs,
|
|
244
|
+
temperature_cpu=np.zeros((padded_num_reqs, ), dtype=np.float32),
|
|
245
|
+
top_k_cpu=np.zeros((padded_num_reqs, ), dtype=np.int32),
|
|
246
|
+
top_p_cpu=np.zeros((padded_num_reqs, ), dtype=np.float32),
|
|
247
|
+
max_num_logprobs=10,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
251
|
+
mesh=mesh, input_batch=mock_batch, padded_num_reqs=padded_num_reqs)
|
|
252
|
+
|
|
253
|
+
assert metadata.do_sampling, "do_sampling should be True"
|
|
254
|
+
assert metadata.logprobs, "logprobs should be True"
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import unittest
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
from flax import nnx
|
|
21
|
+
from jax.sharding import Mesh
|
|
22
|
+
|
|
23
|
+
from tpu_inference.layers.jax.layers import DenseFFW, Embedder, RMSNorm
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TestLayers(unittest.TestCase):
|
|
27
|
+
"""Unit test suite for common JAX layer blocks."""
|
|
28
|
+
|
|
29
|
+
def setUp(self):
|
|
30
|
+
"""Sets up the testing environment before each test."""
|
|
31
|
+
self.mesh = Mesh(
|
|
32
|
+
np.array(jax.devices()).reshape(1, -1),
|
|
33
|
+
axis_names=(
|
|
34
|
+
"expert",
|
|
35
|
+
"model",
|
|
36
|
+
),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def test_rmsnorm_forward_pass(self):
|
|
40
|
+
"""Tests the forward pass of the RMSNorm module."""
|
|
41
|
+
with jax.set_mesh(self.mesh):
|
|
42
|
+
dims = 512
|
|
43
|
+
epsilon = 1e-5
|
|
44
|
+
|
|
45
|
+
norm = RMSNorm(
|
|
46
|
+
dims=dims,
|
|
47
|
+
random_init=True,
|
|
48
|
+
epsilon=epsilon,
|
|
49
|
+
rngs=nnx.Rngs(0),
|
|
50
|
+
dtype=jnp.float32,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
seq_len = 128
|
|
54
|
+
x = jax.random.normal(jax.random.PRNGKey(42), (seq_len, dims))
|
|
55
|
+
|
|
56
|
+
output = norm(x)
|
|
57
|
+
|
|
58
|
+
self.assertEqual(output.shape, x.shape)
|
|
59
|
+
self.assertEqual(output.dtype, jnp.float32)
|
|
60
|
+
|
|
61
|
+
mean_of_squares = jnp.mean(jnp.square(output), axis=-1)
|
|
62
|
+
self.assertTrue(
|
|
63
|
+
jnp.allclose(mean_of_squares, 1.0, atol=1e-5).all())
|
|
64
|
+
|
|
65
|
+
def test_denseffw_forward_pass(self):
|
|
66
|
+
"""Tests the forward pass of the DenseFFW module."""
|
|
67
|
+
with jax.set_mesh(self.mesh):
|
|
68
|
+
hidden_size = 512
|
|
69
|
+
intermediate_size = 2048
|
|
70
|
+
|
|
71
|
+
ffw_layer = DenseFFW(
|
|
72
|
+
random_init=True,
|
|
73
|
+
dtype=jnp.bfloat16,
|
|
74
|
+
hidden_act="silu",
|
|
75
|
+
hidden_size=hidden_size,
|
|
76
|
+
intermediate_size=intermediate_size,
|
|
77
|
+
rngs=nnx.Rngs(0),
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
seq_len = 128
|
|
81
|
+
x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
|
|
82
|
+
|
|
83
|
+
output = ffw_layer(x)
|
|
84
|
+
|
|
85
|
+
self.assertEqual(output.shape, x.shape)
|
|
86
|
+
self.assertEqual(output.dtype, x.dtype)
|
|
87
|
+
|
|
88
|
+
def test_embedder_forward_pass(self):
|
|
89
|
+
"""Tests both the encode and decode passes of the Embedder module."""
|
|
90
|
+
with jax.set_mesh(self.mesh):
|
|
91
|
+
hidden_size = 512
|
|
92
|
+
vocab_size = 32000
|
|
93
|
+
dtype = jnp.bfloat16
|
|
94
|
+
|
|
95
|
+
embedder = Embedder(
|
|
96
|
+
vocab_size=vocab_size,
|
|
97
|
+
hidden_size=hidden_size,
|
|
98
|
+
dtype=dtype,
|
|
99
|
+
random_init=True,
|
|
100
|
+
rngs=nnx.Rngs(0),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
seq_len = 128
|
|
104
|
+
token_ids = jnp.arange(seq_len, dtype=jnp.int32) % vocab_size
|
|
105
|
+
embeddings = embedder(token_ids, decode=False)
|
|
106
|
+
self.assertEqual(embeddings.shape, (seq_len, hidden_size))
|
|
107
|
+
self.assertEqual(embeddings.dtype, dtype)
|
|
108
|
+
|
|
109
|
+
hidden_states = jnp.ones((seq_len, hidden_size),
|
|
110
|
+
dtype=jnp.bfloat16)
|
|
111
|
+
logits = embedder(hidden_states, decode=True)
|
|
112
|
+
self.assertEqual(logits.shape, (seq_len, vocab_size))
|
|
113
|
+
self.assertEqual(logits.dtype, dtype)
|
|
114
|
+
|
|
115
|
+
def test_embedder_normalization(self):
|
|
116
|
+
"""Tests the embedding normalization feature."""
|
|
117
|
+
with jax.set_mesh(self.mesh):
|
|
118
|
+
hidden_size = 512
|
|
119
|
+
vocab_size = 32000
|
|
120
|
+
|
|
121
|
+
rngs_1 = nnx.Rngs(42)
|
|
122
|
+
rngs_2 = nnx.Rngs(42)
|
|
123
|
+
|
|
124
|
+
embedder_norm = Embedder(
|
|
125
|
+
vocab_size=vocab_size,
|
|
126
|
+
hidden_size=hidden_size,
|
|
127
|
+
dtype=jnp.float32,
|
|
128
|
+
normalize_embeddings=True,
|
|
129
|
+
random_init=True,
|
|
130
|
+
rngs=rngs_1,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
embedder_no_norm = Embedder(
|
|
134
|
+
vocab_size=vocab_size,
|
|
135
|
+
hidden_size=hidden_size,
|
|
136
|
+
dtype=jnp.float32,
|
|
137
|
+
random_init=True,
|
|
138
|
+
rngs=rngs_2,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
token_ids = jnp.arange(10, dtype=jnp.int32)
|
|
142
|
+
|
|
143
|
+
embeddings_norm = embedder_norm(token_ids, decode=False)
|
|
144
|
+
embeddings_no_norm = embedder_no_norm(token_ids, decode=False)
|
|
145
|
+
|
|
146
|
+
scaling_factor = jnp.sqrt(hidden_size)
|
|
147
|
+
expected_embeddings = embeddings_no_norm * scaling_factor
|
|
148
|
+
|
|
149
|
+
self.assertTrue(
|
|
150
|
+
jnp.allclose(embeddings_norm, expected_embeddings,
|
|
151
|
+
atol=1e-6).all())
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
if __name__ == "__main__":
|
|
155
|
+
unittest.main()
|