tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
tests/test_tpu_info.py ADDED
@@ -0,0 +1,120 @@
1
+ import os
2
+ from unittest.mock import MagicMock, patch
3
+
4
+ import pytest
5
+ import requests
6
+
7
+ from tpu_inference.tpu_info import (get_node_name, get_node_worker_id,
8
+ get_num_chips, get_num_cores_per_chip,
9
+ get_tpu_metadata, get_tpu_type)
10
+
11
+
12
+ # Mock requests.get for get_tpu_metadata tests
13
+ @patch("tpu_inference.tpu_info.requests.get")
14
+ def test_get_tpu_metadata_success(mock_get):
15
+ """Test get_tpu_metadata when the request is successful."""
16
+ mock_response = MagicMock()
17
+ mock_response.status_code = 200
18
+ mock_response.text = "test_metadata_value"
19
+ mock_get.return_value = mock_response
20
+ assert get_tpu_metadata(key="test-key") == "test_metadata_value"
21
+
22
+
23
+ @patch("tpu_inference.tpu_info.requests.get")
24
+ def test_get_tpu_metadata_request_error(mock_get):
25
+ """Test get_tpu_metadata when a RequestException is raised."""
26
+ mock_get.side_effect = requests.RequestException("Test RequestException")
27
+ assert get_tpu_metadata(key="test-key") is None
28
+
29
+
30
+ # Test get_tpu_type
31
+ @patch("tpu_inference.tpu_info.get_tpu_metadata")
32
+ @patch.dict(os.environ, {"TPU_ACCELERATOR_TYPE": "env_tpu_type"})
33
+ def test_get_tpu_type_from_env(mock_get_tpu_metadata):
34
+ """Test get_tpu_type when TPU_ACCELERATOR_TYPE is set in environment."""
35
+ # The function should return the env var value and not call get_tpu_metadata
36
+ assert get_tpu_type() == "env_tpu_type"
37
+ mock_get_tpu_metadata.assert_not_called()
38
+
39
+
40
+ @patch.dict(os.environ, {}, clear=True)
41
+ @patch("tpu_inference.tpu_info.get_tpu_metadata",
42
+ return_value="metadata_tpu_type")
43
+ def test_get_tpu_type_from_metadata(mock_get_tpu_metadata):
44
+ """Test get_tpu_type when environment variable is not set."""
45
+ assert get_tpu_type() == "metadata_tpu_type"
46
+ mock_get_tpu_metadata.assert_called_once_with(key="accelerator-type")
47
+
48
+
49
+ # Test get_node_name
50
+ @patch("tpu_inference.tpu_info.get_tpu_metadata")
51
+ @patch.dict(os.environ, {"TPU_NAME": "env_tpu_name"})
52
+ def test_get_node_name_from_env(mock_get_tpu_metadata):
53
+ """Test get_node_name when TPU_NAME is set in environment."""
54
+ assert get_node_name() == "env_tpu_name"
55
+ mock_get_tpu_metadata.assert_not_called()
56
+
57
+
58
+ @patch.dict(os.environ, {}, clear=True)
59
+ @patch("tpu_inference.tpu_info.get_tpu_metadata",
60
+ return_value="metadata_tpu_name")
61
+ def test_get_node_name_from_metadata(mock_get_tpu_metadata):
62
+ """Test get_node_name when environment variable is not set."""
63
+ assert get_node_name() == "metadata_tpu_name"
64
+ mock_get_tpu_metadata.assert_called_once_with(key="instance-id")
65
+
66
+
67
+ # Test get_node_worker_id
68
+ @patch("tpu_inference.tpu_info.get_tpu_metadata")
69
+ @patch.dict(os.environ, {"TPU_WORKER_ID": "5"})
70
+ def test_get_node_worker_id_from_env(mock_get_tpu_metadata):
71
+ """Test get_node_worker_id when TPU_WORKER_ID is set in environment."""
72
+ assert get_node_worker_id() == 5
73
+ mock_get_tpu_metadata.assert_not_called()
74
+
75
+
76
+ @patch.dict(os.environ, {}, clear=True)
77
+ @patch("tpu_inference.tpu_info.get_tpu_metadata", return_value="10")
78
+ def test_get_node_worker_id_from_metadata(mock_get_tpu_metadata):
79
+ """Test get_node_worker_id when environment variable is not set."""
80
+ assert get_node_worker_id() == 10
81
+ mock_get_tpu_metadata.assert_called_once_with(key="agent-worker-number")
82
+
83
+
84
+ # Test get_num_cores_per_chip
85
+ @pytest.mark.parametrize(
86
+ "tpu_type, expected",
87
+ [
88
+ ("v5litepod-4", 1),
89
+ ("v6e-8", 1),
90
+ ("v4-8", 2),
91
+ ("v5p-16", 2),
92
+ ("unknown-type", 2) # Default case
93
+ ])
94
+ @patch("tpu_inference.tpu_info.get_tpu_type")
95
+ def test_get_num_cores_per_chip(mock_get_tpu_type, tpu_type, expected):
96
+ """Test get_num_cores_per_chip with different TPU types."""
97
+ mock_get_tpu_type.return_value = tpu_type
98
+ assert get_num_cores_per_chip() == expected
99
+
100
+
101
+ # Test get_num_chips
102
+ @patch("tpu_inference.tpu_info.glob.glob",
103
+ return_value=["/dev/accel0", "/dev/accel1"])
104
+ def test_get_num_chips_from_accel(mock_glob):
105
+ """Test get_num_chips when /dev/accel* files exist."""
106
+ assert get_num_chips() == 2
107
+
108
+
109
+ @patch("tpu_inference.tpu_info.glob.glob", return_value=[])
110
+ @patch("tpu_inference.tpu_info.os.listdir", return_value=["0", "1", "2"])
111
+ def test_get_num_chips_from_vfio(mock_listdir, mock_glob):
112
+ """Test get_num_chips when /dev/accel* files don't exist but /dev/vfio entries do."""
113
+ assert get_num_chips() == 3
114
+
115
+
116
+ @patch("tpu_inference.tpu_info.glob.glob", return_value=[])
117
+ @patch("tpu_inference.tpu_info.os.listdir", side_effect=FileNotFoundError)
118
+ def test_get_num_chips_not_found(mock_listdir, mock_glob, caplog):
119
+ """Test get_num_chips when neither files nor directory are found."""
120
+ assert get_num_chips() == 0
tests/test_utils.py ADDED
@@ -0,0 +1,236 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import os
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ import jax.numpy as jnp
6
+ import pytest
7
+
8
+ # Import the functions to be tested
9
+ from tpu_inference.utils import (GBYTES, enable_megacore,
10
+ get_jax_dtype_from_str_dtype, get_megacore,
11
+ get_padded_head_dim, hbm_usage_bytes,
12
+ hbm_usage_gb, quantize_kv)
13
+
14
+
15
+ def test_enable_and_get_megacore():
16
+ """Tests the enable_megacore and get_megacore functions."""
17
+ assert not get_megacore()
18
+ enable_megacore()
19
+ assert get_megacore()
20
+
21
+
22
+ @patch.dict(os.environ, {"TPU_MULTIHOST_BACKEND": "ray"})
23
+ def test_hbm_usage_bytes_ray_backend():
24
+ """Tests hbm_usage_bytes when TPU_MULTIHOST_BACKEND is ray."""
25
+ mock_device1 = MagicMock()
26
+ mock_device1.memory_stats.return_value = {
27
+ "bytes_in_use": 100 * GBYTES,
28
+ "bytes_limit": 128 * GBYTES
29
+ }
30
+ mock_device2 = MagicMock()
31
+ mock_device2.memory_stats.side_effect = Exception("Memory stats failed")
32
+
33
+ devices = [mock_device1, mock_device2]
34
+ usage = hbm_usage_bytes(devices)
35
+
36
+ expected_usage = [(100 * GBYTES, 128 * GBYTES),
37
+ (100 * GBYTES, 128 * GBYTES)]
38
+ assert usage == expected_usage
39
+
40
+
41
+ @patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", False)
42
+ def test_hbm_usage_bytes_pathways_disabled():
43
+ """Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is False."""
44
+ mock_device1 = MagicMock()
45
+ mock_device1.memory_stats.return_value = {
46
+ "bytes_in_use": 100 * GBYTES,
47
+ "bytes_limit": 128 * GBYTES
48
+ }
49
+ mock_device2 = MagicMock()
50
+ mock_device2.memory_stats.return_value = {
51
+ "bytes_in_use": 50 * GBYTES,
52
+ "bytes_limit": 128 * GBYTES
53
+ }
54
+
55
+ devices = [mock_device1, mock_device2]
56
+ usage = hbm_usage_bytes(devices)
57
+
58
+ expected_usage = [(100 * GBYTES, 128 * GBYTES),
59
+ (50 * GBYTES, 128 * GBYTES)]
60
+ assert usage == expected_usage
61
+
62
+
63
+ @patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", True)
64
+ @patch("jax.live_arrays")
65
+ @patch("jax.devices")
66
+ def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
67
+ """Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is True."""
68
+ # Mock TPU v5p devices
69
+ mock_jax_device = MagicMock()
70
+ mock_jax_device.device_kind = "TPU v5p"
71
+ mock_devices.return_value = [mock_jax_device]
72
+
73
+ # Create mock devices
74
+ mock_device1 = MagicMock()
75
+ mock_device2 = MagicMock()
76
+ devices = [mock_device1, mock_device2]
77
+
78
+ # Create mock addressable shards with data property
79
+ mock_data1_dev1 = MagicMock()
80
+ mock_data1_dev1.device = mock_device1
81
+ mock_data1_dev1.nbytes = 2000 # 2000 bytes on device1
82
+
83
+ mock_data1_dev2 = MagicMock()
84
+ mock_data1_dev2.device = mock_device2
85
+ mock_data1_dev2.nbytes = 2000 # 2000 bytes on device2
86
+
87
+ mock_data2_dev1 = MagicMock()
88
+ mock_data2_dev1.device = mock_device1
89
+ mock_data2_dev1.nbytes = 1000 # 1000 bytes on device1
90
+
91
+ mock_shard1_dev1 = MagicMock()
92
+ mock_shard1_dev1.data = mock_data1_dev1
93
+
94
+ mock_shard1_dev2 = MagicMock()
95
+ mock_shard1_dev2.data = mock_data1_dev2
96
+
97
+ mock_shard2_dev1 = MagicMock()
98
+ mock_shard2_dev1.data = mock_data2_dev1
99
+
100
+ # Create mock arrays with addressable_shards
101
+ mock_array1 = MagicMock()
102
+ mock_array1.addressable_shards = [mock_shard1_dev1, mock_shard1_dev2]
103
+
104
+ mock_array2 = MagicMock()
105
+ mock_array2.addressable_shards = [mock_shard2_dev1]
106
+
107
+ mock_live_arrays.return_value = [mock_array1, mock_array2]
108
+
109
+ usage = hbm_usage_bytes(devices)
110
+
111
+ # Expected calculations:
112
+ # Array1: 2000 bytes on device1, 2000 bytes on device2
113
+ # Array2: 1000 bytes on device1
114
+ # Device1 total: 2000 + 1000 = 3000 bytes
115
+ # Device2 total: 2000 + 0 = 2000 bytes
116
+ # hbm_limit = 95 * GBYTES for TPU v5p
117
+ expected_usage = [(3000, 95 * GBYTES), (2000, 95 * GBYTES)]
118
+ assert usage == expected_usage
119
+
120
+
121
+ @patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", False)
122
+ def test_hbm_usage_gb_pathways_disabled():
123
+ """Tests hbm_usage_gb when VLLM_TPU_USING_PATHWAYS is False."""
124
+ mock_device1 = MagicMock()
125
+ mock_device1.memory_stats.return_value = {
126
+ "bytes_in_use": 100 * GBYTES,
127
+ "bytes_limit": 128 * GBYTES
128
+ }
129
+ mock_device2 = MagicMock()
130
+ mock_device2.memory_stats.return_value = {
131
+ "bytes_in_use": 50.5 * GBYTES,
132
+ "bytes_limit": 128.0 * GBYTES
133
+ }
134
+
135
+ devices = [mock_device1, mock_device2]
136
+ usage = hbm_usage_gb(devices)
137
+
138
+ expected_usage = [(100.0, 128.0), (50.5, 128.0)]
139
+ assert usage == expected_usage
140
+
141
+
142
+ @patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", True)
143
+ @patch("jax.live_arrays")
144
+ @patch("jax.devices")
145
+ def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
146
+ """Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is True but no live arrays."""
147
+ # Mock TPU v6e devices
148
+ mock_jax_device = MagicMock()
149
+ mock_jax_device.device_kind = "TPU v6e"
150
+ mock_devices.return_value = [mock_jax_device]
151
+
152
+ mock_device1 = MagicMock()
153
+ mock_device2 = MagicMock()
154
+ devices = [mock_device1, mock_device2]
155
+
156
+ # No live arrays
157
+ mock_live_arrays.return_value = []
158
+
159
+ usage = hbm_usage_bytes(devices)
160
+
161
+ # No arrays means no memory usage, defaultdict returns 0 for missing keys
162
+ # HBM limit for TPU v6e is 32 GB
163
+ expected_usage = [(0, 32 * GBYTES), (0, 32 * GBYTES)]
164
+ assert usage == expected_usage
165
+
166
+
167
+ @pytest.mark.parametrize(
168
+ "head_dim, expected_padded_head_dim",
169
+ [
170
+ (1, 128),
171
+ (64, 64),
172
+ (127, 128),
173
+ (128, 128),
174
+ (129, 256),
175
+ (255, 256),
176
+ (256, 256),
177
+ (0, 0), # Although head_dim is usually positive, testing boundary
178
+ ],
179
+ )
180
+ def test_get_padded_head_dim(head_dim, expected_padded_head_dim):
181
+ """Tests the get_padded_head_dim function."""
182
+ assert get_padded_head_dim(head_dim) == expected_padded_head_dim
183
+
184
+
185
+ def test_quantize_kv_float8_e4m3fn():
186
+ """Tests the quantize_kv function with float8_e4m3fn dtype."""
187
+ key = jnp.array([-1.0, 0.5, 1.0, 1.5])
188
+ value = jnp.array([2.0, 0.0, -2.0, -3.0])
189
+ kv_cache_quantized_dtype = jnp.float8_e4m3fn
190
+ k_scale = 0.1
191
+ v_scale = 0.2
192
+
193
+ quantized_key, quantized_value = quantize_kv(key, value,
194
+ kv_cache_quantized_dtype,
195
+ k_scale, v_scale)
196
+
197
+ # Expected key: key / k_scale -> clip -> astype
198
+ # [-10., 5., 10., 15.] are within float8_e4m3fn range
199
+ expected_key = jnp.array([-10.0, 5.0, 10.0, 15.0], dtype=jnp.float8_e4m3fn)
200
+
201
+ # Expected value: value / v_scale -> clip -> astype
202
+ # [10., 0., -10., -15.] are within float8_e4m3fn range
203
+ expected_value = jnp.array([10.0, 0.0, -10.0, -15.0],
204
+ dtype=jnp.float8_e4m3fn)
205
+
206
+ assert jnp.array_equal(quantized_key, expected_key)
207
+ assert jnp.array_equal(quantized_value, expected_value)
208
+
209
+ # Test clipping
210
+ dtype_info = jnp.finfo(kv_cache_quantized_dtype)
211
+ minval, maxval = float(dtype_info.min), float(dtype_info.max)
212
+
213
+ # Values that will be outside the range after scaling
214
+ key_clip = jnp.array([minval * k_scale * 2, maxval * k_scale * 2])
215
+ value_clip = jnp.array([maxval * v_scale * 2, minval * v_scale * 2])
216
+ quantized_key_clip, quantized_value_clip = quantize_kv(
217
+ key_clip, value_clip, kv_cache_quantized_dtype, k_scale, v_scale)
218
+
219
+ # Values should be clipped to the min/max of the float8 dtype
220
+ expected_key_clip = jnp.array([minval, maxval], dtype=jnp.float8_e4m3fn)
221
+ expected_value_clip = jnp.array([maxval, minval], dtype=jnp.float8_e4m3fn)
222
+
223
+ assert jnp.array_equal(quantized_key_clip, expected_key_clip)
224
+ assert jnp.array_equal(quantized_value_clip, expected_value_clip)
225
+
226
+
227
+ def test_get_jax_dtype_from_str_dtype():
228
+ """
229
+ Test the get_jax_dtype_from_str_dtype function
230
+ """
231
+ assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
232
+ assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
233
+ assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
234
+ assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3
235
+ assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
236
+ assert get_jax_dtype_from_str_dtype("auto") is None
@@ -0,0 +1,34 @@
1
+ import os
2
+
3
+ # The environment variables override should be imported before any other
4
+ # modules to ensure that the environment variables are set before any
5
+ # other modules are imported.
6
+ import tpu_inference.env_override # noqa: F401
7
+ from tpu_inference import tpu_info as ti
8
+ from tpu_inference.logger import init_logger
9
+
10
+ logger = init_logger(__name__)
11
+
12
+ if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
13
+ logger.info("Running vLLM on TPU via Pathways proxy.")
14
+ # Must run pathwaysutils.initialize() before any JAX operations
15
+ try:
16
+ import pathwaysutils
17
+ pathwaysutils.initialize()
18
+ logger.info("Module pathwaysutils is imported.")
19
+ except Exception as e:
20
+ logger.error(
21
+ f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
22
+ )
23
+ else:
24
+ # Either running on TPU or CPU
25
+ try:
26
+ logger.info(f"TPU info: node_name={ti.get_node_name()} | "
27
+ f"tpu_type={ti.get_tpu_type()} | "
28
+ f"worker_id={ti.get_node_worker_id()} | "
29
+ f"num_chips={ti.get_num_chips()} | "
30
+ f"num_cores_per_chip={ti.get_num_cores_per_chip()}")
31
+ except Exception as e:
32
+ logger.error(
33
+ f"Error occurred while logging TPU info: {e}. Are you running on CPU?"
34
+ )
File without changes