tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,93 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ from jax import numpy as jnp
17
+ from jax._src import test_util as jtu
18
+ from jax.sharding import Mesh
19
+
20
+ from tpu_inference.layers.jax.rope import (DeepseekScalingRotaryEmbedding,
21
+ RotaryEmbedding)
22
+
23
+
24
+ class RotaryEmbeddingTest(jtu.JaxTestCase):
25
+
26
+ def test_apply_rope(self):
27
+ head_dim = 2
28
+ rope_theta = 10000
29
+ original_max_position_embeddings = 2
30
+ rope = RotaryEmbedding(
31
+ rotary_dim=head_dim,
32
+ rope_theta=rope_theta,
33
+ original_max_position_embeddings=original_max_position_embeddings,
34
+ dtype=jnp.float32)
35
+ rope.initialize_cache()
36
+ self.assertTrue(
37
+ rope.sin_cos_cache.shape == (original_max_position_embeddings,
38
+ head_dim))
39
+ expected_sin_cos = jnp.array([[1, 0], [0.5403023, 0.841471]],
40
+ dtype=jnp.float32)
41
+ self.assertArraysAllClose(rope.sin_cos_cache, expected_sin_cos)
42
+
43
+ num_tokens = 2
44
+ num_heads = 1
45
+ positions = jnp.arange(num_tokens)
46
+ x = jnp.ones((num_tokens, num_heads, head_dim))
47
+ x_rope = rope.apply_rope(positions, x)
48
+ expected_x_rope = jnp.array([[[1, 1]], [[-0.30116874, 1.3817732]]],
49
+ dtype=jnp.float32)
50
+ self.assertTrue(x_rope.shape == x.shape)
51
+ self.assertArraysAllClose(x_rope, expected_x_rope)
52
+
53
+
54
+ class DeepseekScalingRotaryEmbeddingTest(jtu.JaxTestCase):
55
+
56
+ def test_apply_rope(self):
57
+ head_dim = 2
58
+ rope_theta = 10000
59
+ original_max_position_embeddings = 1
60
+ scaling_factor = 2
61
+ devices = jax.devices()
62
+ mesh = Mesh(devices, ('data', ))
63
+
64
+ rope = DeepseekScalingRotaryEmbedding(
65
+ rotary_dim=head_dim,
66
+ rope_theta=rope_theta,
67
+ original_max_position_embeddings=original_max_position_embeddings,
68
+ scaling_factor=scaling_factor,
69
+ dtype=jnp.float32)
70
+ rope.initialize_cache(mesh)
71
+ expected_padded_dim = 128
72
+ self.assertTrue(
73
+ rope.sin_cos_cache.shape == (scaling_factor *
74
+ original_max_position_embeddings,
75
+ expected_padded_dim))
76
+
77
+ valid_cache_slice = rope.sin_cos_cache[:, :head_dim]
78
+
79
+ expected_sin_cos = jnp.array([[1.0693147, 0], [0.5777532, 0.8997973]],
80
+ dtype=jnp.float32)
81
+
82
+ self.assertArraysAllClose(valid_cache_slice, expected_sin_cos)
83
+
84
+ num_tokens = 2
85
+ num_heads = 1
86
+ positions = jnp.arange(num_tokens)
87
+ x = jnp.ones((num_tokens, num_heads, head_dim))
88
+ x_rope = rope.apply_rope(positions, x)
89
+ expected_x_rope = jnp.array(
90
+ [[[1.0693147, 1.0693147]], [[-0.32204413, 1.4775505]]],
91
+ dtype=jnp.float32)
92
+ self.assertTrue(x_rope.shape == x.shape)
93
+ self.assertArraysAllClose(x_rope, expected_x_rope)
@@ -0,0 +1,159 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ from unittest.mock import MagicMock
17
+
18
+ import jax
19
+
20
+ from tpu_inference.layers.common.sharding import (Sharding, ShardingConfig,
21
+ ShardingRulesConfig,
22
+ ShardingStrategy)
23
+
24
+
25
+ class TestSharding(unittest.TestCase):
26
+ """Unit test suite for the sharding configuration logic."""
27
+
28
+ def setUp(self):
29
+ """Sets up the testing environment before each test."""
30
+
31
+ self.mock_devices = [MagicMock(coords=i) for i in range(8)]
32
+ self.original_jax_devices = jax.devices
33
+ jax.devices = lambda: self.mock_devices
34
+
35
+ def tearDown(self):
36
+ """Restores the original jax.devices function after tests."""
37
+ jax.devices = self.original_jax_devices
38
+
39
+ def test_sharding_strategy_init(self):
40
+ """Tests the initialization of the ShardingStrategy."""
41
+ strategy = ShardingStrategy(
42
+ tensor_parallelism=2,
43
+ expert_parallelism=4,
44
+ data_parallelism=1,
45
+ sequence_parallelism=1,
46
+ )
47
+ self.assertEqual(strategy.tensor_parallelism, 2)
48
+ self.assertEqual(strategy.expert_parallelism, 4)
49
+
50
+ def test_sharding_config_init(self):
51
+ """Tests the initialization of ShardingConfig."""
52
+ config = ShardingConfig()
53
+ self.assertIsInstance(config.prefill_rules, ShardingRulesConfig)
54
+ self.assertIsInstance(config.generate_rules, ShardingRulesConfig)
55
+
56
+ custom_rules = ShardingRulesConfig(activation_ffw_td=("model", None))
57
+ config_with_rules = ShardingConfig(prefill_rules=custom_rules)
58
+ self.assertEqual(config_with_rules.prefill_rules.activation_ffw_td,
59
+ ("model", None))
60
+
61
+ def test_apply_overrides(self):
62
+ """Tests the _apply_overrides method for valid and invalid keys."""
63
+ sharding = Sharding(
64
+ prefill_rules={},
65
+ generate_rules={},
66
+ )
67
+ config_obj = ShardingRulesConfig()
68
+
69
+ valid_overrides = {"activation_ffw_td": ("model", None)}
70
+ sharding._apply_overrides(config_obj, valid_overrides)
71
+ self.assertEqual(config_obj.activation_ffw_td, ("model", None))
72
+
73
+ invalid_overrides = {"non_existent_attribute": (None, "model")}
74
+ with self.assertRaises(AttributeError):
75
+ sharding._apply_overrides(config_obj, invalid_overrides)
76
+
77
+ def test_default_sharding_config(self):
78
+ """Tests that default sharding rules are created correctly."""
79
+ sharding = Sharding(
80
+ prefill_rules={},
81
+ generate_rules={},
82
+ )
83
+
84
+ sharding_cfg = sharding.get_sharding_cfg()
85
+ generate_rules = sharding_cfg.generate_rules
86
+
87
+ self.assertEqual(generate_rules.ffw_weight_df, (None, "model"))
88
+ self.assertEqual(generate_rules.moe_router_de, (None, "model"))
89
+ self.assertEqual(generate_rules.attn_q_weight_dnh,
90
+ (None, "model", None))
91
+
92
+ def test_sharding_init_with_overrides(self):
93
+ """Tests Sharding initialization with programmatic overrides."""
94
+ generate_overrides = {"logits_tv": ("data", "model")}
95
+
96
+ sharding = Sharding(
97
+ generate_rules=generate_overrides,
98
+ prefill_rules={},
99
+ )
100
+
101
+ sharding_cfg = sharding.get_sharding_cfg()
102
+ self.assertNotEqual(sharding_cfg.generate_rules.logits_tv,
103
+ (None, "model"))
104
+ self.assertEqual(sharding_cfg.generate_rules.logits_tv,
105
+ ("data", "model"))
106
+
107
+ def test_get_overrides_from_vllm_config(self):
108
+ """Tests fetching sharding overrides from a mock VllmConfig."""
109
+
110
+ mock_vllm_config_prefill = MagicMock()
111
+ mock_vllm_config_prefill.additional_config = {
112
+ "sharding": {
113
+ "logical_rules": {
114
+ "all": {
115
+ "norm_scale": ("model", )
116
+ },
117
+ "prefill": {
118
+ "activation_ffw_td": ("data", "model")
119
+ },
120
+ }
121
+ }
122
+ }
123
+ sharding_prefill = Sharding(
124
+ vllm_config=mock_vllm_config_prefill,
125
+ prefill_rules={},
126
+ generate_rules={},
127
+ )
128
+ prefill_overrides = sharding_prefill._get_overrides("prefill")
129
+
130
+ self.assertEqual(prefill_overrides["norm_scale"], ("model", ))
131
+ self.assertEqual(prefill_overrides["activation_ffw_td"],
132
+ ("data", "model"))
133
+
134
+ mock_vllm_config_generate = MagicMock()
135
+ mock_vllm_config_generate.additional_config = {
136
+ "sharding": {
137
+ "logical_rules": {
138
+ "all": {
139
+ "norm_scale": ("model", )
140
+ },
141
+ "prefill": {
142
+ "activation_ffw_td": ("data", "model")
143
+ },
144
+ }
145
+ }
146
+ }
147
+ sharding_generate = Sharding(
148
+ vllm_config=mock_vllm_config_generate,
149
+ prefill_rules={},
150
+ generate_rules={},
151
+ )
152
+ generate_overrides = sharding_generate._get_overrides("generate")
153
+
154
+ self.assertEqual(generate_overrides["norm_scale"], ("model", ))
155
+ self.assertNotIn("activation_ffw_td", generate_overrides)
156
+
157
+
158
+ if __name__ == "__main__":
159
+ unittest.main()
@@ -0,0 +1,152 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ from unittest.mock import MagicMock
17
+
18
+ import jax.numpy as jnp
19
+ from flax import nnx
20
+
21
+ from tpu_inference.layers.jax.layers import DenseFFW
22
+ from tpu_inference.layers.jax.moe.moe import MoE
23
+ from tpu_inference.layers.jax.transformer_block import (
24
+ SharedExpertsTransformerBlock, TransformerBlock)
25
+
26
+
27
+ class TestTransformerBlock(unittest.TestCase):
28
+ """Unit test suite for the JAX TransformerBlock module."""
29
+
30
+ def test_transformer_block_dense_logic(self):
31
+ """
32
+ Tests the forward pass logic of a dense TransformerBlock by mocking its sub-modules.
33
+ This test verifies the sequence of operations and residual connections.
34
+ """
35
+ hidden_size = 1024
36
+
37
+ mock_pre_attn_norm = MagicMock(spec=nnx.Module)
38
+ mock_pre_mlp_norm = MagicMock(spec=nnx.Module)
39
+
40
+ mock_attn = MagicMock(spec=nnx.Module)
41
+ dummy_attn_output = jnp.full((64, hidden_size),
42
+ 2.0,
43
+ dtype=jnp.bfloat16)
44
+ dummy_kv_cache = jnp.zeros((8, 16, 16, 128), dtype=jnp.bfloat16)
45
+ mock_attn.return_value = (dummy_kv_cache, dummy_attn_output)
46
+
47
+ mock_mlp = MagicMock(spec=DenseFFW)
48
+ dummy_mlp_output = jnp.full((64, hidden_size), 3.0, dtype=jnp.bfloat16)
49
+ mock_mlp.return_value = dummy_mlp_output
50
+
51
+ transformer_block = TransformerBlock(
52
+ pre_attention_norm=mock_pre_attn_norm,
53
+ pre_mlp_norm=mock_pre_mlp_norm,
54
+ custom_module=mock_mlp,
55
+ attn=mock_attn,
56
+ )
57
+
58
+ seq_len = 64
59
+ x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
60
+ initial_kv_cache = MagicMock()
61
+ attention_metadata = MagicMock()
62
+
63
+ mock_pre_attn_norm.side_effect = lambda val: val
64
+ mock_pre_mlp_norm.side_effect = lambda val: val
65
+
66
+ new_kv_cache, final_output = transformer_block(
67
+ x,
68
+ is_prefill=True,
69
+ kv_cache=initial_kv_cache,
70
+ attention_metadata=attention_metadata,
71
+ )
72
+
73
+ mock_pre_attn_norm.assert_called_once()
74
+ self.assertTrue(
75
+ jnp.array_equal(mock_pre_attn_norm.call_args.args[0], x))
76
+
77
+ mock_attn.assert_called_once_with(x, True, initial_kv_cache,
78
+ attention_metadata, True)
79
+
80
+ expected_mlp_norm_input = dummy_attn_output + x
81
+
82
+ mock_pre_mlp_norm.assert_called_once()
83
+ self.assertTrue(
84
+ jnp.array_equal(mock_pre_mlp_norm.call_args.args[0],
85
+ expected_mlp_norm_input))
86
+
87
+ mock_mlp.assert_called_once()
88
+ self.assertTrue(
89
+ jnp.array_equal(mock_mlp.call_args.args[0],
90
+ expected_mlp_norm_input))
91
+
92
+ expected_final_output = dummy_mlp_output + expected_mlp_norm_input
93
+ self.assertTrue(jnp.allclose(final_output, expected_final_output))
94
+
95
+ self.assertTrue(jnp.array_equal(new_kv_cache, dummy_kv_cache))
96
+
97
+ def test_shared_experts_transformer_block_logic(self):
98
+ """Tests the forward pass logic of a SharedExpertsTransformerBlock."""
99
+ hidden_size = 1024
100
+
101
+ mock_pre_attn_norm = MagicMock(spec=nnx.Module)
102
+ mock_pre_mlp_norm = MagicMock(spec=nnx.Module)
103
+
104
+ mock_attn = MagicMock(spec=nnx.Module)
105
+ dummy_attn_output = jnp.full((64, hidden_size),
106
+ 2.0,
107
+ dtype=jnp.bfloat16)
108
+ dummy_kv_cache = jnp.zeros((8, 16, 16, 128), dtype=jnp.bfloat16)
109
+ mock_attn.return_value = (dummy_kv_cache, dummy_attn_output)
110
+
111
+ mock_moe = MagicMock(spec=MoE)
112
+ dummy_moe_output = jnp.full((64, hidden_size), 3.0, dtype=jnp.bfloat16)
113
+ mock_moe.return_value = dummy_moe_output
114
+
115
+ mock_shared_experts = MagicMock(spec=DenseFFW)
116
+ dummy_shared_experts_output = jnp.full((64, hidden_size),
117
+ 4.0,
118
+ dtype=jnp.bfloat16)
119
+ mock_shared_experts.return_value = dummy_shared_experts_output
120
+
121
+ transformer_block = SharedExpertsTransformerBlock(
122
+ pre_attention_norm=mock_pre_attn_norm,
123
+ pre_mlp_norm=mock_pre_mlp_norm,
124
+ custom_module=mock_moe,
125
+ attn=mock_attn,
126
+ shared_experts=mock_shared_experts,
127
+ )
128
+
129
+ seq_len = 64
130
+ x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
131
+ initial_kv_cache = MagicMock()
132
+ attention_metadata = MagicMock()
133
+
134
+ mock_pre_attn_norm.side_effect = lambda val: val
135
+ mock_pre_mlp_norm.side_effect = lambda val: val
136
+
137
+ new_kv_cache, final_output = transformer_block(
138
+ x,
139
+ is_prefill=True,
140
+ kv_cache=initial_kv_cache,
141
+ attention_metadata=attention_metadata,
142
+ )
143
+ self.assertTrue(jnp.array_equal(new_kv_cache, dummy_kv_cache))
144
+ self.assertEqual(final_output.shape, (seq_len, hidden_size))
145
+
146
+ self.assertEqual(mock_moe.call_count, 1)
147
+ self.assertEqual(mock_attn.call_count, 1)
148
+ self.assertEqual(mock_shared_experts.call_count, 1)
149
+
150
+
151
+ if __name__ == "__main__":
152
+ unittest.main()
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.