tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,46 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This file contains end-to-end tests for structured decoding.
16
+ #
17
+ # Structured decoding allows constraining the model's output to follow a
18
+ # specific format, such as choosing from a predefined set of options or
19
+ # following a JSON schema. This is useful for classification tasks,
20
+ # structured data extraction, and ensuring outputs conform to expected formats.
21
+
22
+ # The tests in this file verify that:
23
+ # 1. Choice-based structured decoding correctly constrains output to valid options
24
+ # 2. The model produces deterministic results when given structured constraints
25
+
26
+ from __future__ import annotations
27
+
28
+ from vllm import LLM, SamplingParams
29
+ from vllm.sampling_params import StructuredOutputsParams
30
+
31
+
32
+ def test_structured_decoding():
33
+ llm = LLM(model='meta-llama/Llama-3.2-1B-Instruct',
34
+ max_model_len=1024,
35
+ max_num_seqs=1,
36
+ enable_prefix_caching=False)
37
+
38
+ choices = ['Positive', 'Negative']
39
+ structured_outputs_params = StructuredOutputsParams(choice=choices)
40
+ sampling_params = SamplingParams(
41
+ structured_outputs=structured_outputs_params)
42
+ outputs = llm.generate(
43
+ prompts="Classify this sentiment: tpu-inference is wonderful!",
44
+ sampling_params=sampling_params,
45
+ )
46
+ assert outputs[0].outputs[0].text in choices
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,199 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ from unittest.mock import MagicMock, patch
17
+
18
+
19
+ # Mock VllmConfig and its nested configs to avoid dependencies on the actual
20
+ # classes, which can be complex to instantiate for testing.
21
+ class MockVllmConfig:
22
+
23
+ def __init__(self):
24
+ self.parallel_config = MagicMock()
25
+ self.parallel_config.world_size = 4
26
+ self.parallel_config.tensor_parallel_size = 2
27
+ self.parallel_config.pipeline_parallel_size = 1
28
+ self.parallel_config.ray_workers_use_nsight = False
29
+ self.parallel_config.placement_group = None
30
+ self.parallel_config.max_parallel_loading_workers = None
31
+
32
+ self.sharding_config = MagicMock()
33
+ self.sharding_config.total_devices = 2
34
+
35
+ self.model_config = MagicMock()
36
+ self.cache_config = MagicMock()
37
+ self.lora_config = MagicMock()
38
+ self.load_config = MagicMock()
39
+ self.scheduler_config = MagicMock()
40
+ self.speculative_config = MagicMock()
41
+ self.prompt_adapter_config = MagicMock()
42
+ self.observability_config = MagicMock()
43
+ self.device_config = MagicMock()
44
+ self.ec_transfer_config = MagicMock()
45
+
46
+
47
+ @patch(
48
+ "vllm.v1.executor.ray_distributed_executor.RayDistributedExecutor.__init__",
49
+ lambda x, y: None)
50
+ @patch("tpu_inference.executors.ray_distributed_executor.envs")
51
+ @patch("tpu_inference.executors.ray_distributed_executor.ray")
52
+ @patch("tpu_inference.executors.ray_distributed_executor.current_platform")
53
+ @patch("tpu_inference.executors.ray_distributed_executor.get_ip",
54
+ return_value="127.0.0.1")
55
+ @patch("tpu_inference.executors.ray_distributed_executor.get_open_port",
56
+ return_value=12345)
57
+ @patch(
58
+ "tpu_inference.executors.ray_distributed_executor.available_resources_per_node"
59
+ )
60
+ @patch("tpu_inference.executors.ray_distributed_executor._wait_until_pg_ready")
61
+ class TestTpuRayDistributedExecutor(unittest.TestCase):
62
+
63
+ def setUp(self):
64
+ # Import the class under test inside the test method to ensure
65
+ # patches are applied.
66
+ from tpu_inference.executors.ray_distributed_executor import \
67
+ RayDistributedExecutor
68
+ self.RayDistributedExecutor = RayDistributedExecutor
69
+
70
+ self.vllm_config = MockVllmConfig()
71
+ # Reset placement group for each test as it might be modified.
72
+ self.vllm_config.parallel_config.placement_group = None
73
+ self.vllm_config.kv_transfer_config = None
74
+
75
+ def test_init_executor_basic_flow(self, mock_wait_until_pg_ready,
76
+ mock_avail_resources, mock_get_port,
77
+ mock_get_ip, mock_platform, mock_ray,
78
+ mock_envs):
79
+ # --- Setup mocks ---
80
+ mock_envs.VLLM_USE_RAY_COMPILED_DAG = True
81
+ mock_envs.VLLM_USE_RAY_SPMD_WORKER = True
82
+ mock_envs.VLLM_RAY_BUNDLE_INDICES = ""
83
+
84
+ mock_platform.ray_device_key = "TPU"
85
+ mock_platform.device_name = "tpu"
86
+ mock_platform.device_control_env_var = "TPU_VISIBLE_CHIPS"
87
+ mock_platform.additional_env_vars = []
88
+
89
+ mock_ray.is_initialized.return_value = False
90
+ mock_ray.nodes.return_value = [{"Resources": {"TPU": 4}}]
91
+ mock_ray.get_runtime_context.return_value.get_node_id.return_value = "node_1"
92
+ mock_avail_resources.return_value = {"node_1": {"TPU": 4}}
93
+
94
+ mock_wait_until_pg_ready.return_value = None
95
+
96
+ mock_placement_group = MagicMock()
97
+ mock_placement_group.bundle_specs = [{"TPU": 1}] * 4
98
+ mock_ray.util.placement_group.return_value = mock_placement_group
99
+
100
+ mock_worker = MagicMock()
101
+ mock_worker.get_node_and_gpu_ids.remote.return_value = [("node_1",
102
+ [0, 1, 2, 3])]
103
+ mock_ray.remote.return_value.remote.return_value = mock_worker
104
+
105
+ # Simulate remote calls on the worker
106
+ mock_ray.get.side_effect = [
107
+ ["127.0.0.1"] * 4, # worker_ips
108
+ *[("node_1", [i]) for i in range(4)] # worker_node_and_tpu_ids
109
+ ]
110
+
111
+ executor = self.RayDistributedExecutor(self.vllm_config)
112
+ # Members of the parent class
113
+ executor.uses_ray = True
114
+ executor.vllm_config = self.vllm_config
115
+ executor.parallel_config = self.vllm_config.parallel_config
116
+ executor.collective_rpc = MagicMock()
117
+ executor.collective_rpc.return_value = None
118
+
119
+ # --- Initialization ---
120
+ executor._init_executor()
121
+
122
+ # --- Assertions ---
123
+ mock_ray.init.assert_called_once()
124
+ self.assertIsNotNone(executor.parallel_config.placement_group)
125
+ self.assertEqual(len(executor.workers), 4)
126
+
127
+ def test_initialize_ray_cluster_no_tpu_on_driver_raises_error(
128
+ self, mock_wait_until_pg_ready, mock_avail_resources,
129
+ mock_get_port, mock_get_ip, mock_platform, mock_ray, mock_envs):
130
+ # --- Setup Mocks ---
131
+ mock_platform.ray_device_key = "TPU"
132
+ mock_platform.device_name = "tpu"
133
+
134
+ mock_ray.is_initialized.return_value = False
135
+ mock_ray.nodes.return_value = [{"Resources": {"TPU": 4}}]
136
+ mock_ray.get_runtime_context.return_value.get_node_id.return_value = "driver_node"
137
+ # Simulate no TPUs on the driver node
138
+ mock_avail_resources.return_value = {
139
+ "driver_node": {
140
+ "CPU": 8
141
+ },
142
+ "worker_node": {
143
+ "TPU": 4
144
+ }
145
+ }
146
+
147
+ executor = self.RayDistributedExecutor(self.vllm_config)
148
+ executor.vllm_config = self.vllm_config
149
+ executor.parallel_config = self.vllm_config.parallel_config
150
+
151
+ # --- Test and Assert ---
152
+ with self.assertRaisesRegex(ValueError,
153
+ "Current node has no TPU available"):
154
+ executor._initialize_ray_cluster()
155
+
156
+ def test_init_workers_ray_sorts_correctly(self, mock_wait_until_pg_ready,
157
+ mock_avail_resources,
158
+ mock_get_port, mock_get_ip,
159
+ mock_platform, mock_ray,
160
+ mock_envs):
161
+ # --- Setup Mocks ---
162
+ mock_envs.VLLM_RAY_BUNDLE_INDICES = ""
163
+ mock_platform.ray_device_key = "TPU"
164
+ mock_get_ip.return_value = "10.0.0.1" # Driver IP
165
+
166
+ mock_pg = MagicMock()
167
+ mock_pg.bundle_specs = [{"TPU": 1}] * 4
168
+
169
+ mock_workers = [MagicMock() for _ in range(4)]
170
+ mock_ray.remote.return_value.return_value.remote.side_effect = mock_workers
171
+
172
+ # Simulate IPs for workers created with ranks 0, 1, 2, 3
173
+ worker_ips = ["10.0.0.2", "10.0.0.3", "10.0.0.1", "10.0.0.4"]
174
+ mock_ray.get.side_effect = [
175
+ worker_ips, # worker_ips
176
+ *[('node_1', ['0', '1', '2', '3']),
177
+ ('node_2', ['4', '5', '6', '7']),
178
+ ('node_3', ['8', '9', '10', '11']),
179
+ ('node_4', ['12', '13', '14', '15'])] # worker_node_and_tpu_ids
180
+ ]
181
+
182
+ executor = self.RayDistributedExecutor(self.vllm_config)
183
+ executor.use_ray_spmd_worker = True
184
+ executor.parallel_config = self.vllm_config.parallel_config
185
+ executor.vllm_config = self.vllm_config
186
+ executor.parallel_config.ray_workers_use_nsight = False
187
+ executor.collective_rpc = MagicMock()
188
+ executor.collective_rpc.return_value = None
189
+
190
+ # --- Call method under test ---
191
+ executor._init_workers_ray(mock_pg)
192
+
193
+ # --- Assertions ---
194
+ # Expected sorted order of workers: driver, then by IP
195
+ # Original workers: 0 (10.0.0.2), 1 (10.0.0.3), 2 (10.0.0.1), 3 (10.0.0.2)
196
+ # Sorted workers: 2 (driver), 0, 3 (same IP), 1
197
+ self.assertEqual(executor.workers, [
198
+ mock_workers[2], mock_workers[0], mock_workers[1], mock_workers[3]
199
+ ])
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,208 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from types import SimpleNamespace
16
+ from unittest.mock import MagicMock, patch
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ import pytest
22
+ from flax import nnx
23
+ from flax.typing import PRNGKey
24
+ from jax.sharding import Mesh
25
+
26
+ from tpu_inference.experimental.llama3_jax_stashed import (Llama3WeightLoader,
27
+ LlamaForCausalLM)
28
+
29
+
30
+ class MockParam:
31
+ """A mock for a parameter used in the Llama model."""
32
+
33
+ def __init__(self, shape=(32, 128)):
34
+ self.value = SimpleNamespace(shape=shape)
35
+ # The sharding spec is accessed during weight loading
36
+ self.sharding = SimpleNamespace(spec=None)
37
+
38
+ # Allow the mock parameter's value to be updated
39
+ def __setattr__(self, name, value):
40
+ if name == "value":
41
+ self.__dict__[name] = value
42
+ else:
43
+ super().__setattr__(name, value)
44
+
45
+
46
+ class MockVllmConfig:
47
+ """A mock VllmConfig sufficient for testing the Llama3 model."""
48
+
49
+ def __init__(self,
50
+ model_name: str,
51
+ random_weights: bool = False,
52
+ tensor_parallelism: int = 1):
53
+ self.model_config = SimpleNamespace(model=model_name,
54
+ dtype="bfloat16",
55
+ hf_overrides={},
56
+ override_generation_config={})
57
+ self.load_config = MagicMock()
58
+ self.additional_config = {
59
+ "random_weights": random_weights,
60
+ "sharding": {
61
+ "sharding_strategy": {
62
+ "tensor_parallelism": tensor_parallelism
63
+ }
64
+ }
65
+ }
66
+
67
+ # NOTE (jacobplatin): we could add a quantized KV cache test, but
68
+ # we'll skip it for now.
69
+ self.cache_config = MagicMock(cache_dtype="auto")
70
+
71
+
72
+ @pytest.fixture(scope="module")
73
+ def mesh():
74
+ """
75
+ Creates a mesh with all required axes for testing.
76
+ FIX: The sharding logic expects 'data', 'model', and 'expert' axes.
77
+ This creates a 3D mesh to satisfy the sharding rules, even on a single device.
78
+ """
79
+ if not jax.devices():
80
+ pytest.skip("No JAX devices available for mesh creation.")
81
+
82
+ devices = np.array(jax.local_devices())
83
+ # Reshape devices into a 3D array to name 3 axes: data, model, and expert.
84
+ # The 'model' and 'expert' axes will have a size of 1.
85
+ num_devices = len(devices)
86
+ device_mesh = devices.reshape((num_devices, 1, 1))
87
+
88
+ with Mesh(device_mesh, axis_names=('data', 'model', 'expert')) as m:
89
+ yield m
90
+
91
+
92
+ @pytest.fixture
93
+ def rng() -> PRNGKey:
94
+ """Provides a reusable JAX PRNGKey."""
95
+ return jax.random.PRNGKey(42)
96
+
97
+
98
+ @pytest.fixture
99
+ def mock_vllm_config_8b() -> MockVllmConfig:
100
+ return MockVllmConfig(model_name="meta-llama/Llama-3-8B")
101
+
102
+
103
+ @pytest.fixture
104
+ def mock_vllm_config_70b() -> MockVllmConfig:
105
+ return MockVllmConfig(model_name="meta-llama/Llama-3-70B-Instruct")
106
+
107
+
108
+ @pytest.fixture
109
+ def mock_vllm_config_unknown() -> MockVllmConfig:
110
+ return MockVllmConfig(model_name="some-other-model")
111
+
112
+
113
+ # --- Test Cases ---
114
+
115
+
116
+ class TestLlamaForCausalLM:
117
+ """Tests for the main LlamaForCausalLM model class."""
118
+
119
+ def test_init_8b_variant(self, mock_vllm_config_8b, rng, mesh):
120
+ """Tests correct parameter detection for the 8B model variant."""
121
+ model = LlamaForCausalLM(mock_vllm_config_8b, rng, mesh)
122
+ assert model.hidden_size == 4096
123
+ assert "8b" in model.vllm_config.model_config.model.lower()
124
+
125
+ def test_init_70b_variant(self, mock_vllm_config_70b, rng, mesh):
126
+ """Tests correct parameter detection for the 70B model variant."""
127
+ model = nnx.eval_shape(
128
+ lambda: LlamaForCausalLM(mock_vllm_config_70b, rng, mesh))
129
+ assert model.hidden_size == 8192
130
+ assert "70b" in model.vllm_config.model_config.model.lower()
131
+
132
+ def test_init_unknown_variant_raises_error(self, mock_vllm_config_unknown,
133
+ rng, mesh):
134
+ """Tests that an unknown model variant raises a ValueError."""
135
+ with pytest.raises(ValueError,
136
+ match="Could not determine Llama3 variant"):
137
+ LlamaForCausalLM(mock_vllm_config_unknown, rng, mesh)
138
+
139
+ def test_create_model_with_random_weights(self, mock_vllm_config_8b, rng,
140
+ mesh):
141
+ """
142
+ Tests that random weight initialization creates concrete, non-zero-variance arrays.
143
+ """
144
+ with jax.set_mesh(mesh):
145
+ model = LlamaForCausalLM(vllm_config=mock_vllm_config_8b,
146
+ rng=rng,
147
+ mesh=mesh,
148
+ force_random_weights=True)
149
+
150
+ embedding_weight = model.embedder.input_embedding_table_VD.value
151
+ attention_q_kernel = model.layers[0].attn.kernel_q_proj_DNH.value
152
+ final_norm_scale = model.final_norm.scale.value
153
+
154
+ assert isinstance(embedding_weight, jax.Array)
155
+ assert isinstance(attention_q_kernel, jax.Array)
156
+ assert isinstance(final_norm_scale, jax.Array)
157
+
158
+ assert jnp.std(embedding_weight) > 0
159
+ assert jnp.std(attention_q_kernel) > 0
160
+
161
+ assert jnp.all(final_norm_scale == 1.0)
162
+
163
+ @patch("tpu_inference.experimental.llama3_jax_stashed.Llama3WeightLoader")
164
+ def test_load_weights_called_correctly(self, mock_loader_cls, rng, mesh):
165
+ """Tests that the weight loader is called correctly for checkpoint loading."""
166
+ vllm_config = MockVllmConfig(model_name="llama3-8b",
167
+ random_weights=False)
168
+ model = LlamaForCausalLM(vllm_config, rng, mesh)
169
+
170
+ mock_loader_instance = MagicMock()
171
+ mock_loader_cls.return_value = mock_loader_instance
172
+ model.load_weights(rng, cache_dir="/tmp/cache")
173
+ mock_loader_cls.assert_called_once_with(vllm_config=vllm_config,
174
+ hidden_size=4096,
175
+ attn_heads=32,
176
+ num_key_value_heads=8,
177
+ attn_head_dim=128)
178
+ mock_loader_instance.load_weights.assert_called_once_with(model)
179
+
180
+
181
+ class TestLlama3WeightLoader:
182
+ """Tests for the Llama3WeightLoader class."""
183
+
184
+ @pytest.fixture
185
+ def weight_loader(self):
186
+ # Patch the superclass's setup to isolate the Llama3 loader's logic
187
+ return Llama3WeightLoader(vllm_config=MockVllmConfig("test-model"),
188
+ hidden_size=32,
189
+ attn_heads=4,
190
+ num_key_value_heads=2,
191
+ attn_head_dim=8)
192
+
193
+ def test_load_weights_transformation(self, weight_loader, rng, mesh):
194
+ """Tests that weights are correctly reshaped, transposed, and loaded."""
195
+ vllm_config = MockVllmConfig("llama3-8b-small-test",
196
+ random_weights=False)
197
+
198
+ # Create a model instance but override its config for the test.
199
+ model = LlamaForCausalLM(vllm_config, rng, mesh)
200
+
201
+ with patch(
202
+ "tpu_inference.experimental.llama3_jax_stashed.load_hf_weights"
203
+ ) as mock_load:
204
+ # This will now pass after the code fix
205
+ weight_loader.load_weights(model)
206
+
207
+ # Assert that shard_put was called with the correctly transposed weight
208
+ mock_load.assert_called_once()
tests/kernels/__init__.py CHANGED
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,69 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from absl.testing import absltest, parameterized
8
+ from jax._src import test_util as jtu
9
+
10
+ from tpu_inference import utils
11
+ from tpu_inference.kernels.collectives import all_gather_matmul
12
+
13
+ jax.config.parse_flags_with_absl()
14
+
15
+ P = jax.sharding.PartitionSpec
16
+
17
+ SpongeDir: str | None = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', None)
18
+
19
+
20
+ @jtu.with_config(jax_numpy_dtype_promotion='standard')
21
+ class AllGatherMatmulTest(jtu.JaxTestCase):
22
+
23
+ @parameterized.product(
24
+ grid_k=[1, 2, 3],
25
+ grid_n=[1, 2, 3],
26
+ rhs_transpose=[True, False],
27
+ )
28
+ def test_all_gather_matmul(self, grid_k, grid_n, rhs_transpose):
29
+ if jax.device_count() != 8:
30
+ self.skipTest('Not enough devices for test')
31
+
32
+ axis_name = 'x'
33
+ num_devices = jax.device_count()
34
+ mesh = utils.make_optimized_mesh((num_devices, ), (axis_name, ))
35
+ bk, bn = 1024, 1024
36
+ m, k, n = 1024, bk * grid_k, bn * grid_n * num_devices
37
+
38
+ # Run the test 10 times to expose race conditions as much as possible.
39
+ for i in range(10):
40
+ # Create input data
41
+ prng_key = jax.random.key(1234 + i)
42
+ k0, k1 = jax.random.split(prng_key, 2)
43
+ x = jax.random.normal(k0, (m, k), dtype=jnp.bfloat16)
44
+ y_shape = (n, k) if rhs_transpose else (k, n)
45
+ y_sharding = P(axis_name, None) if rhs_transpose else P(
46
+ None, axis_name)
47
+ y = jax.random.normal(k1, y_shape, dtype=jnp.bfloat16)
48
+ sharded_x = jax.device_put(
49
+ x, jax.sharding.NamedSharding(mesh, P(axis_name, None)))
50
+ sharded_y = jax.device_put(
51
+ y, jax.sharding.NamedSharding(mesh, y_sharding))
52
+
53
+ # Run the all_gather_matmul function
54
+ output = all_gather_matmul.all_gather_matmul(
55
+ sharded_x,
56
+ sharded_y,
57
+ mesh,
58
+ axis_name,
59
+ bk=bk,
60
+ bn=bn,
61
+ rhs_transpose=rhs_transpose,
62
+ )
63
+ y_for_dot = sharded_y.T if rhs_transpose else sharded_y
64
+ expected_output = jnp.dot(sharded_x, y_for_dot)
65
+ self.assertAllClose(output, expected_output, atol=1e-2, rtol=1e-2)
66
+
67
+
68
+ if __name__ == "__main__":
69
+ absltest.main(testLoader=jtu.JaxTestLoader())
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np