tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
tests/test_tpu_info.py ADDED
@@ -0,0 +1,120 @@
1
+ import os
2
+ from unittest.mock import MagicMock, patch
3
+
4
+ import pytest
5
+ import requests
6
+
7
+ from tpu_inference.tpu_info import (get_node_name, get_node_worker_id,
8
+ get_num_chips, get_num_cores_per_chip,
9
+ get_tpu_metadata, get_tpu_type)
10
+
11
+
12
+ # Mock requests.get for get_tpu_metadata tests
13
+ @patch("tpu_inference.tpu_info.requests.get")
14
+ def test_get_tpu_metadata_success(mock_get):
15
+ """Test get_tpu_metadata when the request is successful."""
16
+ mock_response = MagicMock()
17
+ mock_response.status_code = 200
18
+ mock_response.text = "test_metadata_value"
19
+ mock_get.return_value = mock_response
20
+ assert get_tpu_metadata(key="test-key") == "test_metadata_value"
21
+
22
+
23
+ @patch("tpu_inference.tpu_info.requests.get")
24
+ def test_get_tpu_metadata_request_error(mock_get):
25
+ """Test get_tpu_metadata when a RequestException is raised."""
26
+ mock_get.side_effect = requests.RequestException("Test RequestException")
27
+ assert get_tpu_metadata(key="test-key") is None
28
+
29
+
30
+ # Test get_tpu_type
31
+ @patch("tpu_inference.tpu_info.get_tpu_metadata")
32
+ @patch.dict(os.environ, {"TPU_ACCELERATOR_TYPE": "env_tpu_type"})
33
+ def test_get_tpu_type_from_env(mock_get_tpu_metadata):
34
+ """Test get_tpu_type when TPU_ACCELERATOR_TYPE is set in environment."""
35
+ # The function should return the env var value and not call get_tpu_metadata
36
+ assert get_tpu_type() == "env_tpu_type"
37
+ mock_get_tpu_metadata.assert_not_called()
38
+
39
+
40
+ @patch.dict(os.environ, {}, clear=True)
41
+ @patch("tpu_inference.tpu_info.get_tpu_metadata",
42
+ return_value="metadata_tpu_type")
43
+ def test_get_tpu_type_from_metadata(mock_get_tpu_metadata):
44
+ """Test get_tpu_type when environment variable is not set."""
45
+ assert get_tpu_type() == "metadata_tpu_type"
46
+ mock_get_tpu_metadata.assert_called_once_with(key="accelerator-type")
47
+
48
+
49
+ # Test get_node_name
50
+ @patch("tpu_inference.tpu_info.get_tpu_metadata")
51
+ @patch.dict(os.environ, {"TPU_NAME": "env_tpu_name"})
52
+ def test_get_node_name_from_env(mock_get_tpu_metadata):
53
+ """Test get_node_name when TPU_NAME is set in environment."""
54
+ assert get_node_name() == "env_tpu_name"
55
+ mock_get_tpu_metadata.assert_not_called()
56
+
57
+
58
+ @patch.dict(os.environ, {}, clear=True)
59
+ @patch("tpu_inference.tpu_info.get_tpu_metadata",
60
+ return_value="metadata_tpu_name")
61
+ def test_get_node_name_from_metadata(mock_get_tpu_metadata):
62
+ """Test get_node_name when environment variable is not set."""
63
+ assert get_node_name() == "metadata_tpu_name"
64
+ mock_get_tpu_metadata.assert_called_once_with(key="instance-id")
65
+
66
+
67
+ # Test get_node_worker_id
68
+ @patch("tpu_inference.tpu_info.get_tpu_metadata")
69
+ @patch.dict(os.environ, {"TPU_WORKER_ID": "5"})
70
+ def test_get_node_worker_id_from_env(mock_get_tpu_metadata):
71
+ """Test get_node_worker_id when TPU_WORKER_ID is set in environment."""
72
+ assert get_node_worker_id() == 5
73
+ mock_get_tpu_metadata.assert_not_called()
74
+
75
+
76
+ @patch.dict(os.environ, {}, clear=True)
77
+ @patch("tpu_inference.tpu_info.get_tpu_metadata", return_value="10")
78
+ def test_get_node_worker_id_from_metadata(mock_get_tpu_metadata):
79
+ """Test get_node_worker_id when environment variable is not set."""
80
+ assert get_node_worker_id() == 10
81
+ mock_get_tpu_metadata.assert_called_once_with(key="agent-worker-number")
82
+
83
+
84
+ # Test get_num_cores_per_chip
85
+ @pytest.mark.parametrize(
86
+ "tpu_type, expected",
87
+ [
88
+ ("v5litepod-4", 1),
89
+ ("v6e-8", 1),
90
+ ("v4-8", 2),
91
+ ("v5p-16", 2),
92
+ ("unknown-type", 2) # Default case
93
+ ])
94
+ @patch("tpu_inference.tpu_info.get_tpu_type")
95
+ def test_get_num_cores_per_chip(mock_get_tpu_type, tpu_type, expected):
96
+ """Test get_num_cores_per_chip with different TPU types."""
97
+ mock_get_tpu_type.return_value = tpu_type
98
+ assert get_num_cores_per_chip() == expected
99
+
100
+
101
+ # Test get_num_chips
102
+ @patch("tpu_inference.tpu_info.glob.glob",
103
+ return_value=["/dev/accel0", "/dev/accel1"])
104
+ def test_get_num_chips_from_accel(mock_glob):
105
+ """Test get_num_chips when /dev/accel* files exist."""
106
+ assert get_num_chips() == 2
107
+
108
+
109
+ @patch("tpu_inference.tpu_info.glob.glob", return_value=[])
110
+ @patch("tpu_inference.tpu_info.os.listdir", return_value=["0", "1", "2"])
111
+ def test_get_num_chips_from_vfio(mock_listdir, mock_glob):
112
+ """Test get_num_chips when /dev/accel* files don't exist but /dev/vfio entries do."""
113
+ assert get_num_chips() == 3
114
+
115
+
116
+ @patch("tpu_inference.tpu_info.glob.glob", return_value=[])
117
+ @patch("tpu_inference.tpu_info.os.listdir", side_effect=FileNotFoundError)
118
+ def test_get_num_chips_not_found(mock_listdir, mock_glob, caplog):
119
+ """Test get_num_chips when neither files nor directory are found."""
120
+ assert get_num_chips() == 0
tests/test_utils.py ADDED
@@ -0,0 +1,218 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import os
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ import jax.numpy as jnp
6
+ import pytest
7
+
8
+ # Import the functions to be tested
9
+ from tpu_inference.utils import (GBYTES, enable_megacore,
10
+ get_jax_dtype_from_str_dtype, get_megacore,
11
+ get_padded_head_dim, hbm_usage_bytes,
12
+ hbm_usage_gb, quantize_kv)
13
+
14
+
15
+ def test_enable_and_get_megacore():
16
+ """Tests the enable_megacore and get_megacore functions."""
17
+ assert not get_megacore()
18
+ enable_megacore()
19
+ assert get_megacore()
20
+
21
+
22
+ @patch.dict(os.environ, {"TPU_MULTIHOST_BACKEND": "ray"})
23
+ def test_hbm_usage_bytes_ray_backend():
24
+ """Tests hbm_usage_bytes when TPU_MULTIHOST_BACKEND is ray."""
25
+ mock_device1 = MagicMock()
26
+ mock_device1.memory_stats.return_value = {
27
+ "bytes_in_use": 100 * GBYTES,
28
+ "bytes_limit": 128 * GBYTES
29
+ }
30
+ mock_device2 = MagicMock()
31
+ mock_device2.memory_stats.side_effect = Exception("Memory stats failed")
32
+
33
+ devices = [mock_device1, mock_device2]
34
+ usage = hbm_usage_bytes(devices)
35
+
36
+ expected_usage = [(100 * GBYTES, 128 * GBYTES),
37
+ (100 * GBYTES, 128 * GBYTES)]
38
+ assert usage == expected_usage
39
+
40
+
41
+ @patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", False)
42
+ def test_hbm_usage_bytes_pathways_disabled():
43
+ """Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is False."""
44
+ mock_device1 = MagicMock()
45
+ mock_device1.memory_stats.return_value = {
46
+ "bytes_in_use": 100 * GBYTES,
47
+ "bytes_limit": 128 * GBYTES
48
+ }
49
+ mock_device2 = MagicMock()
50
+ mock_device2.memory_stats.return_value = {
51
+ "bytes_in_use": 50 * GBYTES,
52
+ "bytes_limit": 128 * GBYTES
53
+ }
54
+
55
+ devices = [mock_device1, mock_device2]
56
+ usage = hbm_usage_bytes(devices)
57
+
58
+ expected_usage = [(100 * GBYTES, 128 * GBYTES),
59
+ (50 * GBYTES, 128 * GBYTES)]
60
+ assert usage == expected_usage
61
+
62
+
63
+ @patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", True)
64
+ @patch("jax.live_arrays")
65
+ @patch("jax.devices")
66
+ def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
67
+ """Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is True."""
68
+ # Mock TPU v5p devices
69
+ mock_jax_device = MagicMock()
70
+ mock_jax_device.device_kind = "TPU v5p"
71
+ mock_devices.return_value = [mock_jax_device]
72
+
73
+ # Create mock devices
74
+ mock_device1 = MagicMock()
75
+ mock_device2 = MagicMock()
76
+ devices = [mock_device1, mock_device2]
77
+
78
+ # Create mock arrays with sharding
79
+ mock_array1 = MagicMock()
80
+ mock_array1.dtype.itemsize = 4 # float32
81
+ mock_array1.size = 1000 # 1000 elements
82
+ mock_array1.sharding.device_set = {mock_device1, mock_device2
83
+ } # Sharded across 2 devices
84
+
85
+ mock_array2 = MagicMock()
86
+ mock_array2.dtype.itemsize = 2 # float16
87
+ mock_array2.size = 500 # 500 elements
88
+ mock_array2.sharding.device_set = {mock_device1} # Only on device1
89
+
90
+ mock_live_arrays.return_value = [mock_array1, mock_array2]
91
+
92
+ usage = hbm_usage_bytes(devices)
93
+
94
+ # Expected calculations:
95
+ # Array1: 4 bytes * 1000 elements / 2 devices = 2000 bytes per device
96
+ # Array2: 2 bytes * 500 elements / 1 device = 1000 bytes on device1 only
97
+ # Device1: 2000 + 1000 = 3000 bytes
98
+ # Device2: 2000 + 0 = 2000 bytes
99
+ # hbm_limit = 33550237184 (hardcoded in the function)
100
+ expected_usage = [(3000, 95 * GBYTES), (2000, 95 * GBYTES)]
101
+ assert usage == expected_usage
102
+
103
+
104
+ @patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", False)
105
+ def test_hbm_usage_gb_pathways_disabled():
106
+ """Tests hbm_usage_gb when VLLM_TPU_USING_PATHWAYS is False."""
107
+ mock_device1 = MagicMock()
108
+ mock_device1.memory_stats.return_value = {
109
+ "bytes_in_use": 100 * GBYTES,
110
+ "bytes_limit": 128 * GBYTES
111
+ }
112
+ mock_device2 = MagicMock()
113
+ mock_device2.memory_stats.return_value = {
114
+ "bytes_in_use": 50.5 * GBYTES,
115
+ "bytes_limit": 128.0 * GBYTES
116
+ }
117
+
118
+ devices = [mock_device1, mock_device2]
119
+ usage = hbm_usage_gb(devices)
120
+
121
+ expected_usage = [(100.0, 128.0), (50.5, 128.0)]
122
+ assert usage == expected_usage
123
+
124
+
125
+ @patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", True)
126
+ @patch("jax.live_arrays")
127
+ @patch("jax.devices")
128
+ def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
129
+ """Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is True but no live arrays."""
130
+ # Mock TPU v5e devices
131
+ mock_jax_device = MagicMock()
132
+ mock_jax_device.device_kind = "TPU v6e"
133
+ mock_devices.return_value = [mock_jax_device]
134
+
135
+ mock_device1 = MagicMock()
136
+ mock_device2 = MagicMock()
137
+ devices = [mock_device1, mock_device2]
138
+
139
+ # No live arrays
140
+ mock_live_arrays.return_value = []
141
+
142
+ usage = hbm_usage_bytes(devices)
143
+
144
+ # No arrays means no memory usage
145
+ expected_usage = [(0, 32 * GBYTES), (0, 32 * GBYTES)]
146
+ assert usage == expected_usage
147
+
148
+
149
+ @pytest.mark.parametrize(
150
+ "head_dim, expected_padded_head_dim",
151
+ [
152
+ (1, 128),
153
+ (64, 128),
154
+ (127, 128),
155
+ (128, 128),
156
+ (129, 256),
157
+ (255, 256),
158
+ (256, 256),
159
+ (0, 0), # Although head_dim is usually positive, testing boundary
160
+ ],
161
+ )
162
+ def test_get_padded_head_dim(head_dim, expected_padded_head_dim):
163
+ """Tests the get_padded_head_dim function."""
164
+ assert get_padded_head_dim(head_dim) == expected_padded_head_dim
165
+
166
+
167
+ def test_quantize_kv_float8_e4m3fn():
168
+ """Tests the quantize_kv function with float8_e4m3fn dtype."""
169
+ key = jnp.array([-1.0, 0.5, 1.0, 1.5])
170
+ value = jnp.array([2.0, 0.0, -2.0, -3.0])
171
+ kv_cache_quantized_dtype = jnp.float8_e4m3fn
172
+ k_scale = 0.1
173
+ v_scale = 0.2
174
+
175
+ quantized_key, quantized_value = quantize_kv(key, value,
176
+ kv_cache_quantized_dtype,
177
+ k_scale, v_scale)
178
+
179
+ # Expected key: key / k_scale -> clip -> astype
180
+ # [-10., 5., 10., 15.] are within float8_e4m3fn range
181
+ expected_key = jnp.array([-10.0, 5.0, 10.0, 15.0], dtype=jnp.float8_e4m3fn)
182
+
183
+ # Expected value: value / v_scale -> clip -> astype
184
+ # [10., 0., -10., -15.] are within float8_e4m3fn range
185
+ expected_value = jnp.array([10.0, 0.0, -10.0, -15.0],
186
+ dtype=jnp.float8_e4m3fn)
187
+
188
+ assert jnp.array_equal(quantized_key, expected_key)
189
+ assert jnp.array_equal(quantized_value, expected_value)
190
+
191
+ # Test clipping
192
+ dtype_info = jnp.finfo(kv_cache_quantized_dtype)
193
+ minval, maxval = float(dtype_info.min), float(dtype_info.max)
194
+
195
+ # Values that will be outside the range after scaling
196
+ key_clip = jnp.array([minval * k_scale * 2, maxval * k_scale * 2])
197
+ value_clip = jnp.array([maxval * v_scale * 2, minval * v_scale * 2])
198
+ quantized_key_clip, quantized_value_clip = quantize_kv(
199
+ key_clip, value_clip, kv_cache_quantized_dtype, k_scale, v_scale)
200
+
201
+ # Values should be clipped to the min/max of the float8 dtype
202
+ expected_key_clip = jnp.array([minval, maxval], dtype=jnp.float8_e4m3fn)
203
+ expected_value_clip = jnp.array([maxval, minval], dtype=jnp.float8_e4m3fn)
204
+
205
+ assert jnp.array_equal(quantized_key_clip, expected_key_clip)
206
+ assert jnp.array_equal(quantized_value_clip, expected_value_clip)
207
+
208
+
209
+ def test_get_jax_dtype_from_str_dtype():
210
+ """
211
+ Test the get_jax_dtype_from_str_dtype function
212
+ """
213
+ assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
214
+ assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
215
+ assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
216
+ assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3
217
+ assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
218
+ assert get_jax_dtype_from_str_dtype("auto") is None
@@ -0,0 +1,59 @@
1
+ import unittest
2
+ from unittest.mock import Mock, patch
3
+
4
+ from tpu_inference.backend import TPUBackend
5
+
6
+
7
+ class TPUBackendTest(unittest.TestCase):
8
+
9
+ @patch('tpu_inference.backend.TPUWorker')
10
+ def test_tpu_backend_initialization(self, mock_tpu_worker_class):
11
+ """Test that TPUBackend initializes the worker correctly."""
12
+ mock_host_interface = Mock()
13
+ mock_worker_kwargs = {'worker_arg': 'test_value'}
14
+
15
+ backend = TPUBackend(host_interface=mock_host_interface,
16
+ **mock_worker_kwargs)
17
+
18
+ # Assert that the TPUWorker was instantiated with the correct arguments
19
+ mock_tpu_worker_class.assert_called_once_with(
20
+ host_interface=mock_host_interface, **mock_worker_kwargs)
21
+
22
+ # Assert that the worker attribute is an instance of the mock class
23
+ self.assertEqual(backend.worker, mock_tpu_worker_class.return_value)
24
+
25
+ @patch('tpu_inference.backend.VllmSchedulerOutputAdapter')
26
+ @patch('tpu_inference.backend.TPUWorker')
27
+ def test_launch_tpu_batch(self, mock_tpu_worker_class, mock_adapter_class):
28
+ """Test that launch_tpu_batch delegates to the worker correctly."""
29
+ mock_worker_instance = mock_tpu_worker_class.return_value
30
+
31
+ backend = TPUBackend()
32
+ mock_batch = Mock()
33
+
34
+ backend.launch_tpu_batch(mock_batch)
35
+
36
+ # Assert that the adapter was created with the correct input
37
+ mock_adapter_class.assert_called_once_with(mock_batch)
38
+
39
+ # Assert that the worker's execute_model method was called with the mock adapter's return value
40
+ mock_worker_instance.execute_model.assert_called_once_with(
41
+ mock_adapter_class.return_value)
42
+
43
+ @patch('tpu_inference.backend.VllmLoRARequestAdapter')
44
+ @patch('tpu_inference.backend.TPUWorker')
45
+ def test_add_lora(self, mock_tpu_worker_class, mock_adapter_class):
46
+ """Test that add_lora delegates to the worker correctly."""
47
+ mock_worker_instance = mock_tpu_worker_class.return_value
48
+
49
+ backend = TPUBackend()
50
+ mock_lora_request = Mock()
51
+
52
+ backend.add_lora(mock_lora_request)
53
+
54
+ # Assert that the adapter was created with the correct input
55
+ mock_adapter_class.assert_called_once_with(mock_lora_request)
56
+
57
+ # Assert that the worker's add_lora method was called with the mock adapter's return value
58
+ mock_worker_instance.add_lora.assert_called_once_with(
59
+ mock_adapter_class.return_value)
@@ -0,0 +1,30 @@
1
+ import os
2
+
3
+ from tpu_inference import tpu_info as ti
4
+ from tpu_inference.logger import init_logger
5
+
6
+ logger = init_logger(__name__)
7
+
8
+ if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
9
+ logger.info("Running vLLM on TPU via Pathways proxy.")
10
+ # Must run pathwaysutils.initialize() before any JAX operations
11
+ try:
12
+ import pathwaysutils
13
+ pathwaysutils.initialize()
14
+ logger.info("Module pathwaysutils is imported.")
15
+ except Exception as e:
16
+ logger.error(
17
+ f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
18
+ )
19
+ else:
20
+ # Either running on TPU or CPU
21
+ try:
22
+ logger.info(f"TPU info: node_name={ti.get_node_name()} | "
23
+ f"tpu_type={ti.get_tpu_type()} | "
24
+ f"worker_id={ti.get_node_worker_id()} | "
25
+ f"num_chips={ti.get_num_chips()} | "
26
+ f"num_cores_per_chip={ti.get_num_cores_per_chip()}")
27
+ except Exception as e:
28
+ logger.error(
29
+ f"Error occurred while logging TPU info: {e}. Are you running on CPU?"
30
+ )
File without changes
@@ -0,0 +1,42 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from vllm.lora.request import LoRARequest
4
+ from vllm.v1.core.sched.output import SchedulerOutput
5
+ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
6
+ from vllm.v1.outputs import ModelRunnerOutput
7
+
8
+ from tpu_inference.di.abstracts import (AbstractKVCacheConfig,
9
+ AbstractKVCacheSpec,
10
+ AbstractLoRARequest,
11
+ AbstractModelRunnerOutput,
12
+ AbstractSchedulerOutput)
13
+
14
+
15
+ class VllmModelRunnerOutputAdapter(AbstractModelRunnerOutput):
16
+
17
+ def __init__(self, vllm_output: ModelRunnerOutput):
18
+ self.vllm_output = vllm_output
19
+
20
+
21
+ class VllmSchedulerOutputAdapter(AbstractSchedulerOutput):
22
+
23
+ def __init__(self, vllm_scheduler_output: SchedulerOutput):
24
+ self.vllm_scheduler_output = vllm_scheduler_output
25
+
26
+
27
+ class VllmLoRARequestAdapter(AbstractLoRARequest):
28
+
29
+ def __init__(self, vllm_lora_request: LoRARequest):
30
+ self.vllm_lora_request = vllm_lora_request
31
+
32
+
33
+ class VllmKVCacheConfigAdapter(AbstractKVCacheConfig):
34
+
35
+ def __init__(self, vllm_kv_cache_config: KVCacheConfig):
36
+ self.vllm_kv_cache_config = vllm_kv_cache_config
37
+
38
+
39
+ class VllmKVCacheSpecAdapter(AbstractKVCacheSpec):
40
+
41
+ def __init__(self, vllm_kv_cache_spec: KVCacheSpec):
42
+ self.vllm_kv_cache_spec = vllm_kv_cache_spec
@@ -0,0 +1,134 @@
1
+ """
2
+ Adapters for wrapping concrete vLLM config objects in tpu_inference interfaces.
3
+ """
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+
8
+ from tpu_inference.interfaces.config_parts import (ICacheConfig,
9
+ ICompilationConfig,
10
+ IModelConfig,
11
+ IParallelConfig,
12
+ ISchedulerConfig)
13
+
14
+
15
+ class VllmCacheConfigAdapter(ICacheConfig):
16
+
17
+ def __init__(self, vllm_cache_config: Any):
18
+ self._vllm_cache_config = vllm_cache_config
19
+
20
+ @property
21
+ def block_size(self) -> Optional[int]:
22
+ return self._vllm_cache_config.block_size
23
+
24
+ @block_size.setter
25
+ def block_size(self, value: Optional[int]) -> None:
26
+ self._vllm_cache_config.block_size = value
27
+
28
+
29
+ class VllmCompilationConfigAdapter(ICompilationConfig):
30
+
31
+ def __init__(self, vllm_compilation_config: Any):
32
+ self._vllm_compilation_config = vllm_compilation_config
33
+
34
+ @property
35
+ def level(self) -> Any:
36
+ return self._vllm_compilation_config.level
37
+
38
+ @level.setter
39
+ def level(self, value: Any) -> None:
40
+ self._vllm_compilation_config.level = value
41
+
42
+ @property
43
+ def backend(self) -> str:
44
+ return self._vllm_compilation_config.backend
45
+
46
+ @backend.setter
47
+ def backend(self, value: str) -> None:
48
+ self._vllm_compilation_config.backend = value
49
+
50
+
51
+ class VllmModelConfigAdapter(IModelConfig):
52
+
53
+ def __init__(self, vllm_model_config: Any):
54
+ self._vllm_model_config = vllm_model_config
55
+
56
+ @property
57
+ def dtype(self) -> torch.dtype:
58
+ return self._vllm_model_config.dtype
59
+
60
+ @dtype.setter
61
+ def dtype(self, value: torch.dtype) -> None:
62
+ self._vllm_model_config.dtype = value
63
+
64
+ @property
65
+ def use_mla(self) -> bool:
66
+ return self._vllm_model_config.use_mla
67
+
68
+
69
+ class VllmParallelConfigAdapter(IParallelConfig):
70
+
71
+ def __init__(self, vllm_parallel_config: Any):
72
+ self._vllm_parallel_config = vllm_parallel_config
73
+
74
+ @property
75
+ def worker_cls(self) -> str:
76
+ return self._vllm_parallel_config.worker_cls
77
+
78
+ @worker_cls.setter
79
+ def worker_cls(self, value: str) -> None:
80
+ self._vllm_parallel_config.worker_cls = value
81
+
82
+
83
+ class VllmSchedulerConfigAdapter(ISchedulerConfig):
84
+
85
+ def __init__(self, vllm_scheduler_config: Any):
86
+ self._vllm_scheduler_config = vllm_scheduler_config
87
+
88
+ @property
89
+ def max_num_seqs(self) -> int:
90
+ return self._vllm_scheduler_config.max_num_seqs
91
+
92
+ @property
93
+ def is_multi_step(self) -> bool:
94
+ return self._vllm_scheduler_config.is_multi_step
95
+
96
+ @property
97
+ def is_multimodal_model(self) -> bool:
98
+ return self._vllm_scheduler_config.is_multimodal_model
99
+
100
+ @property
101
+ def disable_chunked_mm_input(self) -> bool:
102
+ return self._vllm_scheduler_config.disable_chunked_mm_input
103
+
104
+ @disable_chunked_mm_input.setter
105
+ def disable_chunked_mm_input(self, value: bool) -> None:
106
+ self._vllm_scheduler_config.disable_chunked_mm_input = value
107
+
108
+ @property
109
+ def enable_chunked_prefill(self) -> bool:
110
+ return self._vllm_scheduler_config.enable_chunked_prefill
111
+
112
+ @enable_chunked_prefill.setter
113
+ def enable_chunked_prefill(self, value: bool) -> None:
114
+ self._vllm_scheduler_config.enable_chunked_prefill = value
115
+
116
+ @property
117
+ def chunked_prefill_enabled(self) -> bool:
118
+ return self._vllm_scheduler_config.chunked_prefill_enabled
119
+
120
+ @chunked_prefill_enabled.setter
121
+ def chunked_prefill_enabled(self, value: bool) -> None:
122
+ self._vllm_scheduler_config.chunked_prefill_enabled = value
123
+
124
+ @property
125
+ def max_model_len(self) -> int:
126
+ return self._vllm_scheduler_config.max_model_len
127
+
128
+ @property
129
+ def max_num_batched_tokens(self) -> int:
130
+ return self._vllm_scheduler_config.max_num_batched_tokens
131
+
132
+ @max_num_batched_tokens.setter
133
+ def max_num_batched_tokens(self, value: int) -> None:
134
+ self._vllm_scheduler_config.max_num_batched_tokens = value
@@ -0,0 +1,69 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+ from typing import Optional
17
+
18
+ from tpu_inference.adapters.vllm_adapters import (VllmLoRARequestAdapter,
19
+ VllmSchedulerOutputAdapter)
20
+ from tpu_inference.di.interfaces import BackendInterface, HostInterface
21
+ from tpu_inference.worker.base import AbstractTpuWorker
22
+ from tpu_inference.worker.tpu_worker_jax import TPUWorker
23
+
24
+
25
+ class TPUBackend(BackendInterface):
26
+ """
27
+ The main entry point for the host system to interact with the TPU backend.
28
+
29
+ This class implements the BackendInterface. It is responsible for creating
30
+ and managing the concrete TPU worker instance and delegating calls to it.
31
+ """
32
+
33
+ def __init__(self,
34
+ host_interface: Optional[HostInterface] = None,
35
+ **worker_kwargs):
36
+ """
37
+ Initializes the TPUBackend.
38
+
39
+ Args:
40
+ host_interface: An optional object that implements the HostInterface,
41
+ providing a way for the backend to communicate with the host.
42
+ **worker_kwargs: Additional keyword arguments to be passed to the
43
+ worker's constructor.
44
+ """
45
+ self.worker: AbstractTpuWorker = TPUWorker(
46
+ host_interface=host_interface, **worker_kwargs)
47
+
48
+ def launch_tpu_batch(self, batch_to_launch):
49
+ """
50
+ Launches a batch of requests on the TPU worker and returns the result.
51
+
52
+ Args:
53
+ batch_to_launch: The batch of requests to be processed.
54
+
55
+ Returns:
56
+ The result of the model execution.
57
+ """
58
+ adapted_batch = VllmSchedulerOutputAdapter(batch_to_launch)
59
+ return self.worker.execute_model(adapted_batch)
60
+
61
+ def add_lora(self, lora_request):
62
+ """
63
+ Adds a LoRA adapter to the worker.
64
+
65
+ Args:
66
+ lora_request: The LoRA request to be processed.
67
+ """
68
+ adapted_lora_request = VllmLoRARequestAdapter(lora_request)
69
+ return self.worker.add_lora(adapted_lora_request)
File without changes