tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,414 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock, patch
16
+
17
+ import pytest
18
+ from vllm.config import ModelConfig
19
+ from vllm.lora.request import LoRARequest
20
+ from vllm.v1.kv_cache_interface import KVCacheConfig
21
+ from vllm.v1.outputs import DraftTokenIds
22
+
23
+ # The class we are testing
24
+ from tpu_inference.worker.tpu_worker import TPUWorker
25
+
26
+
27
+ @pytest.fixture
28
+ def mock_vllm_config():
29
+ """
30
+ Provides a mock VllmConfig object for tests.
31
+ This version builds the mock explicitly to avoid spec-related AttributeErrors.
32
+ """
33
+ # Create mocks for the nested config objects first
34
+ mock_cache_conf = MagicMock()
35
+ mock_cache_conf.gpu_memory_utilization = 0.9
36
+ mock_cache_conf.num_gpu_blocks = 0
37
+ mock_cache_conf.num_cpu_blocks = 0
38
+
39
+ mock_parallel_conf = MagicMock()
40
+ mock_parallel_conf.tensor_parallel_size = 2
41
+ mock_parallel_conf.data_parallel_size = 1
42
+ mock_parallel_conf.pipeline_parallel_size = 1
43
+ mock_parallel_conf.nnodes = 1
44
+ mock_parallel_conf.nnodes_within_dp = 1
45
+
46
+ mock_additional_config = {}
47
+
48
+ # Create the main config mock and attach the others without a top-level spec
49
+ config = MagicMock()
50
+ config.model_config = ModelConfig(model="Qwen/Qwen3-0.6B")
51
+ config.cache_config = mock_cache_conf
52
+ config.parallel_config = mock_parallel_conf
53
+ config.additional_config = mock_additional_config
54
+
55
+ config.sharding_config = MagicMock()
56
+ config.sharding_config.total_devices = 2
57
+
58
+ return config
59
+
60
+
61
+ class TestTPUWorker:
62
+ """Test suite for the TPUWorker class."""
63
+
64
+ #
65
+ # --- Initialization Tests ---
66
+ #
67
+
68
+ def test_init_success(self, mock_vllm_config):
69
+ """Tests successful initialization of TPUWorker."""
70
+ worker = TPUWorker(vllm_config=mock_vllm_config,
71
+ local_rank=0,
72
+ rank=0,
73
+ distributed_init_method="test_method",
74
+ is_driver_worker=True,
75
+ devices=['tpu:0'])
76
+ assert worker.vllm_config == mock_vllm_config
77
+ assert worker.rank == 0
78
+ assert worker.local_rank == 0
79
+ assert worker.is_driver_worker
80
+ assert worker.profile_dir is None
81
+ assert worker.devices == ['tpu:0']
82
+
83
+ @patch('tpu_inference.worker.tpu_worker.vllm_envs')
84
+ def test_init_with_profiler_on_rank_zero(self, mock_envs,
85
+ mock_vllm_config):
86
+ """Tests that the profiler directory is set correctly on rank 0."""
87
+ mock_envs.VLLM_TORCH_PROFILER_DIR = "/tmp/profiles"
88
+ worker = TPUWorker(vllm_config=mock_vllm_config,
89
+ local_rank=0,
90
+ rank=0,
91
+ distributed_init_method="test_method")
92
+ assert worker.profile_dir == "/tmp/profiles"
93
+
94
+ @patch('tpu_inference.worker.tpu_worker.vllm_envs')
95
+ def test_init_with_profiler_on_other_ranks(self, mock_envs,
96
+ mock_vllm_config):
97
+ """Tests that the profiler directory is NOT set on non-rank 0 workers."""
98
+ mock_envs.VLLM_TORCH_PROFILER_DIR = "/tmp/profiles"
99
+ worker = TPUWorker(vllm_config=mock_vllm_config,
100
+ local_rank=1,
101
+ rank=1,
102
+ distributed_init_method="test_method")
103
+ assert worker.profile_dir is None
104
+
105
+ #
106
+ # --- Device and Cache Initialization Tests ---
107
+ #
108
+
109
+ def test_initialize_cache(self, mock_vllm_config):
110
+ """Tests setting the number of GPU and CPU cache blocks."""
111
+ worker = TPUWorker(vllm_config=mock_vllm_config,
112
+ local_rank=0,
113
+ rank=0,
114
+ distributed_init_method="test_method")
115
+ worker.initialize_cache(num_gpu_blocks=2048, num_cpu_blocks=1024)
116
+ assert worker.cache_config.num_gpu_blocks == 2048
117
+ assert worker.cache_config.num_cpu_blocks == 1024
118
+
119
+ @patch('tpu_inference.worker.tpu_worker.TPUModelRunner')
120
+ @patch('tpu_inference.worker.tpu_worker.utils')
121
+ @patch('tpu_inference.worker.tpu_worker.jax')
122
+ @patch('tpu_inference.worker.tpu_worker.ensure_kv_transfer_initialized')
123
+ def test_init_device_with_provided_devices(
124
+ self, mock_ensure_kv_transfer_initialized, mock_jax, mock_utils,
125
+ mock_runner_cls, mock_vllm_config):
126
+ """Tests init_device when devices are provided during construction."""
127
+ mock_devices = ['tpu:0', 'tpu:1']
128
+ worker = TPUWorker(vllm_config=mock_vllm_config,
129
+ local_rank=0,
130
+ rank=0,
131
+ distributed_init_method="test_method",
132
+ devices=mock_devices)
133
+
134
+ worker.init_device()
135
+
136
+ expected_rank = 0
137
+ expected_is_first_rank = True
138
+ expected_is_last_rank = True
139
+ mock_runner_cls.assert_called_once_with(mock_vllm_config, mock_devices,
140
+ expected_rank,
141
+ expected_is_first_rank,
142
+ expected_is_last_rank)
143
+ assert isinstance(worker.model_runner, MagicMock)
144
+
145
+ @patch('tpu_inference.worker.tpu_worker.TPUModelRunner')
146
+ @patch('tpu_inference.worker.tpu_worker.utils')
147
+ @patch('tpu_inference.worker.tpu_worker.jax')
148
+ @patch('tpu_inference.worker.tpu_worker.ensure_kv_transfer_initialized')
149
+ def test_init_device_autodetects_devices(
150
+ self, mock_ensure_kv_transfer_initialized, mock_jax, mock_utils,
151
+ mock_runner_cls, mock_vllm_config):
152
+ """Tests init_device when devices are auto-detected via JAX."""
153
+ worker = TPUWorker(
154
+ vllm_config=mock_vllm_config,
155
+ local_rank=0,
156
+ rank=0,
157
+ distributed_init_method="test_method",
158
+ devices=[] # No devices provided, should trigger auto-detection
159
+ )
160
+ mock_jax.device_count.return_value = 4
161
+ mock_jax.devices.return_value = ['tpu:0', 'tpu:1', 'tpu:2', 'tpu:3']
162
+
163
+ worker.init_device()
164
+
165
+ expected_devices = ['tpu:0', 'tpu:1'] # Sliced by tensor_parallel_size
166
+ assert worker.devices == expected_devices
167
+ expected_rank = 0
168
+ expected_is_first_rank = True
169
+ expected_is_last_rank = True
170
+ mock_runner_cls.assert_called_once_with(mock_vllm_config,
171
+ expected_devices,
172
+ expected_rank,
173
+ expected_is_first_rank,
174
+ expected_is_last_rank)
175
+
176
+ @patch('tpu_inference.worker.tpu_worker.utils')
177
+ def test_determine_available_memory(self, mock_utils, mock_vllm_config):
178
+ """Tests the available HBM memory calculation."""
179
+ # Setup mock return for hbm_usage_bytes: [(used_bytes, limit_bytes), ...]
180
+ mock_utils.hbm_usage_bytes.return_value = [
181
+ (100 * 1024**3, 1000 * 1024**3), (200 * 1024**3, 1000 * 1024**3)
182
+ ]
183
+ mock_devices = ['tpu:0', 'tpu:1']
184
+ worker = TPUWorker(vllm_config=mock_vllm_config,
185
+ local_rank=0,
186
+ rank=0,
187
+ distributed_init_method="test_method",
188
+ devices=mock_devices)
189
+
190
+ available_mem = worker.determine_available_memory()
191
+
192
+ mock_utils.hbm_usage_bytes.assert_called_once_with(mock_devices)
193
+ # Total limit: 1000 + 1000 = 2000 GiB
194
+ # Total cap: 2000 * 0.9 = 1800 GiB
195
+ # Total used: 100 + 200 = 300 GiB
196
+ # Total free = 1800 - 300 = 1500 GiB
197
+ expected_mem = 1500 * 1024**3
198
+ assert available_mem == expected_mem
199
+
200
+ #
201
+ # --- Core Logic Tests ---
202
+ #
203
+
204
+ @patch('tpu_inference.worker.tpu_worker.TPUModelRunner')
205
+ def test_execute_model(self, mock_runner_cls, mock_vllm_config):
206
+ """Tests that the driver worker executes the model and returns the concrete vLLM output."""
207
+ worker = TPUWorker(vllm_config=mock_vllm_config,
208
+ local_rank=0,
209
+ rank=0,
210
+ distributed_init_method="test",
211
+ is_driver_worker=True)
212
+ worker.model_runner = mock_runner_cls.return_value # Assign mocked runner instance
213
+ mock_scheduler_input = MagicMock()
214
+
215
+ # The model runner returns a concrete vllm output
216
+ mock_model_output = "concrete_model_output"
217
+ worker.model_runner.execute_model.return_value = mock_model_output
218
+
219
+ result = worker.execute_model(mock_scheduler_input)
220
+
221
+ # Assert the runner was called with the scheduler output directly
222
+ worker.model_runner.execute_model.assert_called_once_with(
223
+ mock_scheduler_input, None)
224
+ # Assert the final result is the concrete model output
225
+ assert result == mock_model_output
226
+
227
+ @patch('tpu_inference.worker.tpu_worker.TPUModelRunner')
228
+ def test_execute_model_non_driver_returns_none(self, mock_runner_cls,
229
+ mock_vllm_config):
230
+ """Tests that a non-driver worker executes the model but returns None."""
231
+ worker = TPUWorker(
232
+ vllm_config=mock_vllm_config,
233
+ local_rank=0,
234
+ rank=0,
235
+ distributed_init_method="test",
236
+ is_driver_worker=False # Not a driver
237
+ )
238
+ worker.model_runner = mock_runner_cls.return_value
239
+ mock_scheduler_input = MagicMock()
240
+
241
+ result = worker.execute_model(mock_scheduler_input)
242
+
243
+ assert result is None
244
+
245
+ def test_take_draft_token_ids(self, mock_vllm_config):
246
+ """Tests that take_draft_token_ids correctly passes through from the runner."""
247
+ worker = TPUWorker(vllm_config=mock_vllm_config,
248
+ local_rank=0,
249
+ rank=0,
250
+ distributed_init_method="test")
251
+ worker.model_runner = MagicMock()
252
+
253
+ # Case 1: Runner returns a DraftTokenIds object
254
+ mock_draft_tokens = DraftTokenIds(req_ids=["req1"],
255
+ draft_token_ids=[[1, 2]])
256
+ worker.model_runner.take_draft_token_ids.return_value = mock_draft_tokens
257
+
258
+ result = worker.take_draft_token_ids()
259
+ worker.model_runner.take_draft_token_ids.assert_called_once()
260
+ assert result == mock_draft_tokens
261
+
262
+ def test_add_lora_not_implemented(self, mock_vllm_config):
263
+ """Tests that add_lora raises NotImplementedError."""
264
+ worker = TPUWorker(vllm_config=mock_vllm_config,
265
+ local_rank=0,
266
+ rank=0,
267
+ distributed_init_method="test")
268
+ mock_lora_request = MagicMock()
269
+
270
+ with pytest.raises(
271
+ NotImplementedError,
272
+ match="LoRA is not supported by the JAX worker yet."):
273
+ worker.add_lora(mock_lora_request)
274
+
275
+ def test_add_lora_not_implemented_lora_request(self, mock_vllm_config):
276
+ """Tests that add_lora raises NotImplementedError."""
277
+ worker = TPUWorker(vllm_config=mock_vllm_config,
278
+ local_rank=0,
279
+ rank=0,
280
+ distributed_init_method="test")
281
+ mock_lora_request = MagicMock(spec=LoRARequest)
282
+
283
+ with pytest.raises(
284
+ NotImplementedError,
285
+ match="LoRA is not supported by the JAX worker yet."):
286
+ worker.add_lora(mock_lora_request)
287
+
288
+ #
289
+ # --- Profiling and Health Check Tests ---
290
+ #
291
+
292
+ @patch('tpu_inference.worker.tpu_worker.jax')
293
+ @patch.dict('os.environ', {"PYTHON_TRACER_LEVEL": "1"}, clear=True)
294
+ def test_profile_start(self, mock_jax, mock_vllm_config):
295
+ """Tests starting the JAX profiler."""
296
+ worker = TPUWorker(vllm_config=mock_vllm_config,
297
+ local_rank=0,
298
+ rank=0,
299
+ distributed_init_method="test")
300
+ worker.profile_dir = "/tmp/profile_dir"
301
+
302
+ worker.profile(is_start=True)
303
+
304
+ mock_jax.profiler.ProfileOptions.assert_called_once()
305
+ mock_jax.profiler.start_trace.assert_called_once()
306
+ args, kwargs = mock_jax.profiler.start_trace.call_args
307
+ assert args[0] == "/tmp/profile_dir"
308
+ # Verify options from env var were used
309
+ assert kwargs['profiler_options'].python_tracer_level == 1
310
+
311
+ @patch('tpu_inference.worker.tpu_worker.jax')
312
+ def test_profile_stop(self, mock_jax, mock_vllm_config):
313
+ """Tests stopping the JAX profiler."""
314
+ worker = TPUWorker(vllm_config=mock_vllm_config,
315
+ local_rank=0,
316
+ rank=0,
317
+ distributed_init_method="test")
318
+ worker.profile(is_start=False)
319
+ mock_jax.profiler.stop_trace.assert_called_once()
320
+
321
+ def test_check_health(self, mock_vllm_config):
322
+ """Tests that check_health runs without error."""
323
+ worker = TPUWorker(vllm_config=mock_vllm_config,
324
+ local_rank=0,
325
+ rank=0,
326
+ distributed_init_method="test")
327
+ try:
328
+ worker.check_health()
329
+ except Exception as e:
330
+ pytest.fail(
331
+ f"TPUWorker.check_health() raised an unexpected exception: {e}"
332
+ )
333
+
334
+ #
335
+ # --- Pass-through Method Tests ---
336
+ #
337
+
338
+ @pytest.mark.parametrize(
339
+ "worker_method_name, runner_method_name, method_args", [
340
+ ("load_model", "load_model", []),
341
+ ("get_model", "get_model", []),
342
+ ("get_kv_cache_spec", "get_kv_cache_spec", []),
343
+ ])
344
+ def test_runner_passthrough_methods(self, worker_method_name,
345
+ runner_method_name, method_args,
346
+ mock_vllm_config):
347
+ """Tests methods that are simple pass-throughs to the TPUModelRunner."""
348
+ worker = TPUWorker(vllm_config=mock_vllm_config,
349
+ local_rank=0,
350
+ rank=0,
351
+ distributed_init_method="test")
352
+ worker.model_runner = MagicMock()
353
+
354
+ # Call the worker method and assert the underlying runner method was called
355
+ getattr(worker, worker_method_name)(*method_args)
356
+ mock_runner_method = getattr(worker.model_runner, runner_method_name)
357
+ mock_runner_method.assert_called_once_with(*method_args)
358
+
359
+ def test_initialize_from_config(self, mock_vllm_config):
360
+ """Tests the special case pass-through for initialize_from_config."""
361
+ worker = TPUWorker(vllm_config=mock_vllm_config,
362
+ local_rank=0,
363
+ rank=0,
364
+ distributed_init_method="test")
365
+ worker.model_runner = MagicMock()
366
+ worker.topology_order_id = 0
367
+ mock_input_config = MagicMock()
368
+
369
+ worker.initialize_from_config(mock_input_config)
370
+
371
+ worker.model_runner.initialize_kv_cache.assert_called_once_with(
372
+ mock_input_config, 0)
373
+
374
+ def test_initialize_from_config_kv_cache_config(self, mock_vllm_config):
375
+ """Tests the special case pass-through for initialize_from_config."""
376
+ worker = TPUWorker(vllm_config=mock_vllm_config,
377
+ local_rank=0,
378
+ rank=0,
379
+ distributed_init_method="test")
380
+ worker.model_runner = MagicMock()
381
+ worker.topology_order_id = 0
382
+ mock_input_config = MagicMock(spec=KVCacheConfig)
383
+
384
+ worker.initialize_from_config(mock_input_config)
385
+
386
+ worker.model_runner.initialize_kv_cache.assert_called_once_with(
387
+ mock_input_config, 0)
388
+
389
+ def test_compile_or_warm_up_model(self, mock_vllm_config):
390
+ """Tests the special case pass-through for model compilation/warmup."""
391
+ worker = TPUWorker(vllm_config=mock_vllm_config,
392
+ local_rank=0,
393
+ rank=0,
394
+ distributed_init_method="test")
395
+ worker.model_runner = MagicMock()
396
+
397
+ worker.compile_or_warm_up_model()
398
+
399
+ # This method calls two different runner methods
400
+ worker.model_runner.capture_model.assert_called_once()
401
+ worker.model_runner._init_random.assert_called_once()
402
+
403
+ def test_get_supported_tasks(self, mock_vllm_config):
404
+ """Test get_supported_tasks passthrough to model runner."""
405
+ worker = TPUWorker(vllm_config=mock_vllm_config,
406
+ local_rank=0,
407
+ rank=0,
408
+ distributed_init_method="test")
409
+ worker.model_runner = MagicMock()
410
+ worker.model_runner.get_supported_tasks.return_value = ("generate", )
411
+
412
+ _ = worker.get_supported_tasks()
413
+
414
+ worker.model_runner.get_supported_tasks.assert_called_once()
@@ -0,0 +1,67 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # The environment variables override should be imported before any other
16
+ # modules to ensure that the environment variables are set before any
17
+ # other modules are imported.
18
+ import tpu_inference.env_override # noqa: F401
19
+ from tpu_inference import envs
20
+ from tpu_inference import tpu_info as ti
21
+ from tpu_inference.logger import init_logger
22
+
23
+ logger = init_logger(__name__)
24
+
25
+ if "proxy" in envs.JAX_PLATFORMS:
26
+ logger.info("Running vLLM on TPU via Pathways proxy.")
27
+ # Must run pathwaysutils.initialize() before any JAX operations
28
+ try:
29
+ import traceback
30
+
31
+ import pathwaysutils
32
+ import vllm
33
+ from vllm.platforms import (resolve_current_platform_cls_qualname,
34
+ resolve_obj_by_qualname)
35
+ pathwaysutils.initialize()
36
+ logger.info("Module pathwaysutils is imported.")
37
+
38
+ # Pathways requires eager resolution of vllm.current_platform instead of
39
+ # lazy resolution in the normal code path. Since this part involves
40
+ # global topology discovery across multiple hosts, the platform
41
+ # resolution must happen before other components are loaded.
42
+ logger.info("Eagerly resolving vLLM current_platform for Pathways.")
43
+ platform_cls_qualname = resolve_current_platform_cls_qualname()
44
+ resolved_platform_instance = resolve_obj_by_qualname(
45
+ platform_cls_qualname)()
46
+ vllm.platforms._current_platform = resolved_platform_instance
47
+ vllm.platforms._init_trace = "".join(traceback.format_stack())
48
+ logger.info(
49
+ f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
50
+ )
51
+
52
+ except Exception as e:
53
+ logger.error(
54
+ f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
55
+ )
56
+ else:
57
+ # Either running on TPU or CPU
58
+ try:
59
+ logger.info(f"TPU info: node_name={ti.get_node_name()} | "
60
+ f"tpu_type={ti.get_tpu_type()} | "
61
+ f"worker_id={ti.get_node_worker_id()} | "
62
+ f"num_chips={ti.get_num_chips()} | "
63
+ f"num_cores_per_chip={ti.get_num_cores_per_chip()}")
64
+ except Exception as e:
65
+ logger.error(
66
+ f"Error occurred while logging TPU info: {e}. Are you running on CPU?"
67
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.