tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,200 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Unit tests for TPUModelRunner mesh initialization."""
15
+ import os
16
+ from unittest.mock import Mock, patch
17
+
18
+ import pytest
19
+
20
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
21
+
22
+
23
+ class TestTPUModelRunnerMeshInit:
24
+ """Test suite for TPUModelRunner._init_mesh and related methods."""
25
+
26
+ @pytest.fixture
27
+ def mock_vllm_config(self):
28
+ """Create a mock VllmConfig with sharding configuration."""
29
+ config = Mock()
30
+ config.sharding_config = Mock()
31
+ config.sharding_config.model_dp_size = 4
32
+ config.sharding_config.attn_dp_size = 2
33
+ config.sharding_config.expert_size = 1
34
+ config.sharding_config.tp_size = 8
35
+ config.sharding_config.device_indexes = None
36
+ config.sharding_config.total_dp_size = 4
37
+ return config
38
+
39
+ @pytest.fixture
40
+ def mock_devices(self):
41
+ """Create mock JAX devices."""
42
+ devices = [Mock(id=i) for i in range(64)]
43
+ return devices
44
+
45
+ @pytest.fixture
46
+ def runner_instance(self, mock_vllm_config, mock_devices):
47
+ """Create a minimal TPUModelRunner-like object for testing."""
48
+ # Create a minimal object that has the necessary attributes
49
+ runner = Mock(spec=TPUModelRunner)
50
+ runner.vllm_config = mock_vllm_config
51
+ runner.devices = mock_devices
52
+ runner.mesh = None
53
+
54
+ # Bind the actual methods to test (methods don't take sharding_strategy param)
55
+ runner._init_mesh = lambda: TPUModelRunner._init_mesh(runner)
56
+ runner._create_new_model_mesh = lambda: TPUModelRunner._create_new_model_mesh(
57
+ runner)
58
+ runner._create_2d_mesh = lambda: TPUModelRunner._create_2d_mesh(runner)
59
+ runner._create_single_slice_mesh = lambda: TPUModelRunner._create_single_slice_mesh(
60
+ runner)
61
+ runner._create_multi_slice_mesh = lambda ns: TPUModelRunner._create_multi_slice_mesh(
62
+ runner, ns)
63
+
64
+ return runner
65
+
66
+ def test_init_mesh_2d_model_without_device_order(self, runner_instance,
67
+ mock_vllm_config):
68
+ """Test 2d mesh creation without enforced device order."""
69
+ with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \
70
+ patch('tpu_inference.runner.tpu_runner.make_optimized_mesh') as mock_make_mesh, \
71
+ patch('tpu_inference.runner.tpu_runner.logger'):
72
+
73
+ mock_mesh = Mock()
74
+ mock_make_mesh.return_value = mock_mesh
75
+
76
+ runner_instance._init_mesh()
77
+
78
+ mock_make_mesh.assert_called_once()
79
+ call_args = mock_make_mesh.call_args
80
+
81
+ # Verify mesh_shape
82
+ assert call_args[0][0] == (4, 8) # (model_dp_size, tp_size)
83
+ # Verify axis_names
84
+ assert call_args[0][1] == ("data", "model")
85
+ # Verify devices
86
+ assert call_args[1]['devices'] == runner_instance.devices
87
+
88
+ assert runner_instance.mesh == mock_mesh
89
+
90
+ def test_init_mesh_2d_model_with_device_order(self, runner_instance,
91
+ mock_vllm_config):
92
+ """Test 2d mesh creation with enforced device order."""
93
+ mock_vllm_config.sharding_config.device_indexes = [0, 1, 2, 3]
94
+
95
+ with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \
96
+ patch('jax.make_mesh') as mock_jax_mesh, \
97
+ patch('tpu_inference.runner.tpu_runner.logger'):
98
+
99
+ mock_mesh = Mock()
100
+ mock_jax_mesh.return_value = mock_mesh
101
+
102
+ runner_instance._init_mesh()
103
+
104
+ mock_jax_mesh.assert_called_once()
105
+ call_args = mock_jax_mesh.call_args
106
+
107
+ # Verify mesh_shape
108
+ assert call_args[0][0] == (4, 8)
109
+ # Verify axis_names
110
+ assert call_args[0][1] == ("data", "model")
111
+ # Verify devices
112
+ assert call_args[1]['devices'] == runner_instance.devices
113
+
114
+ assert runner_instance.mesh == mock_mesh
115
+
116
+ def test_init_mesh_new_model_single_slice(self, runner_instance,
117
+ mock_vllm_config):
118
+ """Test new model mesh creation with single slice."""
119
+ with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': '1'}), \
120
+ patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \
121
+ patch('jax.sharding.Mesh') as mock_jax_mesh, \
122
+ patch('tpu_inference.runner.tpu_runner.logger'):
123
+
124
+ mock_devices_array = Mock()
125
+ mock_mesh_utils.create_device_mesh.return_value = mock_devices_array
126
+ mock_mesh = Mock()
127
+ mock_jax_mesh.return_value = mock_mesh
128
+
129
+ runner_instance._init_mesh()
130
+
131
+ # Verify create_device_mesh was called
132
+ mock_mesh_utils.create_device_mesh.assert_called_once()
133
+ call_args = mock_mesh_utils.create_device_mesh.call_args
134
+
135
+ # Verify mesh_shape: (model_dp_size, attn_dp_size, expert_size, tp_size)
136
+ assert call_args[0][0] == (4, 2, 1, 8)
137
+ assert call_args[0][1] == runner_instance.devices
138
+ assert call_args[1]['allow_split_physical_axes'] is True
139
+
140
+ # Verify Mesh was created with correct axis names
141
+ mock_jax_mesh.assert_called_once_with(
142
+ mock_devices_array, ("data", "attn_dp", "expert", "model"))
143
+
144
+ assert runner_instance.mesh == mock_mesh
145
+
146
+ def test_init_mesh_new_model_multi_slice(self, runner_instance,
147
+ mock_vllm_config):
148
+ """Test new model mesh creation with multiple slices."""
149
+ num_slices = 2
150
+ with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': str(num_slices)}), \
151
+ patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \
152
+ patch('jax.sharding.Mesh') as mock_jax_mesh, \
153
+ patch('tpu_inference.runner.tpu_runner.logger'):
154
+
155
+ mock_devices_array = Mock()
156
+ mock_mesh_utils.create_hybrid_device_mesh.return_value = mock_devices_array
157
+ mock_mesh = Mock()
158
+ mock_jax_mesh.return_value = mock_mesh
159
+
160
+ runner_instance._init_mesh()
161
+
162
+ # Verify create_hybrid_device_mesh was called
163
+ mock_mesh_utils.create_hybrid_device_mesh.assert_called_once()
164
+ call_args = mock_mesh_utils.create_hybrid_device_mesh.call_args
165
+
166
+ # Verify intra_node_shape: (dp_inner, attn_dp_size, expert_size, tp_size)
167
+ # dp_inner = model_dp_size // num_slices = 4 // 2 = 2
168
+ assert call_args[1]['mesh_shape'] == (2, 2, 1, 8)
169
+ # Verify outer_node_shape: (num_slices, 1, 1, 1)
170
+ assert call_args[1]['dcn_mesh_shape'] == (2, 1, 1, 1)
171
+ assert call_args[1]['devices'] == runner_instance.devices
172
+ assert call_args[1]['allow_split_physical_axes'] is True
173
+
174
+ # Verify Mesh was created with correct axis names
175
+ mock_jax_mesh.assert_called_once_with(
176
+ mock_devices_array, ("data", "attn_dp", "expert", "model"))
177
+
178
+ assert runner_instance.mesh == mock_mesh
179
+
180
+ @pytest.mark.parametrize("num_slices,expected_dp_inner", [
181
+ (1, 4),
182
+ (2, 2),
183
+ (4, 1),
184
+ ])
185
+ def test_multi_slice_mesh_dp_inner_calculation(self, runner_instance,
186
+ mock_vllm_config,
187
+ num_slices,
188
+ expected_dp_inner):
189
+ """Test dp_inner calculation for various num_slices values."""
190
+ with patch('tpu_inference.runner.tpu_runner.mesh_utils'
191
+ ) as mock_mesh_utils:
192
+ mock_mesh_utils.create_hybrid_device_mesh.return_value = Mock()
193
+
194
+ runner_instance._create_multi_slice_mesh(num_slices)
195
+
196
+ call_args = mock_mesh_utils.create_hybrid_device_mesh.call_args
197
+ intra_node_shape = call_args[1]['mesh_shape']
198
+
199
+ # First dimension of intra_node_shape should be dp_inner
200
+ assert intra_node_shape[0] == expected_dp_inner
@@ -0,0 +1,411 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import io
3
+ import logging
4
+ import time
5
+ from unittest.mock import MagicMock, mock_open, patch
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ import pytest
11
+ from jax._src.interpreters import pxla
12
+
13
+ from tpu_inference.runner.utils import (
14
+ PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR, ForbidCompile, InferencePhase,
15
+ LatencyTracker, PhasedBasedProfiler,
16
+ determine_phase_from_batch_composition_stats, get_batch_composition_stats,
17
+ get_padded_num_reqs_with_upper_limit, get_padded_token_len,
18
+ get_req_paddings, get_token_paddings)
19
+
20
+
21
+ def test_get_padded_num_reqs_with_upper_limit():
22
+ """Tests the get_padded_num_reqs_with_upper_limit function."""
23
+ # From utils.py, MIN_NUM_SEQS = 8
24
+ assert get_padded_num_reqs_with_upper_limit(4, 128) == 8
25
+ assert get_padded_num_reqs_with_upper_limit(8, 128) == 8
26
+ assert get_padded_num_reqs_with_upper_limit(9, 128) == 16
27
+ assert get_padded_num_reqs_with_upper_limit(16, 128) == 16
28
+ assert get_padded_num_reqs_with_upper_limit(17, 128) == 32
29
+ assert get_padded_num_reqs_with_upper_limit(100, 64) == 64
30
+ assert get_padded_num_reqs_with_upper_limit(1, 128) == 8
31
+
32
+
33
+ def test_get_paddings():
34
+ # Bucketed padding
35
+ min_token_size, max_token_size, padding_gap = 16, 512, 64
36
+ expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
37
+ actual_paddings = get_token_paddings(min_token_size, max_token_size,
38
+ padding_gap)
39
+
40
+ # Bucketed padding with max_token_size not a power of two.
41
+ max_token_size = 317
42
+ expected_paddings = [16, 32, 64, 128, 192, 256, 320]
43
+ actual_paddings = get_token_paddings(min_token_size, max_token_size,
44
+ padding_gap)
45
+ assert actual_paddings == expected_paddings
46
+
47
+ # Exponential padding.
48
+ max_token_size, padding_gap = 1024, 0
49
+ expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
50
+ actual_paddings = get_token_paddings(min_token_size, max_token_size,
51
+ padding_gap)
52
+ assert actual_paddings == expected_paddings
53
+ # Exponential padding with max_token_size not a power of two.
54
+ max_token_size = 317
55
+ expected_paddings = [16, 32, 64, 128, 256, 512]
56
+ actual_paddings = get_token_paddings(min_token_size, max_token_size,
57
+ padding_gap)
58
+ assert actual_paddings == expected_paddings
59
+
60
+
61
+ def test_get_padded_token_len():
62
+ min_token_size, max_token_size, padding_gap = 16, 512, 64
63
+ paddings = get_token_paddings(min_token_size, max_token_size, padding_gap)
64
+ assert get_padded_token_len(paddings, 1) == 16
65
+ assert get_padded_token_len(paddings, 16) == 16
66
+ assert get_padded_token_len(paddings, 20) == 32
67
+ assert get_padded_token_len(paddings, 300) == 320
68
+ assert get_padded_token_len(paddings, 512) == 512
69
+
70
+
71
+ def test_get_req_paddings():
72
+ assert get_req_paddings(1, 32) == [8, 16, 32]
73
+ assert get_req_paddings(8, 32) == [8, 16, 32]
74
+ assert get_req_paddings(8, 36) == [8, 16, 32, 36]
75
+
76
+
77
+ def test_latency_tracker(caplog):
78
+ """Tests the LatencyTracker context manager."""
79
+ logger_name = "vllm.tpu_inference.runner.utils"
80
+ logger = logging.getLogger(logger_name)
81
+
82
+ original_level = logger.level
83
+ original_propagate = logger.propagate
84
+
85
+ # Create an in-memory stream to capture log output
86
+ log_capture_string = io.StringIO()
87
+ # Create a handler that writes to our in-memory stream
88
+ capture_handler = logging.StreamHandler(log_capture_string)
89
+
90
+ try:
91
+ logger.setLevel(logging.DEBUG)
92
+ logger.propagate = False
93
+ logger.addHandler(capture_handler)
94
+
95
+ sleep_duration = 0.01
96
+ with LatencyTracker("test_op") as tracker:
97
+ time.sleep(sleep_duration)
98
+
99
+ elapsed = tracker.end_time - tracker.start_time
100
+ assert elapsed >= sleep_duration
101
+ log_contents = log_capture_string.getvalue()
102
+
103
+ assert "Latency for 'test_op'" in log_contents
104
+ assert f"{elapsed:.3f} seconds" in log_contents
105
+
106
+ finally:
107
+ # --- IMPORTANT: Clean up and restore the logger's original state ---
108
+ logger.setLevel(original_level)
109
+ logger.propagate = original_propagate
110
+ logger.removeHandler(capture_handler)
111
+
112
+
113
+ # Define a fixture to clear the JAX cache before each test
114
+ @pytest.fixture(autouse=True)
115
+ def clear_jax_cache():
116
+ jax.clear_caches()
117
+ yield
118
+ jax.clear_caches()
119
+
120
+
121
+ @pytest.fixture
122
+ def jitted_function():
123
+ """Defines a jitted function for testing."""
124
+
125
+ @jax.jit
126
+ def my_jitted_func(x):
127
+ return x * 2
128
+
129
+ return my_jitted_func
130
+
131
+
132
+ @pytest.fixture
133
+ def jnp_array_input():
134
+ return jnp.ones((2, 3))
135
+
136
+
137
+ @pytest.fixture
138
+ def jnp_array_input_same_shape():
139
+ return jnp.zeros((2, 3))
140
+
141
+
142
+ @pytest.fixture
143
+ def jnp_array_input_new():
144
+ return jnp.ones((3, 3))
145
+
146
+
147
+ def test_forbid_compile_raises_error_on_first_call(jitted_function,
148
+ jnp_array_input):
149
+ """Test that ForbidCompile raises an error when a compilation occurs."""
150
+ with pytest.raises(RuntimeError, match="JAX compilation occurred"):
151
+ with ForbidCompile():
152
+ jitted_function(jnp_array_input)
153
+
154
+
155
+ def test_forbid_compile_succeeds_on_cached_call(jitted_function,
156
+ jnp_array_input):
157
+ """Test that ForbidCompile does not raise an error on a cached call."""
158
+ # Warm up the cache
159
+ jitted_function(jnp_array_input)
160
+ with ForbidCompile():
161
+ jitted_function(jnp_array_input)
162
+
163
+
164
+ def test_forbid_compile_restores_original_function():
165
+ """Test that ForbidCompile restores the original JAX function after exit."""
166
+ original_func = pxla._cached_lowering_to_hlo
167
+ with ForbidCompile():
168
+ pass
169
+ assert pxla._cached_lowering_to_hlo is original_func
170
+
171
+
172
+ def test_forbid_compile_with_exception():
173
+ """Test that ForbidCompile restores the original function even if an exception occurs."""
174
+ original_func = pxla._cached_lowering_to_hlo
175
+ with pytest.raises(ValueError, match="Test exception"):
176
+ with ForbidCompile():
177
+ raise ValueError("Test exception")
178
+ assert pxla._cached_lowering_to_hlo is original_func
179
+
180
+
181
+ def test_forbid_compile_raises_on_new_shape(jitted_function, jnp_array_input,
182
+ jnp_array_input_same_shape,
183
+ jnp_array_input_new):
184
+ """
185
+ Tests that ForbidCompile raises a RuntimeError when a jitted function
186
+ is called with an input shape that triggers a new compilation.
187
+ """
188
+ # Clear cache for a clean test state.
189
+ pxla._cached_lowering_to_hlo.cache_clear()
190
+
191
+ # Warm up the JIT cache with the SCALAR input.
192
+ # This causes the first compilation and cache miss.
193
+ jitted_function(jnp_array_input)
194
+ misses_after_warmup = pxla._cached_lowering_to_hlo.cache_info().misses
195
+ assert misses_after_warmup == 1
196
+
197
+ # This call uses the same shape/dtype, so it should be a cache HIT.
198
+ # No RuntimeError expected.
199
+ with ForbidCompile():
200
+ jitted_function(jnp_array_input_same_shape)
201
+ assert pxla._cached_lowering_to_hlo.cache_info(
202
+ ).misses == misses_after_warmup # No new misses
203
+
204
+ # Now, call with a VECTOR input. This has a different shape,
205
+ # forcing a NEW compilation (cache MISS).
206
+ # This *should* raise a RuntimeError within the ForbidCompile context.
207
+ expected_error_message = "JAX compilation occurred but was forbidden in this context."
208
+ with pytest.raises(RuntimeError, match=expected_error_message):
209
+ with ForbidCompile(message=expected_error_message):
210
+ jitted_function(jnp_array_input_new)
211
+
212
+
213
+ class MockInputBatch:
214
+
215
+ def __init__(self, req_ids, num_computed_tokens_cpu):
216
+ self.req_ids = req_ids
217
+ self.num_computed_tokens_cpu = np.array(num_computed_tokens_cpu)
218
+
219
+
220
+ class MockSchedulerOutput:
221
+
222
+ def __init__(self, num_scheduled_tokens):
223
+ self.num_scheduled_tokens = num_scheduled_tokens
224
+
225
+
226
+ @pytest.mark.parametrize(
227
+ "scenario, num_reqs, req_ids, computed, scheduled, expected_prefill, expected_decode",
228
+ [
229
+ ("prefill_only", 2, [101, 102], [0, 0], {
230
+ 101: 50,
231
+ 102: 100
232
+ }, 150, 0),
233
+ ("decode_only", 3, [201, 202, 203], [10, 20, 5], {
234
+ 201: 1,
235
+ 202: 1,
236
+ 203: 1
237
+ }, 0, 3),
238
+ ("mixed_batch", 4, [301, 302, 303, 304], [0, 10, 0, 20], {
239
+ 301: 100,
240
+ 302: 1,
241
+ 303: 50,
242
+ 304: 1
243
+ }, 150, 2),
244
+ ("chunked_prefill", 2, [401, 402], [50, 10], {
245
+ 401: 50,
246
+ 402: 1
247
+ }, 50, 1),
248
+ ])
249
+ def test_get_batch_composition_stats(scenario, num_reqs, req_ids, computed,
250
+ scheduled, expected_prefill,
251
+ expected_decode):
252
+ """Tests get_batch_composition_stats for various scenarios."""
253
+ input_batch = MockInputBatch(req_ids, computed)
254
+ scheduler_output = MockSchedulerOutput(scheduled)
255
+ total_tokens = sum(scheduled.values())
256
+
257
+ stats = get_batch_composition_stats(
258
+ input_batch=input_batch,
259
+ total_num_scheduled_tokens=total_tokens,
260
+ num_reqs=num_reqs,
261
+ padded_total_num_scheduled_tokens=total_tokens + 8,
262
+ scheduler_output=scheduler_output)
263
+
264
+ assert stats["num_prefill_tokens"] == expected_prefill
265
+ assert stats["num_decode_tokens"] == expected_decode
266
+ assert stats["num_reqs"] == num_reqs
267
+ assert stats["total_num_scheduled_tokens"] == total_tokens
268
+
269
+
270
+ @pytest.mark.parametrize("prefill_tokens, total_tokens, expected_phase", [
271
+ (90, 100, InferencePhase.PREFILL_HEAVY),
272
+ (89, 100, InferencePhase.AMBIGUOUS),
273
+ (15, 100, InferencePhase.DECODE_HEAVY),
274
+ (50, 100, InferencePhase.BALANCED),
275
+ (70, 100, InferencePhase.AMBIGUOUS),
276
+ (30, 100, InferencePhase.AMBIGUOUS),
277
+ (40, 100, InferencePhase.BALANCED),
278
+ (50, 100, InferencePhase.BALANCED),
279
+ (60, 100, InferencePhase.BALANCED),
280
+ (100, 100, InferencePhase.PREFILL_HEAVY),
281
+ (20, 100, InferencePhase.DECODE_HEAVY),
282
+ (21, 100, InferencePhase.AMBIGUOUS),
283
+ (0, 100, InferencePhase.DECODE_HEAVY),
284
+ ])
285
+ def test_determine_phase_from_batch_composition_stats(prefill_tokens,
286
+ total_tokens,
287
+ expected_phase):
288
+ """Tests the phase determination logic based on prefill ratios."""
289
+ stats = {
290
+ "num_prefill_tokens": prefill_tokens,
291
+ "total_num_scheduled_tokens": total_tokens
292
+ }
293
+ phase = determine_phase_from_batch_composition_stats(stats)
294
+ assert phase == expected_phase
295
+
296
+
297
+ @pytest.fixture
298
+ def profiler_fixture(tmp_path):
299
+ """Fixture to set up a PhasedBasedProfiler with mocked dependencies."""
300
+ target_module = "tpu_inference.runner.utils"
301
+ with patch(f"{target_module}.jax.profiler.start_trace") as mock_start, \
302
+ patch(f"{target_module}.jax.profiler.stop_trace") as mock_stop, \
303
+ patch("builtins.open", mock_open()) as mock_file, \
304
+ patch(f"{target_module}.datetime") as mock_datetime, \
305
+ patch(f"{target_module}.InferencePhase", InferencePhase), \
306
+ patch(f"{target_module}.determine_phase_from_batch_composition_stats") as mock_determine_phase:
307
+
308
+ mock_now = MagicMock()
309
+ mock_now.strftime.return_value = "2024_01_01_12_00_00"
310
+ mock_datetime.datetime.now.return_value = mock_now
311
+
312
+ profiler = PhasedBasedProfiler(profile_dir=str(tmp_path))
313
+ profiler.num_steps_to_profile_for = PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR
314
+
315
+ yield {
316
+ "profiler": profiler,
317
+ "mock_start": mock_start,
318
+ "mock_stop": mock_stop,
319
+ "mock_file": mock_file,
320
+ "mock_determine_phase": mock_determine_phase,
321
+ }
322
+
323
+
324
+ def test_phased_profiler_full_cycle(profiler_fixture):
325
+ """Tests a full start-step-stop profiling cycle for one phase."""
326
+ profiler = profiler_fixture["profiler"]
327
+ mock_start = profiler_fixture["mock_start"]
328
+ mock_stop = profiler_fixture["mock_stop"]
329
+ mock_file = profiler_fixture["mock_file"]
330
+ mock_determine_phase = profiler_fixture["mock_determine_phase"]
331
+
332
+ stats = {"num_reqs": 2, "total_num_scheduled_tokens": 100}
333
+
334
+ # 1. Start profiling on PREFILL_HEAVY phase
335
+ mock_determine_phase.return_value = InferencePhase.PREFILL_HEAVY
336
+ profiler.step(stats)
337
+ mock_start.assert_called_once()
338
+ assert profiler.profiling_n_steps_left == PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR
339
+ assert profiler.current_phase == "prefill_heavy"
340
+ assert profiler.inference_phase_seen[InferencePhase.PREFILL_HEAVY]
341
+ assert mock_file().write.call_count == 1 # Wrote stats on start
342
+
343
+ # 2. Step profiling (N-1 steps)
344
+ for i in range(PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR - 1):
345
+ profiler.step(stats)
346
+ assert profiler.profiling_n_steps_left == PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR - 1 - i
347
+ mock_start.assert_called_once() # Not called again
348
+ mock_stop.assert_not_called()
349
+
350
+ # 3. Final step stops profiling
351
+ profiler.step(stats)
352
+ mock_stop.assert_called_once()
353
+ assert profiler.profiling_n_steps_left == 0
354
+ assert profiler.current_phase == ""
355
+ assert mock_file(
356
+ ).write.call_count == PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR + 1
357
+
358
+
359
+ def test_phased_profiler_ignores_initial_request(profiler_fixture):
360
+ """Tests that profiling is not triggered for initial small requests."""
361
+ profiler = profiler_fixture["profiler"]
362
+ mock_start = profiler_fixture["mock_start"]
363
+ mock_determine_phase = profiler_fixture["mock_determine_phase"]
364
+
365
+ mock_determine_phase.return_value = InferencePhase.PREFILL_HEAVY
366
+
367
+ profiler.step({"num_reqs": 1, "total_num_scheduled_tokens": 1})
368
+ mock_start.assert_not_called()
369
+
370
+ profiler.step({"num_reqs": 1, "total_num_scheduled_tokens": 100})
371
+ mock_start.assert_not_called()
372
+
373
+ profiler.step({"num_reqs": 2, "total_num_scheduled_tokens": 1})
374
+ mock_start.assert_not_called()
375
+
376
+ profiler.step({"num_reqs": 2, "total_num_scheduled_tokens": 2})
377
+ mock_start.assert_called_once()
378
+
379
+
380
+ def test_phased_profiler_handles_all_phases(profiler_fixture):
381
+ """Tests that the profiler can profile all defined phases sequentially."""
382
+ profiler = profiler_fixture["profiler"]
383
+ mock_start = profiler_fixture["mock_start"]
384
+ mock_stop = profiler_fixture["mock_stop"]
385
+ mock_determine_phase = profiler_fixture["mock_determine_phase"]
386
+
387
+ stats = {"num_reqs": 2, "total_num_scheduled_tokens": 100}
388
+ phases_to_profile = [
389
+ InferencePhase.PREFILL_HEAVY, InferencePhase.DECODE_HEAVY,
390
+ InferencePhase.BALANCED
391
+ ]
392
+
393
+ for i, phase in enumerate(phases_to_profile):
394
+ # Start profiling for the new phase
395
+ mock_determine_phase.return_value = phase
396
+ profiler.step(stats)
397
+ assert mock_start.call_count == i + 1
398
+ assert profiler.current_phase == phase.name.lower()
399
+ assert profiler.inference_phase_seen[phase]
400
+
401
+ # Step until profiling stops for this phase
402
+ for _ in range(PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR):
403
+ profiler.step(stats)
404
+
405
+ assert mock_stop.call_count == i + 1
406
+ assert profiler.current_phase == ""
407
+
408
+ # After all phases seen, should not start again
409
+ mock_determine_phase.return_value = InferencePhase.PREFILL_HEAVY
410
+ profiler.step(stats)
411
+ assert mock_start.call_count == len(phases_to_profile)
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.