tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,115 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /home/pooyam/tpu_inference/tests/models/jax/layers/test_sampling.py
16
+ import jax.numpy as jnp
17
+ import numpy as np
18
+ from vllm.v1.outputs import LogprobsTensors
19
+
20
+ from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
21
+ gather_logprobs)
22
+
23
+
24
+ class TestSampling:
25
+
26
+ def test_compute_logprobs(self):
27
+ logits = jnp.array([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]],
28
+ dtype=jnp.float32)
29
+ logprobs = compute_logprobs(logits)
30
+
31
+ # Expected values computed with scipy.special.log_softmax
32
+ expected_logprobs = np.array(
33
+ [
34
+ [-2.40760596, -1.40760596, -0.40760596],
35
+ [-0.40760596, -1.40760596, -2.40760596],
36
+ ],
37
+ dtype=np.float32,
38
+ )
39
+ assert np.allclose(logprobs, expected_logprobs, atol=1e-6)
40
+
41
+ def test_gather_logprobs(self):
42
+ logprobs = jnp.array(
43
+ [
44
+ [-2.40760596, -1.40760596, -0.40760596, -3.40760596],
45
+ [-0.40760596, -1.40760596, -2.40760596, -3.40760596],
46
+ ],
47
+ dtype=jnp.float32,
48
+ )
49
+ token_ids = jnp.array([2, 0], dtype=jnp.int32)
50
+ num_logprobs = 2
51
+
52
+ result: LogprobsTensors = gather_logprobs(logprobs, token_ids,
53
+ num_logprobs)
54
+
55
+ # check indices
56
+ expected_indices = np.array(
57
+ [
58
+ [2, 2, 1], # token id 2, top-k are 2, 1
59
+ [0, 0, 1], # token id 0, top-k are 0, 1
60
+ ],
61
+ dtype=np.int32,
62
+ )
63
+ assert np.array_equal(result.logprob_token_ids, expected_indices)
64
+
65
+ # check logprobs
66
+ expected_logprobs_values = np.array(
67
+ [
68
+ [-0.40760596, -0.40760596, -1.40760596],
69
+ [-0.40760596, -0.40760596, -1.40760596],
70
+ ],
71
+ dtype=np.float32,
72
+ )
73
+ assert np.allclose(result.logprobs,
74
+ expected_logprobs_values,
75
+ atol=1e-6)
76
+
77
+ # check ranks
78
+ expected_ranks = np.array([1, 1], dtype=np.int32)
79
+ assert np.array_equal(result.selected_token_ranks, expected_ranks)
80
+
81
+ def test_gather_logprobs_with_ties(self):
82
+ logprobs = jnp.array(
83
+ [
84
+ [-1.0, -1.0, -2.0, -2.0],
85
+ ],
86
+ dtype=jnp.float32,
87
+ )
88
+ token_ids = jnp.array([1], dtype=jnp.int32)
89
+ num_logprobs = 3
90
+
91
+ result: LogprobsTensors = gather_logprobs(logprobs, token_ids,
92
+ num_logprobs)
93
+
94
+ # check logprobs
95
+ expected_logprobs_values = np.array(
96
+ [
97
+ [-1.0, -1.0, -1.0, -2.0],
98
+ ],
99
+ dtype=np.float32,
100
+ )
101
+ assert np.allclose(result.logprobs,
102
+ expected_logprobs_values,
103
+ atol=1e-6)
104
+
105
+ # check ranks
106
+ # rank of token 1 is 2 because there are 2 values >= -1.0
107
+ expected_ranks = np.array([2], dtype=np.int32)
108
+ assert np.array_equal(result.selected_token_ranks, expected_ranks)
109
+
110
+ # check indices
111
+ # The order of tied elements is not guaranteed.
112
+ # token id is 1. top-k indices are a permutation of {0, 1, 2} or {0, 1, 3}.
113
+ assert result.logprob_token_ids[0, 0] == 1
114
+ top_k_indices = sorted(result.logprob_token_ids[0, 1:].tolist())
115
+ assert top_k_indices == [0, 1, 2] or top_k_indices == [0, 1, 3]
@@ -0,0 +1,254 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import pytest
21
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
22
+
23
+ from tpu_inference.layers.jax.sample.sampling_metadata import (
24
+ DEFAULT_SAMPLING_PARAMS, TPUSupportedSamplingMetadata)
25
+
26
+ ## Mocks and Fixtures
27
+
28
+
29
+ @dataclass
30
+ class MockInputBatch:
31
+ """A mock of the InputBatch class, using NumPy arrays for CPU tensors."""
32
+
33
+ all_greedy: bool
34
+ num_reqs: int = 0
35
+ temperature_cpu: np.ndarray = None
36
+ top_k_cpu: np.ndarray = None
37
+ top_p_cpu: np.ndarray = None
38
+ max_num_logprobs: int = None
39
+
40
+
41
+ @pytest.fixture(scope="module")
42
+ def mesh() -> Mesh:
43
+ """Creates a 1D JAX mesh for testing on available devices."""
44
+ if not jax.devices():
45
+ pytest.skip("No JAX devices available for testing.")
46
+ return Mesh(np.array(jax.devices()), axis_names=("data", ))
47
+
48
+
49
+ ## Test Cases
50
+
51
+
52
+ def test_from_input_batch_all_greedy(mesh: Mesh):
53
+ """
54
+ Tests TPUSupportedSamplingMetadata.from_input_batch when **all_greedy is True**.
55
+
56
+ It should return an object with `do_sampling=False` and `None` for the tensors.
57
+ """
58
+ mock_batch = MockInputBatch(all_greedy=True)
59
+ padded_num_reqs = 4
60
+
61
+ metadata = TPUSupportedSamplingMetadata.from_input_batch(
62
+ mesh=mesh, input_batch=mock_batch, padded_num_reqs=padded_num_reqs)
63
+
64
+ assert not metadata.do_sampling, "do_sampling should be False for greedy requests"
65
+ assert metadata.temperature is None
66
+ assert metadata.top_k is None
67
+ assert metadata.top_p is None
68
+
69
+
70
+ def test_from_input_batch_with_sampling_and_padding(mesh: Mesh):
71
+ """
72
+ Tests TPUSupportedSamplingMetadata.from_input_batch with sampling enabled,
73
+ requiring the tensors to be **padded** to the correct shape.
74
+ """
75
+ num_reqs = 2
76
+ padded_num_reqs = 4
77
+
78
+ # Input tensors must be large enough to hold the padded values.
79
+ temp_tensor = np.array([0.7, 0.8, 0.0, 0.0], dtype=np.float32)
80
+ top_k_tensor = np.array([10, 20, 0, 0], dtype=np.int32)
81
+ top_p_tensor = np.array([0.9, 0.95, 0.0, 0.0], dtype=np.float32)
82
+
83
+ mock_batch = MockInputBatch(
84
+ all_greedy=False,
85
+ num_reqs=num_reqs,
86
+ temperature_cpu=temp_tensor,
87
+ top_k_cpu=top_k_tensor,
88
+ top_p_cpu=top_p_tensor,
89
+ )
90
+
91
+ metadata = TPUSupportedSamplingMetadata.from_input_batch(
92
+ mesh=mesh, input_batch=mock_batch, padded_num_reqs=padded_num_reqs)
93
+
94
+ # 1. Check metadata flags and types
95
+ assert metadata.do_sampling, "do_sampling should be True"
96
+ assert isinstance(metadata.temperature, jnp.ndarray)
97
+ assert isinstance(metadata.top_k, jnp.ndarray)
98
+ assert isinstance(metadata.top_p, jnp.ndarray)
99
+
100
+ # 2. Check shapes
101
+ assert metadata.temperature.shape == (padded_num_reqs, )
102
+ assert metadata.top_k.shape == (padded_num_reqs, )
103
+ assert metadata.top_p.shape == (padded_num_reqs, )
104
+
105
+ # 3. Check sharding (should be fully replicated)
106
+ expected_sharding = NamedSharding(mesh, PartitionSpec(None))
107
+ assert metadata.temperature.sharding == expected_sharding
108
+ assert metadata.top_k.sharding == expected_sharding
109
+ assert metadata.top_p.sharding == expected_sharding
110
+
111
+ # 4. Check that values were correctly padded
112
+ expected_temp = np.array(
113
+ [
114
+ 0.7, 0.8, DEFAULT_SAMPLING_PARAMS["temperature"],
115
+ DEFAULT_SAMPLING_PARAMS["temperature"]
116
+ ],
117
+ dtype=np.float32,
118
+ )
119
+ expected_top_k = np.array(
120
+ [
121
+ 10, 20, DEFAULT_SAMPLING_PARAMS["top_k"],
122
+ DEFAULT_SAMPLING_PARAMS["top_k"]
123
+ ],
124
+ dtype=np.int32,
125
+ )
126
+ expected_top_p = np.array(
127
+ [
128
+ 0.9, 0.95, DEFAULT_SAMPLING_PARAMS["top_p"],
129
+ DEFAULT_SAMPLING_PARAMS["top_p"]
130
+ ],
131
+ dtype=np.float32,
132
+ )
133
+
134
+ np.testing.assert_allclose(np.asarray(metadata.temperature), expected_temp)
135
+ np.testing.assert_array_equal(np.asarray(metadata.top_k), expected_top_k)
136
+ np.testing.assert_allclose(np.asarray(metadata.top_p), expected_top_p)
137
+
138
+
139
+ def test_from_input_batch_no_padding_needed(mesh: Mesh):
140
+ """
141
+ Tests the case where `num_reqs` equals `padded_num_reqs`, so **no padding** should occur.
142
+ """
143
+ num_reqs = 4
144
+ padded_num_reqs = 4
145
+
146
+ temp_tensor = np.array([0.7, 0.8, 0.6, 0.5], dtype=np.float32)
147
+ top_k_tensor = np.array([10, 20, 30, 40], dtype=np.int32)
148
+ top_p_tensor = np.array([0.9, 0.95, 0.85, 0.8], dtype=np.float32)
149
+
150
+ mock_batch = MockInputBatch(
151
+ all_greedy=False,
152
+ num_reqs=num_reqs,
153
+ temperature_cpu=temp_tensor,
154
+ top_k_cpu=top_k_tensor,
155
+ top_p_cpu=top_p_tensor,
156
+ )
157
+
158
+ metadata = TPUSupportedSamplingMetadata.from_input_batch(
159
+ mesh=mesh, input_batch=mock_batch, padded_num_reqs=padded_num_reqs)
160
+
161
+ assert metadata.do_sampling
162
+ # Check that values are identical to the input, since no padding was needed
163
+ np.testing.assert_allclose(np.asarray(metadata.temperature), temp_tensor)
164
+ np.testing.assert_array_equal(np.asarray(metadata.top_k), top_k_tensor)
165
+ np.testing.assert_allclose(np.asarray(metadata.top_p), top_p_tensor)
166
+
167
+
168
+ def test_jax_tree_util_registration():
169
+ """
170
+ Tests that the dataclass is correctly registered as a **JAX PyTree**,
171
+ meaning `jax.tree_util` functions can operate on it as expected. 🌳
172
+ """
173
+ metadata = TPUSupportedSamplingMetadata(
174
+ temperature=jnp.array([0.7]),
175
+ top_k=jnp.array([10]),
176
+ top_p=jnp.array([0.9]),
177
+ do_sampling=True,
178
+ )
179
+
180
+ # Flatten the PyTree
181
+ leaves, treedef = jax.tree_util.tree_flatten(metadata)
182
+
183
+ # The leaves should be the "data_fields" specified in the decorator
184
+ assert len(leaves) == 3
185
+ np.testing.assert_array_equal(leaves[0], jnp.array([0.7]))
186
+ np.testing.assert_array_equal(leaves[1], jnp.array([10]))
187
+ np.testing.assert_array_equal(leaves[2], jnp.array([0.9]))
188
+
189
+ # Reconstruct the PyTree from leaves
190
+ new_metadata = jax.tree_util.tree_unflatten(treedef, leaves)
191
+
192
+ # The reconstructed object should match the original
193
+ assert new_metadata.do_sampling == metadata.do_sampling
194
+ np.testing.assert_array_equal(new_metadata.temperature,
195
+ metadata.temperature)
196
+ np.testing.assert_array_equal(new_metadata.top_k, metadata.top_k)
197
+ np.testing.assert_array_equal(new_metadata.top_p, metadata.top_p)
198
+
199
+
200
+ def test_from_input_batch_with_logprobs(mesh: Mesh):
201
+ """
202
+ Tests that the `logprobs` flag is correctly set based on `max_num_logprobs`.
203
+ """
204
+ # Case 1: Logprobs are requested
205
+ mock_batch_with_logprobs = MockInputBatch(all_greedy=True,
206
+ max_num_logprobs=5)
207
+ metadata_with = TPUSupportedSamplingMetadata.from_input_batch(
208
+ mesh=mesh,
209
+ input_batch=mock_batch_with_logprobs,
210
+ padded_num_reqs=4,
211
+ )
212
+ assert metadata_with.logprobs, "logprobs should be True when max_num_logprobs > 0"
213
+
214
+ # Case 2: Logprobs are not requested (max_num_logprobs is 0)
215
+ mock_batch_no_logprobs_zero = MockInputBatch(all_greedy=True,
216
+ max_num_logprobs=0)
217
+ metadata_without_zero = TPUSupportedSamplingMetadata.from_input_batch(
218
+ mesh=mesh,
219
+ input_batch=mock_batch_no_logprobs_zero,
220
+ padded_num_reqs=4,
221
+ )
222
+ assert not metadata_without_zero.logprobs, "logprobs should be False when max_num_logprobs is 0"
223
+
224
+ # Case 3: Logprobs are not requested (max_num_logprobs is None)
225
+ mock_batch_no_logprobs_none = MockInputBatch(all_greedy=True,
226
+ max_num_logprobs=None)
227
+ metadata_without_none = TPUSupportedSamplingMetadata.from_input_batch(
228
+ mesh=mesh,
229
+ input_batch=mock_batch_no_logprobs_none,
230
+ padded_num_reqs=4,
231
+ )
232
+ assert not metadata_without_none.logprobs, "logprobs should be False when max_num_logprobs is None"
233
+
234
+
235
+ def test_from_input_batch_sampling_with_logprobs(mesh: Mesh):
236
+ """
237
+ Tests enabling both sampling and logprobs simultaneously.
238
+ """
239
+ num_reqs = 2
240
+ padded_num_reqs = 4
241
+ mock_batch = MockInputBatch(
242
+ all_greedy=False,
243
+ num_reqs=num_reqs,
244
+ temperature_cpu=np.zeros((padded_num_reqs, ), dtype=np.float32),
245
+ top_k_cpu=np.zeros((padded_num_reqs, ), dtype=np.int32),
246
+ top_p_cpu=np.zeros((padded_num_reqs, ), dtype=np.float32),
247
+ max_num_logprobs=10,
248
+ )
249
+
250
+ metadata = TPUSupportedSamplingMetadata.from_input_batch(
251
+ mesh=mesh, input_batch=mock_batch, padded_num_reqs=padded_num_reqs)
252
+
253
+ assert metadata.do_sampling, "do_sampling should be True"
254
+ assert metadata.logprobs, "logprobs should be True"
@@ -0,0 +1,155 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ from flax import nnx
21
+ from jax.sharding import Mesh
22
+
23
+ from tpu_inference.layers.jax.layers import DenseFFW, Embedder, RMSNorm
24
+
25
+
26
+ class TestLayers(unittest.TestCase):
27
+ """Unit test suite for common JAX layer blocks."""
28
+
29
+ def setUp(self):
30
+ """Sets up the testing environment before each test."""
31
+ self.mesh = Mesh(
32
+ np.array(jax.devices()).reshape(1, -1),
33
+ axis_names=(
34
+ "expert",
35
+ "model",
36
+ ),
37
+ )
38
+
39
+ def test_rmsnorm_forward_pass(self):
40
+ """Tests the forward pass of the RMSNorm module."""
41
+ with jax.set_mesh(self.mesh):
42
+ dims = 512
43
+ epsilon = 1e-5
44
+
45
+ norm = RMSNorm(
46
+ dims=dims,
47
+ random_init=True,
48
+ epsilon=epsilon,
49
+ rngs=nnx.Rngs(0),
50
+ dtype=jnp.float32,
51
+ )
52
+
53
+ seq_len = 128
54
+ x = jax.random.normal(jax.random.PRNGKey(42), (seq_len, dims))
55
+
56
+ output = norm(x)
57
+
58
+ self.assertEqual(output.shape, x.shape)
59
+ self.assertEqual(output.dtype, jnp.float32)
60
+
61
+ mean_of_squares = jnp.mean(jnp.square(output), axis=-1)
62
+ self.assertTrue(
63
+ jnp.allclose(mean_of_squares, 1.0, atol=1e-5).all())
64
+
65
+ def test_denseffw_forward_pass(self):
66
+ """Tests the forward pass of the DenseFFW module."""
67
+ with jax.set_mesh(self.mesh):
68
+ hidden_size = 512
69
+ intermediate_size = 2048
70
+
71
+ ffw_layer = DenseFFW(
72
+ random_init=True,
73
+ dtype=jnp.bfloat16,
74
+ hidden_act="silu",
75
+ hidden_size=hidden_size,
76
+ intermediate_size=intermediate_size,
77
+ rngs=nnx.Rngs(0),
78
+ )
79
+
80
+ seq_len = 128
81
+ x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
82
+
83
+ output = ffw_layer(x)
84
+
85
+ self.assertEqual(output.shape, x.shape)
86
+ self.assertEqual(output.dtype, x.dtype)
87
+
88
+ def test_embedder_forward_pass(self):
89
+ """Tests both the encode and decode passes of the Embedder module."""
90
+ with jax.set_mesh(self.mesh):
91
+ hidden_size = 512
92
+ vocab_size = 32000
93
+ dtype = jnp.bfloat16
94
+
95
+ embedder = Embedder(
96
+ vocab_size=vocab_size,
97
+ hidden_size=hidden_size,
98
+ dtype=dtype,
99
+ random_init=True,
100
+ rngs=nnx.Rngs(0),
101
+ )
102
+
103
+ seq_len = 128
104
+ token_ids = jnp.arange(seq_len, dtype=jnp.int32) % vocab_size
105
+ embeddings = embedder(token_ids, decode=False)
106
+ self.assertEqual(embeddings.shape, (seq_len, hidden_size))
107
+ self.assertEqual(embeddings.dtype, dtype)
108
+
109
+ hidden_states = jnp.ones((seq_len, hidden_size),
110
+ dtype=jnp.bfloat16)
111
+ logits = embedder(hidden_states, decode=True)
112
+ self.assertEqual(logits.shape, (seq_len, vocab_size))
113
+ self.assertEqual(logits.dtype, dtype)
114
+
115
+ def test_embedder_normalization(self):
116
+ """Tests the embedding normalization feature."""
117
+ with jax.set_mesh(self.mesh):
118
+ hidden_size = 512
119
+ vocab_size = 32000
120
+
121
+ rngs_1 = nnx.Rngs(42)
122
+ rngs_2 = nnx.Rngs(42)
123
+
124
+ embedder_norm = Embedder(
125
+ vocab_size=vocab_size,
126
+ hidden_size=hidden_size,
127
+ dtype=jnp.float32,
128
+ normalize_embeddings=True,
129
+ random_init=True,
130
+ rngs=rngs_1,
131
+ )
132
+
133
+ embedder_no_norm = Embedder(
134
+ vocab_size=vocab_size,
135
+ hidden_size=hidden_size,
136
+ dtype=jnp.float32,
137
+ random_init=True,
138
+ rngs=rngs_2,
139
+ )
140
+
141
+ token_ids = jnp.arange(10, dtype=jnp.int32)
142
+
143
+ embeddings_norm = embedder_norm(token_ids, decode=False)
144
+ embeddings_no_norm = embedder_no_norm(token_ids, decode=False)
145
+
146
+ scaling_factor = jnp.sqrt(hidden_size)
147
+ expected_embeddings = embeddings_no_norm * scaling_factor
148
+
149
+ self.assertTrue(
150
+ jnp.allclose(embeddings_norm, expected_embeddings,
151
+ atol=1e-6).all())
152
+
153
+
154
+ if __name__ == "__main__":
155
+ unittest.main()