tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,235 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ from flax import nnx
21
+ from jax.sharding import Mesh, PartitionSpec
22
+
23
+ from tpu_inference.layers.jax.moe.deepseek_v3_moe import (DeepSeekV3Router,
24
+ SparseMoE)
25
+
26
+
27
+ class TestDeepSeekV3Router(unittest.TestCase):
28
+
29
+ def setUp(self):
30
+ self.cpu_mesh = Mesh(jax.devices('cpu'), axis_names=('data', ))
31
+
32
+ def test_get_topk_indices_single_group(self):
33
+ """Test get_topk_indices with single expert group."""
34
+ with jax.set_mesh(self.cpu_mesh):
35
+ router = DeepSeekV3Router(random_init=True,
36
+ hidden_size=512,
37
+ num_experts=4,
38
+ num_experts_per_tok=2,
39
+ n_groups=1,
40
+ topk_groups=1,
41
+ norm_topk_prob=True,
42
+ routed_scaling_factor=1.0,
43
+ dtype=jnp.bfloat16,
44
+ rngs=nnx.Rngs(42))
45
+ router.bias_E = jnp.zeros((4, ))
46
+
47
+ scores = jnp.array([[0.1, 0.3, 0.2, 0.4]]) # shape: (1, 4)
48
+ indices = router.get_topk_indices(scores)
49
+
50
+ # Should return indices of top 2 experts
51
+ expected_indices = jnp.array([[3,
52
+ 1]]) # experts with scores 0.4, 0.3
53
+ self.assertTrue(jnp.array_equal(indices, expected_indices))
54
+
55
+ def test_get_topk_indices_2_groups(self):
56
+ """Test get_topk_indices with 2 expert groups."""
57
+ with jax.set_mesh(self.cpu_mesh):
58
+ router = DeepSeekV3Router(random_init=True,
59
+ hidden_size=512,
60
+ num_experts=4,
61
+ num_experts_per_tok=2,
62
+ n_groups=2,
63
+ topk_groups=1,
64
+ norm_topk_prob=True,
65
+ routed_scaling_factor=1.0,
66
+ dtype=jnp.bfloat16,
67
+ rngs=nnx.Rngs(42))
68
+ router.bias_E = jnp.zeros((4, ))
69
+
70
+ # 4 experts, 2 groups, 2 experts per group
71
+ scores = jnp.array([[[0.1, 0.3, 0.2, 0.4]]]) # shape: (1, 1, 4)
72
+ indices = router.get_topk_indices(scores)
73
+
74
+ # Should return indices of top 2 experts
75
+ expected_indices = jnp.array([[[3, 2]]])
76
+ self.assertTrue(jnp.array_equal(indices, expected_indices))
77
+
78
+ def test_router_e2e(self):
79
+ with jax.set_mesh(self.cpu_mesh):
80
+ router = DeepSeekV3Router(random_init=True,
81
+ hidden_size=512,
82
+ num_experts=8,
83
+ num_experts_per_tok=2,
84
+ n_groups=2,
85
+ topk_groups=1,
86
+ norm_topk_prob=True,
87
+ routed_scaling_factor=1.0,
88
+ dtype=jnp.bfloat16,
89
+ rngs=nnx.Rngs(42))
90
+ x = jnp.ones((2, 512))
91
+ weights, indices = router(x)
92
+ self.assertEqual(weights.shape, (2, 2))
93
+ self.assertEqual(indices.shape, (2, 2))
94
+
95
+
96
+ class TestSparseMoE(unittest.TestCase):
97
+
98
+ def setUp(self):
99
+ """Set up a multi-device mesh and a sample MoE layer for testing."""
100
+ devices = jax.devices()
101
+ self.device_count = len(devices)
102
+ if self.device_count < 8:
103
+ self.skipTest("This test requires at least 8 simulated devices.")
104
+
105
+ # This mesh will have a 'model' axis for expert parallelism
106
+ mesh_shape = (self.device_count, 1)
107
+ device_mesh_array = np.array(devices).reshape(mesh_shape)
108
+
109
+ # Define the axis names
110
+ axis_names = ('model', 'data')
111
+
112
+ # Create the 2D mesh
113
+ self.mesh = Mesh(device_mesh_array, axis_names=axis_names)
114
+
115
+ # --- Model Configuration ---
116
+ self.B, self.S, self.D = 2, 4, 16 # Batch, Sequence, Hidden Dim
117
+ self.E, self.K = 16, 8 # Num Experts, Experts per Token
118
+ self.moe_intermediate_size = 32 # FFN Dim
119
+ self.num_expert_parallelism = 8 # Shard experts across 8 devices
120
+
121
+ self.key = jax.random.PRNGKey(42)
122
+ self.x = jax.random.normal(self.key, (self.B * self.S, self.D),
123
+ dtype=jnp.bfloat16)
124
+
125
+ # --- Instantiate MoE Layer ---
126
+ # We need to do this inside the mesh context
127
+ with self.mesh:
128
+ router = DeepSeekV3Router(hidden_size=self.D,
129
+ num_experts=self.E,
130
+ num_experts_per_tok=self.K,
131
+ n_groups=1,
132
+ topk_groups=1,
133
+ norm_topk_prob=False,
134
+ routed_scaling_factor=1.0,
135
+ dtype=jnp.bfloat16,
136
+ rngs=nnx.Rngs(self.key),
137
+ ed_sharding=PartitionSpec(),
138
+ e_sharding=PartitionSpec(),
139
+ activation_ffw_td=PartitionSpec(
140
+ 'data', None))
141
+ # Instantiation updated to match user's code snippet
142
+ self.moe = SparseMoE(
143
+ hidden_size=self.D,
144
+ intermediate_size_moe=self.moe_intermediate_size,
145
+ num_local_experts=self.E,
146
+ hidden_act="silu",
147
+ num_experts_per_tok=self.K,
148
+ router=router,
149
+ dtype=jnp.bfloat16,
150
+ rngs=nnx.Rngs(self.key),
151
+ mesh=self.mesh,
152
+ apply_expert_weight_before_computation=False,
153
+
154
+ # Sharding specs updated based on user's snippet
155
+ edf_sharding=PartitionSpec('model', None, None),
156
+ efd_sharding=PartitionSpec('model', None, None),
157
+ activation_ffw_ted=PartitionSpec('data', None),
158
+ activation_ffw_td=PartitionSpec(
159
+ 'data', None) # Activations are replicated
160
+ )
161
+
162
+ def test_token_replicated_expert_parallel_fwd(self):
163
+ """
164
+ Validates the MoE forward pass against a simple, dense equivalent.
165
+ This specifically tests the is_batch_sharded_by_expert=False path.
166
+ """
167
+ # --- 1. Get the ACTUAL output from the complex distributed MoE layer ---
168
+ # The __call__ method will trigger the shard_map, which requires the mesh context.
169
+ with self.mesh:
170
+ actual_output = self.moe(self.x)
171
+
172
+ # --- 2. Calculate the EXPECTED output using a simple, sequential process ---
173
+ # This serves as the "ground truth".
174
+
175
+ # Get router decisions (router params are replicated, so this is fine)
176
+ router_weights, selected_experts = self.moe.router(self.x)
177
+
178
+ # Gather the full, unsharded weights from all devices ---
179
+ # .value on a sharded param gives the *local* shard.
180
+ # jax.device_get() retrieves the *full* GlobalDeviceArray to the host.
181
+ gating_kernel_full = jax.device_get(self.moe.kernel_gating_EDF.value)
182
+ up_proj_kernel_full = jax.device_get(self.moe.kernel_up_proj_EDF.value)
183
+ down_proj_kernel_full = jax.device_get(
184
+ self.moe.kernel_down_proj_EFD.value)
185
+
186
+ # Check that we really got the full weights
187
+ self.assertEqual(gating_kernel_full.shape,
188
+ (self.E, self.D, self.moe_intermediate_size))
189
+
190
+ # Flatten inputs for easier iteration
191
+ flat_x = self.x.reshape(self.B * self.S, self.D)
192
+ flat_weights = router_weights.reshape(self.B * self.S, self.K)
193
+ flat_experts = selected_experts.reshape(self.B * self.S, self.K)
194
+
195
+ expected_output = jnp.zeros_like(flat_x)
196
+
197
+ # Manually apply each expert to each token sequentially
198
+ for i in range(self.B * self.S): # For each token
199
+ token_input = flat_x[i]
200
+ combined_expert_output = jnp.zeros(self.D, dtype=jnp.bfloat16)
201
+
202
+ for k in range(self.K): # For each chosen expert for that token
203
+ expert_idx = flat_experts[i, k]
204
+ weight = flat_weights[i, k]
205
+
206
+ # Get kernels from the *full* gathered arrays ---
207
+ gating_kernel = gating_kernel_full[expert_idx]
208
+ up_proj_kernel = up_proj_kernel_full[expert_idx]
209
+ down_proj_kernel = down_proj_kernel_full[expert_idx]
210
+
211
+ # Perform the expert computation (dense matmuls)
212
+ gating_proj = jnp.dot(token_input, gating_kernel)
213
+ up_proj = jnp.dot(token_input, up_proj_kernel)
214
+
215
+ # Note: Assuming 'silu' activation as specified in MoE init
216
+ fused = nnx.silu(gating_proj) * up_proj
217
+
218
+ expert_output = jnp.dot(fused, down_proj_kernel)
219
+
220
+ # Apply router weight after computation (matches implementation)
221
+ combined_expert_output += weight * expert_output
222
+
223
+ expected_output = expected_output.at[i].set(combined_expert_output)
224
+
225
+ expected_output = expected_output.reshape(self.B * self.S, self.D)
226
+
227
+ # --- 3. Compare the results ---
228
+ self.assertTrue(
229
+ jnp.allclose(actual_output, expected_output, atol=1e-2, rtol=1e-2),
230
+ f"The output of the distributed MoE does not match the dense equivalent.\n"
231
+ f"Actual:\n{actual_output}\n"
232
+ f"Expected:\n{expected_output}")
233
+ print(
234
+ "\n✅ Test Passed: Distributed MoE output matches the dense ground truth."
235
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.