tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,208 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from types import SimpleNamespace
16
+ from unittest.mock import MagicMock, patch
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ import pytest
22
+ from flax import nnx
23
+ from flax.typing import PRNGKey
24
+ from jax.sharding import Mesh
25
+
26
+ from tpu_inference.experimental.llama3_jax_stashed import (Llama3WeightLoader,
27
+ LlamaForCausalLM)
28
+
29
+
30
+ class MockParam:
31
+ """A mock for a parameter used in the Llama model."""
32
+
33
+ def __init__(self, shape=(32, 128)):
34
+ self.value = SimpleNamespace(shape=shape)
35
+ # The sharding spec is accessed during weight loading
36
+ self.sharding = SimpleNamespace(spec=None)
37
+
38
+ # Allow the mock parameter's value to be updated
39
+ def __setattr__(self, name, value):
40
+ if name == "value":
41
+ self.__dict__[name] = value
42
+ else:
43
+ super().__setattr__(name, value)
44
+
45
+
46
+ class MockVllmConfig:
47
+ """A mock VllmConfig sufficient for testing the Llama3 model."""
48
+
49
+ def __init__(self,
50
+ model_name: str,
51
+ random_weights: bool = False,
52
+ tensor_parallelism: int = 1):
53
+ self.model_config = SimpleNamespace(model=model_name,
54
+ dtype="bfloat16",
55
+ hf_overrides={},
56
+ override_generation_config={})
57
+ self.load_config = MagicMock()
58
+ self.additional_config = {
59
+ "random_weights": random_weights,
60
+ "sharding": {
61
+ "sharding_strategy": {
62
+ "tensor_parallelism": tensor_parallelism
63
+ }
64
+ }
65
+ }
66
+
67
+ # NOTE (jacobplatin): we could add a quantized KV cache test, but
68
+ # we'll skip it for now.
69
+ self.cache_config = MagicMock(cache_dtype="auto")
70
+
71
+
72
+ @pytest.fixture(scope="module")
73
+ def mesh():
74
+ """
75
+ Creates a mesh with all required axes for testing.
76
+ FIX: The sharding logic expects 'data', 'model', and 'expert' axes.
77
+ This creates a 3D mesh to satisfy the sharding rules, even on a single device.
78
+ """
79
+ if not jax.devices():
80
+ pytest.skip("No JAX devices available for mesh creation.")
81
+
82
+ devices = np.array(jax.local_devices())
83
+ # Reshape devices into a 3D array to name 3 axes: data, model, and expert.
84
+ # The 'model' and 'expert' axes will have a size of 1.
85
+ num_devices = len(devices)
86
+ device_mesh = devices.reshape((num_devices, 1, 1))
87
+
88
+ with Mesh(device_mesh, axis_names=('data', 'model', 'expert')) as m:
89
+ yield m
90
+
91
+
92
+ @pytest.fixture
93
+ def rng() -> PRNGKey:
94
+ """Provides a reusable JAX PRNGKey."""
95
+ return jax.random.PRNGKey(42)
96
+
97
+
98
+ @pytest.fixture
99
+ def mock_vllm_config_8b() -> MockVllmConfig:
100
+ return MockVllmConfig(model_name="meta-llama/Llama-3-8B")
101
+
102
+
103
+ @pytest.fixture
104
+ def mock_vllm_config_70b() -> MockVllmConfig:
105
+ return MockVllmConfig(model_name="meta-llama/Llama-3-70B-Instruct")
106
+
107
+
108
+ @pytest.fixture
109
+ def mock_vllm_config_unknown() -> MockVllmConfig:
110
+ return MockVllmConfig(model_name="some-other-model")
111
+
112
+
113
+ # --- Test Cases ---
114
+
115
+
116
+ class TestLlamaForCausalLM:
117
+ """Tests for the main LlamaForCausalLM model class."""
118
+
119
+ def test_init_8b_variant(self, mock_vllm_config_8b, rng, mesh):
120
+ """Tests correct parameter detection for the 8B model variant."""
121
+ model = LlamaForCausalLM(mock_vllm_config_8b, rng, mesh)
122
+ assert model.hidden_size == 4096
123
+ assert "8b" in model.vllm_config.model_config.model.lower()
124
+
125
+ def test_init_70b_variant(self, mock_vllm_config_70b, rng, mesh):
126
+ """Tests correct parameter detection for the 70B model variant."""
127
+ model = nnx.eval_shape(
128
+ lambda: LlamaForCausalLM(mock_vllm_config_70b, rng, mesh))
129
+ assert model.hidden_size == 8192
130
+ assert "70b" in model.vllm_config.model_config.model.lower()
131
+
132
+ def test_init_unknown_variant_raises_error(self, mock_vllm_config_unknown,
133
+ rng, mesh):
134
+ """Tests that an unknown model variant raises a ValueError."""
135
+ with pytest.raises(ValueError,
136
+ match="Could not determine Llama3 variant"):
137
+ LlamaForCausalLM(mock_vllm_config_unknown, rng, mesh)
138
+
139
+ def test_create_model_with_random_weights(self, mock_vllm_config_8b, rng,
140
+ mesh):
141
+ """
142
+ Tests that random weight initialization creates concrete, non-zero-variance arrays.
143
+ """
144
+ with jax.set_mesh(mesh):
145
+ model = LlamaForCausalLM(vllm_config=mock_vllm_config_8b,
146
+ rng=rng,
147
+ mesh=mesh,
148
+ force_random_weights=True)
149
+
150
+ embedding_weight = model.embedder.input_embedding_table_VD.value
151
+ attention_q_kernel = model.layers[0].attn.kernel_q_proj_DNH.value
152
+ final_norm_scale = model.final_norm.scale.value
153
+
154
+ assert isinstance(embedding_weight, jax.Array)
155
+ assert isinstance(attention_q_kernel, jax.Array)
156
+ assert isinstance(final_norm_scale, jax.Array)
157
+
158
+ assert jnp.std(embedding_weight) > 0
159
+ assert jnp.std(attention_q_kernel) > 0
160
+
161
+ assert jnp.all(final_norm_scale == 1.0)
162
+
163
+ @patch("tpu_inference.experimental.llama3_jax_stashed.Llama3WeightLoader")
164
+ def test_load_weights_called_correctly(self, mock_loader_cls, rng, mesh):
165
+ """Tests that the weight loader is called correctly for checkpoint loading."""
166
+ vllm_config = MockVllmConfig(model_name="llama3-8b",
167
+ random_weights=False)
168
+ model = LlamaForCausalLM(vllm_config, rng, mesh)
169
+
170
+ mock_loader_instance = MagicMock()
171
+ mock_loader_cls.return_value = mock_loader_instance
172
+ model.load_weights(rng, cache_dir="/tmp/cache")
173
+ mock_loader_cls.assert_called_once_with(vllm_config=vllm_config,
174
+ hidden_size=4096,
175
+ attn_heads=32,
176
+ num_key_value_heads=8,
177
+ attn_head_dim=128)
178
+ mock_loader_instance.load_weights.assert_called_once_with(model)
179
+
180
+
181
+ class TestLlama3WeightLoader:
182
+ """Tests for the Llama3WeightLoader class."""
183
+
184
+ @pytest.fixture
185
+ def weight_loader(self):
186
+ # Patch the superclass's setup to isolate the Llama3 loader's logic
187
+ return Llama3WeightLoader(vllm_config=MockVllmConfig("test-model"),
188
+ hidden_size=32,
189
+ attn_heads=4,
190
+ num_key_value_heads=2,
191
+ attn_head_dim=8)
192
+
193
+ def test_load_weights_transformation(self, weight_loader, rng, mesh):
194
+ """Tests that weights are correctly reshaped, transposed, and loaded."""
195
+ vllm_config = MockVllmConfig("llama3-8b-small-test",
196
+ random_weights=False)
197
+
198
+ # Create a model instance but override its config for the test.
199
+ model = LlamaForCausalLM(vllm_config, rng, mesh)
200
+
201
+ with patch(
202
+ "tpu_inference.experimental.llama3_jax_stashed.load_hf_weights"
203
+ ) as mock_load:
204
+ # This will now pass after the code fix
205
+ weight_loader.load_weights(model)
206
+
207
+ # Assert that shard_put was called with the correctly transposed weight
208
+ mock_load.assert_called_once()
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,69 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from absl.testing import absltest, parameterized
8
+ from jax._src import test_util as jtu
9
+
10
+ from tpu_inference import utils
11
+ from tpu_inference.kernels.collectives import all_gather_matmul
12
+
13
+ jax.config.parse_flags_with_absl()
14
+
15
+ P = jax.sharding.PartitionSpec
16
+
17
+ SpongeDir: str | None = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', None)
18
+
19
+
20
+ @jtu.with_config(jax_numpy_dtype_promotion='standard')
21
+ class AllGatherMatmulTest(jtu.JaxTestCase):
22
+
23
+ @parameterized.product(
24
+ grid_k=[1, 2, 3],
25
+ grid_n=[1, 2, 3],
26
+ rhs_transpose=[True, False],
27
+ )
28
+ def test_all_gather_matmul(self, grid_k, grid_n, rhs_transpose):
29
+ if jax.device_count() != 8:
30
+ self.skipTest('Not enough devices for test')
31
+
32
+ axis_name = 'x'
33
+ num_devices = jax.device_count()
34
+ mesh = utils.make_optimized_mesh((num_devices, ), (axis_name, ))
35
+ bk, bn = 1024, 1024
36
+ m, k, n = 1024, bk * grid_k, bn * grid_n * num_devices
37
+
38
+ # Run the test 10 times to expose race conditions as much as possible.
39
+ for i in range(10):
40
+ # Create input data
41
+ prng_key = jax.random.key(1234 + i)
42
+ k0, k1 = jax.random.split(prng_key, 2)
43
+ x = jax.random.normal(k0, (m, k), dtype=jnp.bfloat16)
44
+ y_shape = (n, k) if rhs_transpose else (k, n)
45
+ y_sharding = P(axis_name, None) if rhs_transpose else P(
46
+ None, axis_name)
47
+ y = jax.random.normal(k1, y_shape, dtype=jnp.bfloat16)
48
+ sharded_x = jax.device_put(
49
+ x, jax.sharding.NamedSharding(mesh, P(axis_name, None)))
50
+ sharded_y = jax.device_put(
51
+ y, jax.sharding.NamedSharding(mesh, y_sharding))
52
+
53
+ # Run the all_gather_matmul function
54
+ output = all_gather_matmul.all_gather_matmul(
55
+ sharded_x,
56
+ sharded_y,
57
+ mesh,
58
+ axis_name,
59
+ bk=bk,
60
+ bn=bn,
61
+ rhs_transpose=rhs_transpose,
62
+ )
63
+ y_for_dot = sharded_y.T if rhs_transpose else sharded_y
64
+ expected_output = jnp.dot(sharded_x, y_for_dot)
65
+ self.assertAllClose(output, expected_output, atol=1e-2, rtol=1e-2)
66
+
67
+
68
+ if __name__ == "__main__":
69
+ absltest.main(testLoader=jtu.JaxTestLoader())
@@ -0,0 +1,388 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+ import numpy as np
18
+ from absl.testing import absltest, parameterized
19
+ from jax._src import test_util as jtu
20
+ from jax.sharding import Mesh
21
+
22
+ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
23
+
24
+ jax.config.parse_flags_with_absl()
25
+
26
+
27
+ def cdiv(a, b):
28
+ assert b != 0
29
+ return (a + b - 1) // b
30
+
31
+
32
+ def align_to(x, a):
33
+ return cdiv(x, a) * a
34
+
35
+
36
+ def gen_moe_inputs(
37
+ dtype,
38
+ top_k,
39
+ num_experts,
40
+ hidden_size,
41
+ intermediate_size,
42
+ num_tokens,
43
+ *,
44
+ seed=1234,
45
+ has_bias=False,
46
+ ):
47
+ key = jax.random.key(seed)
48
+ k0, k1, k2, k3, k4, k5, k6 = jax.random.split(key, 7)
49
+
50
+ a = jax.random.normal(k0, (num_tokens, hidden_size),
51
+ dtype=jnp.float32).astype(dtype) / 10
52
+
53
+ w1 = (jax.random.normal(
54
+ k1,
55
+ (num_experts, 2, hidden_size, intermediate_size),
56
+ dtype=jnp.float32,
57
+ ) / 10).astype(dtype)
58
+ w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
59
+ dtype=jnp.float32) / 10).astype(dtype)
60
+
61
+ if has_bias:
62
+ b1 = (jax.random.normal(k3, (num_experts, 2, intermediate_size),
63
+ dtype=jnp.float32) / 10).astype(dtype)
64
+ b2 = (jax.random.normal(k4, (num_experts, hidden_size),
65
+ dtype=jnp.float32) / 10).astype(dtype)
66
+ else:
67
+ b1 = b2 = None
68
+
69
+ gating_output = (
70
+ jax.random.normal(k5, (num_tokens, num_experts), dtype=jnp.float32) +
71
+ jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
72
+ num_tokens, num_experts) / 100)
73
+
74
+ # To generate unique top-k!
75
+ top_k_indices = jax.random.randint(k6, (num_tokens, top_k),
76
+ minval=0,
77
+ maxval=num_experts - 1,
78
+ dtype=jnp.int32)
79
+
80
+ one_hot = (jnp.sum(
81
+ jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
82
+ axis=1,
83
+ ) * 30)
84
+
85
+ gating_output = (gating_output + one_hot).astype(dtype)
86
+
87
+ return a, w1, w2, b1, b2, gating_output
88
+
89
+
90
+ def sub_channel_quantize(x, quant_dtype, wsz=256):
91
+ """Quantizes x with sub-channel quantization on the 2nd minor."""
92
+ if jnp.issubdtype(quant_dtype, jnp.floating):
93
+ dtype_info = jnp.finfo(quant_dtype)
94
+ else:
95
+ dtype_info = jnp.iinfo(quant_dtype)
96
+ dtype_max = float(dtype_info.max)
97
+ w_lst, scale_lst = [], []
98
+ assert len(x.shape) >= 2
99
+ assert x.shape[-2] % wsz == 0
100
+ for i in range(0, x.shape[-2], wsz):
101
+ y = x[..., i:i + wsz, :]
102
+ abs_max = jnp.abs(y).max(axis=-2, keepdims=True)
103
+ scale = (abs_max / dtype_max).astype(jnp.float32)
104
+ w = (y / scale).astype(quant_dtype)
105
+ w_lst.append(w)
106
+ scale_lst.append(scale)
107
+ return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2)
108
+
109
+
110
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
111
+ class MoEKernelTest(jtu.JaxTestCase):
112
+
113
+ def setUp(self):
114
+ super().setUp()
115
+ self.mesh_devices = sorted(
116
+ jax.devices(),
117
+ key=lambda x: (
118
+ x.coords[0],
119
+ (-1 if x.coords[0] % 2 else 1) * x.coords[1],
120
+ ),
121
+ )
122
+ self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
123
+ axis_names=("data", "model"))
124
+
125
+ def _test_moe(
126
+ self,
127
+ dtype,
128
+ top_k,
129
+ num_experts,
130
+ hidden_size,
131
+ intermediate_size,
132
+ num_tokens,
133
+ seed,
134
+ renormalize_topk_logits,
135
+ bt,
136
+ bf,
137
+ bd1,
138
+ bd2,
139
+ btc,
140
+ bfc,
141
+ bd1c,
142
+ bd2c,
143
+ act_fn="silu",
144
+ w_dtype=None,
145
+ subc_quant_wsz=None,
146
+ has_bias=False,
147
+ atol=2e-1,
148
+ rtol=2e-1,
149
+ ):
150
+ a, w1, w2, b1, b2, gating_output = gen_moe_inputs(
151
+ dtype,
152
+ top_k,
153
+ num_experts,
154
+ hidden_size,
155
+ intermediate_size,
156
+ num_tokens,
157
+ seed=seed,
158
+ has_bias=has_bias,
159
+ )
160
+ w1_scale = None
161
+ w2_scale = None
162
+ if w_dtype is not None:
163
+ if subc_quant_wsz is None:
164
+ subc_quant_wsz = 256
165
+ w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz)
166
+ w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz)
167
+
168
+ actual = fused_ep_moe(
169
+ mesh=self.mesh,
170
+ tokens=a,
171
+ w1=w1,
172
+ w2=w2,
173
+ gating_output=gating_output,
174
+ top_k=top_k,
175
+ renormalize_topk_logits=renormalize_topk_logits,
176
+ act_fn=act_fn,
177
+ subc_quant_wsz=subc_quant_wsz,
178
+ w1_scale=w1_scale,
179
+ w2_scale=w2_scale,
180
+ b1=b1,
181
+ b2=b2,
182
+ bt=bt,
183
+ bf=bf,
184
+ bd1=bd1,
185
+ bd2=bd2,
186
+ btc=btc,
187
+ bfc=bfc,
188
+ bd1c=bd1c,
189
+ bd2c=bd2c,
190
+ )
191
+ expected = ref_moe(
192
+ a,
193
+ w1,
194
+ w2,
195
+ gating_output,
196
+ top_k,
197
+ b1=b1,
198
+ b2=b2,
199
+ renormalize_topk_logits=renormalize_topk_logits,
200
+ activation=act_fn,
201
+ subc_quant_wsz=subc_quant_wsz,
202
+ w1_scale=w1_scale,
203
+ w2_scale=w2_scale,
204
+ )
205
+ self.assertAllClose(actual, expected, atol=atol, rtol=rtol)
206
+
207
+ @parameterized.product(renormalize_topk_logits=[True, False], )
208
+ def test_basic(self, renormalize_topk_logits):
209
+ dtype = jnp.bfloat16
210
+ top_k = 8
211
+ num_experts = 128
212
+ hidden_size = 1024
213
+ intermediate_size = 1024
214
+ num_tokens = 8 * 32
215
+ self._test_moe(
216
+ dtype=dtype,
217
+ top_k=top_k,
218
+ num_experts=num_experts,
219
+ hidden_size=hidden_size,
220
+ intermediate_size=intermediate_size,
221
+ num_tokens=num_tokens,
222
+ seed=1234,
223
+ renormalize_topk_logits=renormalize_topk_logits,
224
+ bt=32,
225
+ bf=1024,
226
+ bd1=1024,
227
+ bd2=1024,
228
+ btc=32,
229
+ bfc=256,
230
+ bd1c=256,
231
+ bd2c=256,
232
+ )
233
+
234
+ @parameterized.product(act_fn=["silu", "gelu", "swigluoai"], )
235
+ def test_activation(self, act_fn):
236
+ dtype = jnp.bfloat16
237
+ top_k = 8
238
+ num_experts = 128
239
+ hidden_size = 1024
240
+ intermediate_size = 1024
241
+ num_tokens = 8 * 32
242
+ self._test_moe(
243
+ dtype=dtype,
244
+ top_k=top_k,
245
+ num_experts=num_experts,
246
+ hidden_size=hidden_size,
247
+ intermediate_size=intermediate_size,
248
+ num_tokens=num_tokens,
249
+ seed=1234,
250
+ renormalize_topk_logits=True,
251
+ act_fn=act_fn,
252
+ bt=32,
253
+ bf=512,
254
+ bd1=512,
255
+ bd2=512,
256
+ btc=32,
257
+ bfc=256,
258
+ bd1c=256,
259
+ bd2c=256,
260
+ )
261
+
262
+ def test_benchmark_qwen_235(self):
263
+ num_experts = 128
264
+ top_k = 8
265
+ hidden_size = 4096
266
+ intermediate_size = 1536
267
+ dtype = jnp.bfloat16
268
+ num_tokens = 8 * 64
269
+ seed = 54321
270
+ renormalize_topk_logits = True
271
+ self._test_moe(
272
+ dtype=dtype,
273
+ top_k=top_k,
274
+ num_experts=num_experts,
275
+ hidden_size=hidden_size,
276
+ intermediate_size=intermediate_size,
277
+ num_tokens=num_tokens,
278
+ seed=seed,
279
+ renormalize_topk_logits=renormalize_topk_logits,
280
+ bt=64,
281
+ bf=768,
282
+ bd1=2048,
283
+ bd2=2048,
284
+ btc=64,
285
+ bfc=768,
286
+ bd1c=2048,
287
+ bd2c=2048,
288
+ act_fn="silu",
289
+ atol=5e-2,
290
+ rtol=5e-2,
291
+ )
292
+
293
+ def test_benchmark_qwen_30b_a3b(self):
294
+ num_experts = 128
295
+ top_k = 8
296
+ hidden_size = 2048
297
+ intermediate_size = 768
298
+ dtype = jnp.bfloat16
299
+ num_tokens = 512
300
+ seed = 54321
301
+ renormalize_topk_logits = True
302
+ self._test_moe(
303
+ dtype=dtype,
304
+ top_k=top_k,
305
+ num_experts=num_experts,
306
+ hidden_size=hidden_size,
307
+ intermediate_size=intermediate_size,
308
+ num_tokens=num_tokens,
309
+ seed=seed,
310
+ renormalize_topk_logits=renormalize_topk_logits,
311
+ bt=16,
312
+ bf=384,
313
+ bd1=512,
314
+ bd2=512,
315
+ btc=16,
316
+ bfc=384,
317
+ bd1c=256,
318
+ bd2c=256,
319
+ act_fn="silu",
320
+ atol=5e-2,
321
+ rtol=5e-2,
322
+ )
323
+
324
+ @parameterized.product(
325
+ w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], )
326
+ def test_sub_channel_quantization(self, w_dtype):
327
+ if w_dtype in (
328
+ jnp.float8_e5m2,
329
+ jnp.float4_e2m1fn,
330
+ ) and not jtu.is_device_tpu_at_least(version=7):
331
+ self.skipTest("Expect TPUv7+")
332
+ dtype = jnp.bfloat16
333
+ top_k = 8
334
+ num_experts = 128
335
+ hidden_size = 1024
336
+ intermediate_size = 1024
337
+ num_tokens = 8 * 32
338
+ self._test_moe(
339
+ dtype=dtype,
340
+ top_k=top_k,
341
+ num_experts=num_experts,
342
+ hidden_size=hidden_size,
343
+ intermediate_size=intermediate_size,
344
+ num_tokens=num_tokens,
345
+ seed=1234,
346
+ renormalize_topk_logits=False,
347
+ w_dtype=w_dtype,
348
+ subc_quant_wsz=256,
349
+ bt=32,
350
+ bf=1024,
351
+ bd1=1024,
352
+ bd2=1024,
353
+ btc=32,
354
+ bfc=256,
355
+ bd1c=256,
356
+ bd2c=256,
357
+ )
358
+
359
+ def test_bias(self):
360
+ dtype = jnp.bfloat16
361
+ top_k = 8
362
+ num_experts = 128
363
+ hidden_size = 1024
364
+ intermediate_size = 1024
365
+ num_tokens = 8 * 32
366
+ self._test_moe(
367
+ dtype=dtype,
368
+ top_k=top_k,
369
+ num_experts=num_experts,
370
+ hidden_size=hidden_size,
371
+ intermediate_size=intermediate_size,
372
+ num_tokens=num_tokens,
373
+ seed=1234,
374
+ renormalize_topk_logits=False,
375
+ has_bias=True,
376
+ bt=32,
377
+ bf=512,
378
+ bd1=512,
379
+ bd2=512,
380
+ btc=32,
381
+ bfc=256,
382
+ bd1c=256,
383
+ bd2c=256,
384
+ )
385
+
386
+
387
+ if __name__ == "__main__":
388
+ absltest.main(testLoader=jtu.JaxTestLoader())