tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +46 -17
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +44 -17
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
  240. tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,135 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import unittest
17
+ from dataclasses import dataclass, field
18
+
19
+ import chex
20
+
21
+ os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax import nnx
26
+ from jax.sharding import NamedSharding
27
+ from jax.sharding import PartitionSpec as P
28
+
29
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
30
+ from tpu_inference.layers.common.sharding import build_mesh
31
+ from tpu_inference.layers.jax.attention.llama4_attention import (
32
+ L2Norm, Llama4Attention)
33
+
34
+
35
+ @dataclass
36
+ class SimpleVLLMConfig:
37
+ additional_config: dict = field(default_factory=dict)
38
+
39
+
40
+ class Llama4AttentionTest(unittest.TestCase):
41
+ """Unit test suite for Llama4-specific attention components."""
42
+
43
+ def setUp(self):
44
+ devices = jax.devices()[:1]
45
+ sharding_strategy = {"tensor_parallelism": len(devices)}
46
+ self.mesh = build_mesh(devices, sharding_strategy)
47
+
48
+ def test_l2norm_forward_pass(self):
49
+ """Tests the forward pass of the L2Norm module with hardcoded values."""
50
+ eps = 1e-5
51
+ l2_norm = L2Norm(eps=eps)
52
+ x = jnp.array([[1.0, 2.0, 3.0, 4.0]], dtype=jnp.float32)
53
+
54
+ output = l2_norm(x)
55
+
56
+ self.assertEqual(output.shape, x.shape)
57
+ self.assertEqual(output.dtype, x.dtype)
58
+
59
+ # Expected values calculated manually:
60
+ # mean_sq = (1^2 + 2^2 + 3^2 + 4^2) / 4 = (1+4+9+16)/4 = 30/4 = 7.5
61
+ # norm_val = sqrt(7.5 + 1e-5)
62
+ # expected = x / norm_val
63
+ expected_output = jnp.array([[0.365148, 0.730297, 1.095445, 1.460594]],
64
+ dtype=jnp.float32)
65
+ self.assertTrue(jnp.allclose(output, expected_output, atol=1e-6))
66
+
67
+ def test_l2norm_with_zeros(self):
68
+ """Tests L2Norm with an all-zero input."""
69
+ l2_norm = L2Norm(eps=1e-5)
70
+ x = jnp.zeros((4, 8, 16))
71
+ output = l2_norm(x)
72
+ self.assertEqual(output.shape, x.shape)
73
+ # Output should be all zeros.
74
+ self.assertTrue(jnp.all(output == 0))
75
+
76
+ def test_l2norm_eps_effect(self):
77
+ """Tests the effect of the epsilon value in L2Norm."""
78
+ eps = 1e-3
79
+ l2_norm = L2Norm(eps=eps)
80
+ x = jax.random.normal(jax.random.PRNGKey(0), (1, 1, 128))
81
+ output = l2_norm(x)
82
+
83
+ mean_sq = jnp.mean(x**2, axis=-1, keepdims=True)
84
+ expected_output = x * jax.lax.rsqrt(mean_sq + eps)
85
+
86
+ self.assertTrue(jnp.allclose(output, expected_output))
87
+
88
+ def test_apply_temperature_tuning(self):
89
+ with jax.set_mesh(self.mesh):
90
+ hidden_size = 64
91
+ num_attention_heads = 4
92
+ head_dim = hidden_size // num_attention_heads
93
+
94
+ # Create dummy sharding objects
95
+ dummy_sharding = NamedSharding(self.mesh, P())
96
+
97
+ llama4_attention = Llama4Attention(
98
+ hidden_size=hidden_size,
99
+ num_attention_heads=num_attention_heads,
100
+ num_key_value_heads=num_attention_heads,
101
+ head_dim=head_dim,
102
+ rope_theta=10000.0,
103
+ rope_scaling={},
104
+ dtype=jnp.bfloat16,
105
+ kv_cache_dtype="auto",
106
+ use_qk_norm=False,
107
+ temperature_tuning=True,
108
+ temperature_tuning_scale=2.0,
109
+ temperature_tuning_floor_scale=2.0,
110
+ mesh=self.mesh,
111
+ random_init=True,
112
+ activation_attention_td=dummy_sharding,
113
+ activation_attention_out_td=dummy_sharding,
114
+ rngs=nnx.Rngs(42),
115
+ )
116
+
117
+ seq_len = 8
118
+ input_arr_TNH = jnp.ones((seq_len, num_attention_heads, head_dim),
119
+ dtype=jnp.bfloat16)
120
+ attention_metadata = AttentionMetadata(
121
+ input_positions=jnp.arange(seq_len, dtype=jnp.int32))
122
+ expected_scales = jnp.array(
123
+ [1, 2.375, 2.375, 3.20312, 3.20312, 3.76562, 3.76562, 4.21875],
124
+ dtype=jnp.bfloat16)
125
+ output = llama4_attention.apply_temperature_tuning(
126
+ attention_metadata, input_arr_TNH)
127
+ chex.assert_shape(output, (seq_len, num_attention_heads, head_dim))
128
+
129
+ expected_output = jnp.ones_like(
130
+ input_arr_TNH) * expected_scales[:, None, None]
131
+ chex.assert_trees_all_close(output, expected_output, atol=1e-3)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ unittest.main()
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,235 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ from flax import nnx
21
+ from jax.sharding import Mesh, PartitionSpec
22
+
23
+ from tpu_inference.layers.jax.moe.deepseek_v3_moe import (DeepSeekV3Router,
24
+ SparseMoE)
25
+
26
+
27
+ class TestDeepSeekV3Router(unittest.TestCase):
28
+
29
+ def setUp(self):
30
+ self.cpu_mesh = Mesh(jax.devices('cpu'), axis_names=('data', ))
31
+
32
+ def test_get_topk_indices_single_group(self):
33
+ """Test get_topk_indices with single expert group."""
34
+ with jax.set_mesh(self.cpu_mesh):
35
+ router = DeepSeekV3Router(random_init=True,
36
+ hidden_size=512,
37
+ num_experts=4,
38
+ num_experts_per_tok=2,
39
+ n_groups=1,
40
+ topk_groups=1,
41
+ norm_topk_prob=True,
42
+ routed_scaling_factor=1.0,
43
+ dtype=jnp.bfloat16,
44
+ rngs=nnx.Rngs(42))
45
+ router.bias_E = jnp.zeros((4, ))
46
+
47
+ scores = jnp.array([[0.1, 0.3, 0.2, 0.4]]) # shape: (1, 4)
48
+ indices = router.get_topk_indices(scores)
49
+
50
+ # Should return indices of top 2 experts
51
+ expected_indices = jnp.array([[3,
52
+ 1]]) # experts with scores 0.4, 0.3
53
+ self.assertTrue(jnp.array_equal(indices, expected_indices))
54
+
55
+ def test_get_topk_indices_2_groups(self):
56
+ """Test get_topk_indices with 2 expert groups."""
57
+ with jax.set_mesh(self.cpu_mesh):
58
+ router = DeepSeekV3Router(random_init=True,
59
+ hidden_size=512,
60
+ num_experts=4,
61
+ num_experts_per_tok=2,
62
+ n_groups=2,
63
+ topk_groups=1,
64
+ norm_topk_prob=True,
65
+ routed_scaling_factor=1.0,
66
+ dtype=jnp.bfloat16,
67
+ rngs=nnx.Rngs(42))
68
+ router.bias_E = jnp.zeros((4, ))
69
+
70
+ # 4 experts, 2 groups, 2 experts per group
71
+ scores = jnp.array([[[0.1, 0.3, 0.2, 0.4]]]) # shape: (1, 1, 4)
72
+ indices = router.get_topk_indices(scores)
73
+
74
+ # Should return indices of top 2 experts
75
+ expected_indices = jnp.array([[[3, 2]]])
76
+ self.assertTrue(jnp.array_equal(indices, expected_indices))
77
+
78
+ def test_router_e2e(self):
79
+ with jax.set_mesh(self.cpu_mesh):
80
+ router = DeepSeekV3Router(random_init=True,
81
+ hidden_size=512,
82
+ num_experts=8,
83
+ num_experts_per_tok=2,
84
+ n_groups=2,
85
+ topk_groups=1,
86
+ norm_topk_prob=True,
87
+ routed_scaling_factor=1.0,
88
+ dtype=jnp.bfloat16,
89
+ rngs=nnx.Rngs(42))
90
+ x = jnp.ones((2, 512))
91
+ weights, indices = router(x)
92
+ self.assertEqual(weights.shape, (2, 2))
93
+ self.assertEqual(indices.shape, (2, 2))
94
+
95
+
96
+ class TestSparseMoE(unittest.TestCase):
97
+
98
+ def setUp(self):
99
+ """Set up a multi-device mesh and a sample MoE layer for testing."""
100
+ devices = jax.devices()
101
+ self.device_count = len(devices)
102
+ if self.device_count < 8:
103
+ self.skipTest("This test requires at least 8 simulated devices.")
104
+
105
+ # This mesh will have a 'model' axis for expert parallelism
106
+ mesh_shape = (self.device_count, 1)
107
+ device_mesh_array = np.array(devices).reshape(mesh_shape)
108
+
109
+ # Define the axis names
110
+ axis_names = ('model', 'data')
111
+
112
+ # Create the 2D mesh
113
+ self.mesh = Mesh(device_mesh_array, axis_names=axis_names)
114
+
115
+ # --- Model Configuration ---
116
+ self.B, self.S, self.D = 2, 4, 16 # Batch, Sequence, Hidden Dim
117
+ self.E, self.K = 16, 8 # Num Experts, Experts per Token
118
+ self.moe_intermediate_size = 32 # FFN Dim
119
+ self.num_expert_parallelism = 8 # Shard experts across 8 devices
120
+
121
+ self.key = jax.random.PRNGKey(42)
122
+ self.x = jax.random.normal(self.key, (self.B * self.S, self.D),
123
+ dtype=jnp.bfloat16)
124
+
125
+ # --- Instantiate MoE Layer ---
126
+ # We need to do this inside the mesh context
127
+ with self.mesh:
128
+ router = DeepSeekV3Router(hidden_size=self.D,
129
+ num_experts=self.E,
130
+ num_experts_per_tok=self.K,
131
+ n_groups=1,
132
+ topk_groups=1,
133
+ norm_topk_prob=False,
134
+ routed_scaling_factor=1.0,
135
+ dtype=jnp.bfloat16,
136
+ rngs=nnx.Rngs(self.key),
137
+ ed_sharding=PartitionSpec(),
138
+ e_sharding=PartitionSpec(),
139
+ activation_ffw_td=PartitionSpec(
140
+ 'data', None))
141
+ # Instantiation updated to match user's code snippet
142
+ self.moe = SparseMoE(
143
+ hidden_size=self.D,
144
+ intermediate_size_moe=self.moe_intermediate_size,
145
+ num_local_experts=self.E,
146
+ hidden_act="silu",
147
+ num_experts_per_tok=self.K,
148
+ router=router,
149
+ dtype=jnp.bfloat16,
150
+ rngs=nnx.Rngs(self.key),
151
+ mesh=self.mesh,
152
+ apply_expert_weight_before_computation=False,
153
+
154
+ # Sharding specs updated based on user's snippet
155
+ edf_sharding=PartitionSpec('model', None, None),
156
+ efd_sharding=PartitionSpec('model', None, None),
157
+ activation_ffw_ted=PartitionSpec('data', None),
158
+ activation_ffw_td=PartitionSpec(
159
+ 'data', None) # Activations are replicated
160
+ )
161
+
162
+ def test_token_replicated_expert_parallel_fwd(self):
163
+ """
164
+ Validates the MoE forward pass against a simple, dense equivalent.
165
+ This specifically tests the is_batch_sharded_by_expert=False path.
166
+ """
167
+ # --- 1. Get the ACTUAL output from the complex distributed MoE layer ---
168
+ # The __call__ method will trigger the shard_map, which requires the mesh context.
169
+ with self.mesh:
170
+ actual_output = self.moe(self.x)
171
+
172
+ # --- 2. Calculate the EXPECTED output using a simple, sequential process ---
173
+ # This serves as the "ground truth".
174
+
175
+ # Get router decisions (router params are replicated, so this is fine)
176
+ router_weights, selected_experts = self.moe.router(self.x)
177
+
178
+ # Gather the full, unsharded weights from all devices ---
179
+ # .value on a sharded param gives the *local* shard.
180
+ # jax.device_get() retrieves the *full* GlobalDeviceArray to the host.
181
+ gating_kernel_full = jax.device_get(self.moe.kernel_gating_EDF.value)
182
+ up_proj_kernel_full = jax.device_get(self.moe.kernel_up_proj_EDF.value)
183
+ down_proj_kernel_full = jax.device_get(
184
+ self.moe.kernel_down_proj_EFD.value)
185
+
186
+ # Check that we really got the full weights
187
+ self.assertEqual(gating_kernel_full.shape,
188
+ (self.E, self.D, self.moe_intermediate_size))
189
+
190
+ # Flatten inputs for easier iteration
191
+ flat_x = self.x.reshape(self.B * self.S, self.D)
192
+ flat_weights = router_weights.reshape(self.B * self.S, self.K)
193
+ flat_experts = selected_experts.reshape(self.B * self.S, self.K)
194
+
195
+ expected_output = jnp.zeros_like(flat_x)
196
+
197
+ # Manually apply each expert to each token sequentially
198
+ for i in range(self.B * self.S): # For each token
199
+ token_input = flat_x[i]
200
+ combined_expert_output = jnp.zeros(self.D, dtype=jnp.bfloat16)
201
+
202
+ for k in range(self.K): # For each chosen expert for that token
203
+ expert_idx = flat_experts[i, k]
204
+ weight = flat_weights[i, k]
205
+
206
+ # Get kernels from the *full* gathered arrays ---
207
+ gating_kernel = gating_kernel_full[expert_idx]
208
+ up_proj_kernel = up_proj_kernel_full[expert_idx]
209
+ down_proj_kernel = down_proj_kernel_full[expert_idx]
210
+
211
+ # Perform the expert computation (dense matmuls)
212
+ gating_proj = jnp.dot(token_input, gating_kernel)
213
+ up_proj = jnp.dot(token_input, up_proj_kernel)
214
+
215
+ # Note: Assuming 'silu' activation as specified in MoE init
216
+ fused = nnx.silu(gating_proj) * up_proj
217
+
218
+ expert_output = jnp.dot(fused, down_proj_kernel)
219
+
220
+ # Apply router weight after computation (matches implementation)
221
+ combined_expert_output += weight * expert_output
222
+
223
+ expected_output = expected_output.at[i].set(combined_expert_output)
224
+
225
+ expected_output = expected_output.reshape(self.B * self.S, self.D)
226
+
227
+ # --- 3. Compare the results ---
228
+ self.assertTrue(
229
+ jnp.allclose(actual_output, expected_output, atol=1e-2, rtol=1e-2),
230
+ f"The output of the distributed MoE does not match the dense equivalent.\n"
231
+ f"Actual:\n{actual_output}\n"
232
+ f"Expected:\n{expected_output}")
233
+ print(
234
+ "\n✅ Test Passed: Distributed MoE output matches the dense ground truth."
235
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.