tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
tests/test_tpu_info.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from unittest.mock import MagicMock, patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
from tpu_inference.tpu_info import (get_node_name, get_node_worker_id,
|
|
8
|
+
get_num_chips, get_num_cores_per_chip,
|
|
9
|
+
get_tpu_metadata, get_tpu_type)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Mock requests.get for get_tpu_metadata tests
|
|
13
|
+
@patch("tpu_inference.tpu_info.requests.get")
|
|
14
|
+
def test_get_tpu_metadata_success(mock_get):
|
|
15
|
+
"""Test get_tpu_metadata when the request is successful."""
|
|
16
|
+
mock_response = MagicMock()
|
|
17
|
+
mock_response.status_code = 200
|
|
18
|
+
mock_response.text = "test_metadata_value"
|
|
19
|
+
mock_get.return_value = mock_response
|
|
20
|
+
assert get_tpu_metadata(key="test-key") == "test_metadata_value"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@patch("tpu_inference.tpu_info.requests.get")
|
|
24
|
+
def test_get_tpu_metadata_request_error(mock_get):
|
|
25
|
+
"""Test get_tpu_metadata when a RequestException is raised."""
|
|
26
|
+
mock_get.side_effect = requests.RequestException("Test RequestException")
|
|
27
|
+
assert get_tpu_metadata(key="test-key") is None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Test get_tpu_type
|
|
31
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata")
|
|
32
|
+
@patch.dict(os.environ, {"TPU_ACCELERATOR_TYPE": "env_tpu_type"})
|
|
33
|
+
def test_get_tpu_type_from_env(mock_get_tpu_metadata):
|
|
34
|
+
"""Test get_tpu_type when TPU_ACCELERATOR_TYPE is set in environment."""
|
|
35
|
+
# The function should return the env var value and not call get_tpu_metadata
|
|
36
|
+
assert get_tpu_type() == "env_tpu_type"
|
|
37
|
+
mock_get_tpu_metadata.assert_not_called()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@patch.dict(os.environ, {}, clear=True)
|
|
41
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata",
|
|
42
|
+
return_value="metadata_tpu_type")
|
|
43
|
+
def test_get_tpu_type_from_metadata(mock_get_tpu_metadata):
|
|
44
|
+
"""Test get_tpu_type when environment variable is not set."""
|
|
45
|
+
assert get_tpu_type() == "metadata_tpu_type"
|
|
46
|
+
mock_get_tpu_metadata.assert_called_once_with(key="accelerator-type")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# Test get_node_name
|
|
50
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata")
|
|
51
|
+
@patch.dict(os.environ, {"TPU_NAME": "env_tpu_name"})
|
|
52
|
+
def test_get_node_name_from_env(mock_get_tpu_metadata):
|
|
53
|
+
"""Test get_node_name when TPU_NAME is set in environment."""
|
|
54
|
+
assert get_node_name() == "env_tpu_name"
|
|
55
|
+
mock_get_tpu_metadata.assert_not_called()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@patch.dict(os.environ, {}, clear=True)
|
|
59
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata",
|
|
60
|
+
return_value="metadata_tpu_name")
|
|
61
|
+
def test_get_node_name_from_metadata(mock_get_tpu_metadata):
|
|
62
|
+
"""Test get_node_name when environment variable is not set."""
|
|
63
|
+
assert get_node_name() == "metadata_tpu_name"
|
|
64
|
+
mock_get_tpu_metadata.assert_called_once_with(key="instance-id")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# Test get_node_worker_id
|
|
68
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata")
|
|
69
|
+
@patch.dict(os.environ, {"TPU_WORKER_ID": "5"})
|
|
70
|
+
def test_get_node_worker_id_from_env(mock_get_tpu_metadata):
|
|
71
|
+
"""Test get_node_worker_id when TPU_WORKER_ID is set in environment."""
|
|
72
|
+
assert get_node_worker_id() == 5
|
|
73
|
+
mock_get_tpu_metadata.assert_not_called()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@patch.dict(os.environ, {}, clear=True)
|
|
77
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata", return_value="10")
|
|
78
|
+
def test_get_node_worker_id_from_metadata(mock_get_tpu_metadata):
|
|
79
|
+
"""Test get_node_worker_id when environment variable is not set."""
|
|
80
|
+
assert get_node_worker_id() == 10
|
|
81
|
+
mock_get_tpu_metadata.assert_called_once_with(key="agent-worker-number")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# Test get_num_cores_per_chip
|
|
85
|
+
@pytest.mark.parametrize(
|
|
86
|
+
"tpu_type, expected",
|
|
87
|
+
[
|
|
88
|
+
("v5litepod-4", 1),
|
|
89
|
+
("v6e-8", 1),
|
|
90
|
+
("v4-8", 2),
|
|
91
|
+
("v5p-16", 2),
|
|
92
|
+
("unknown-type", 2) # Default case
|
|
93
|
+
])
|
|
94
|
+
@patch("tpu_inference.tpu_info.get_tpu_type")
|
|
95
|
+
def test_get_num_cores_per_chip(mock_get_tpu_type, tpu_type, expected):
|
|
96
|
+
"""Test get_num_cores_per_chip with different TPU types."""
|
|
97
|
+
mock_get_tpu_type.return_value = tpu_type
|
|
98
|
+
assert get_num_cores_per_chip() == expected
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# Test get_num_chips
|
|
102
|
+
@patch("tpu_inference.tpu_info.glob.glob",
|
|
103
|
+
return_value=["/dev/accel0", "/dev/accel1"])
|
|
104
|
+
def test_get_num_chips_from_accel(mock_glob):
|
|
105
|
+
"""Test get_num_chips when /dev/accel* files exist."""
|
|
106
|
+
assert get_num_chips() == 2
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@patch("tpu_inference.tpu_info.glob.glob", return_value=[])
|
|
110
|
+
@patch("tpu_inference.tpu_info.os.listdir", return_value=["0", "1", "2"])
|
|
111
|
+
def test_get_num_chips_from_vfio(mock_listdir, mock_glob):
|
|
112
|
+
"""Test get_num_chips when /dev/accel* files don't exist but /dev/vfio entries do."""
|
|
113
|
+
assert get_num_chips() == 3
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@patch("tpu_inference.tpu_info.glob.glob", return_value=[])
|
|
117
|
+
@patch("tpu_inference.tpu_info.os.listdir", side_effect=FileNotFoundError)
|
|
118
|
+
def test_get_num_chips_not_found(mock_listdir, mock_glob, caplog):
|
|
119
|
+
"""Test get_num_chips when neither files nor directory are found."""
|
|
120
|
+
assert get_num_chips() == 0
|
tests/test_utils.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
import os
|
|
3
|
+
from unittest.mock import MagicMock, patch
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
# Import the functions to be tested
|
|
9
|
+
from tpu_inference.utils import (GBYTES, enable_megacore,
|
|
10
|
+
get_jax_dtype_from_str_dtype, get_megacore,
|
|
11
|
+
get_padded_head_dim, hbm_usage_bytes,
|
|
12
|
+
hbm_usage_gb, quantize_kv)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def test_enable_and_get_megacore():
|
|
16
|
+
"""Tests the enable_megacore and get_megacore functions."""
|
|
17
|
+
assert not get_megacore()
|
|
18
|
+
enable_megacore()
|
|
19
|
+
assert get_megacore()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@patch.dict(os.environ, {"TPU_MULTIHOST_BACKEND": "ray"})
|
|
23
|
+
def test_hbm_usage_bytes_ray_backend():
|
|
24
|
+
"""Tests hbm_usage_bytes when TPU_MULTIHOST_BACKEND is ray."""
|
|
25
|
+
mock_device1 = MagicMock()
|
|
26
|
+
mock_device1.memory_stats.return_value = {
|
|
27
|
+
"bytes_in_use": 100 * GBYTES,
|
|
28
|
+
"bytes_limit": 128 * GBYTES
|
|
29
|
+
}
|
|
30
|
+
mock_device2 = MagicMock()
|
|
31
|
+
mock_device2.memory_stats.side_effect = Exception("Memory stats failed")
|
|
32
|
+
|
|
33
|
+
devices = [mock_device1, mock_device2]
|
|
34
|
+
usage = hbm_usage_bytes(devices)
|
|
35
|
+
|
|
36
|
+
expected_usage = [(100 * GBYTES, 128 * GBYTES),
|
|
37
|
+
(100 * GBYTES, 128 * GBYTES)]
|
|
38
|
+
assert usage == expected_usage
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", False)
|
|
42
|
+
def test_hbm_usage_bytes_pathways_disabled():
|
|
43
|
+
"""Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is False."""
|
|
44
|
+
mock_device1 = MagicMock()
|
|
45
|
+
mock_device1.memory_stats.return_value = {
|
|
46
|
+
"bytes_in_use": 100 * GBYTES,
|
|
47
|
+
"bytes_limit": 128 * GBYTES
|
|
48
|
+
}
|
|
49
|
+
mock_device2 = MagicMock()
|
|
50
|
+
mock_device2.memory_stats.return_value = {
|
|
51
|
+
"bytes_in_use": 50 * GBYTES,
|
|
52
|
+
"bytes_limit": 128 * GBYTES
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
devices = [mock_device1, mock_device2]
|
|
56
|
+
usage = hbm_usage_bytes(devices)
|
|
57
|
+
|
|
58
|
+
expected_usage = [(100 * GBYTES, 128 * GBYTES),
|
|
59
|
+
(50 * GBYTES, 128 * GBYTES)]
|
|
60
|
+
assert usage == expected_usage
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", True)
|
|
64
|
+
@patch("jax.live_arrays")
|
|
65
|
+
@patch("jax.devices")
|
|
66
|
+
def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
|
|
67
|
+
"""Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is True."""
|
|
68
|
+
# Mock TPU v5p devices
|
|
69
|
+
mock_jax_device = MagicMock()
|
|
70
|
+
mock_jax_device.device_kind = "TPU v5p"
|
|
71
|
+
mock_devices.return_value = [mock_jax_device]
|
|
72
|
+
|
|
73
|
+
# Create mock devices
|
|
74
|
+
mock_device1 = MagicMock()
|
|
75
|
+
mock_device2 = MagicMock()
|
|
76
|
+
devices = [mock_device1, mock_device2]
|
|
77
|
+
|
|
78
|
+
# Create mock arrays with sharding
|
|
79
|
+
mock_array1 = MagicMock()
|
|
80
|
+
mock_array1.dtype.itemsize = 4 # float32
|
|
81
|
+
mock_array1.size = 1000 # 1000 elements
|
|
82
|
+
mock_array1.sharding.device_set = {mock_device1, mock_device2
|
|
83
|
+
} # Sharded across 2 devices
|
|
84
|
+
|
|
85
|
+
mock_array2 = MagicMock()
|
|
86
|
+
mock_array2.dtype.itemsize = 2 # float16
|
|
87
|
+
mock_array2.size = 500 # 500 elements
|
|
88
|
+
mock_array2.sharding.device_set = {mock_device1} # Only on device1
|
|
89
|
+
|
|
90
|
+
mock_live_arrays.return_value = [mock_array1, mock_array2]
|
|
91
|
+
|
|
92
|
+
usage = hbm_usage_bytes(devices)
|
|
93
|
+
|
|
94
|
+
# Expected calculations:
|
|
95
|
+
# Array1: 4 bytes * 1000 elements / 2 devices = 2000 bytes per device
|
|
96
|
+
# Array2: 2 bytes * 500 elements / 1 device = 1000 bytes on device1 only
|
|
97
|
+
# Device1: 2000 + 1000 = 3000 bytes
|
|
98
|
+
# Device2: 2000 + 0 = 2000 bytes
|
|
99
|
+
# hbm_limit = 33550237184 (hardcoded in the function)
|
|
100
|
+
expected_usage = [(3000, 95 * GBYTES), (2000, 95 * GBYTES)]
|
|
101
|
+
assert usage == expected_usage
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", False)
|
|
105
|
+
def test_hbm_usage_gb_pathways_disabled():
|
|
106
|
+
"""Tests hbm_usage_gb when VLLM_TPU_USING_PATHWAYS is False."""
|
|
107
|
+
mock_device1 = MagicMock()
|
|
108
|
+
mock_device1.memory_stats.return_value = {
|
|
109
|
+
"bytes_in_use": 100 * GBYTES,
|
|
110
|
+
"bytes_limit": 128 * GBYTES
|
|
111
|
+
}
|
|
112
|
+
mock_device2 = MagicMock()
|
|
113
|
+
mock_device2.memory_stats.return_value = {
|
|
114
|
+
"bytes_in_use": 50.5 * GBYTES,
|
|
115
|
+
"bytes_limit": 128.0 * GBYTES
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
devices = [mock_device1, mock_device2]
|
|
119
|
+
usage = hbm_usage_gb(devices)
|
|
120
|
+
|
|
121
|
+
expected_usage = [(100.0, 128.0), (50.5, 128.0)]
|
|
122
|
+
assert usage == expected_usage
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", True)
|
|
126
|
+
@patch("jax.live_arrays")
|
|
127
|
+
@patch("jax.devices")
|
|
128
|
+
def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
|
|
129
|
+
"""Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is True but no live arrays."""
|
|
130
|
+
# Mock TPU v5e devices
|
|
131
|
+
mock_jax_device = MagicMock()
|
|
132
|
+
mock_jax_device.device_kind = "TPU v6e"
|
|
133
|
+
mock_devices.return_value = [mock_jax_device]
|
|
134
|
+
|
|
135
|
+
mock_device1 = MagicMock()
|
|
136
|
+
mock_device2 = MagicMock()
|
|
137
|
+
devices = [mock_device1, mock_device2]
|
|
138
|
+
|
|
139
|
+
# No live arrays
|
|
140
|
+
mock_live_arrays.return_value = []
|
|
141
|
+
|
|
142
|
+
usage = hbm_usage_bytes(devices)
|
|
143
|
+
|
|
144
|
+
# No arrays means no memory usage
|
|
145
|
+
expected_usage = [(0, 32 * GBYTES), (0, 32 * GBYTES)]
|
|
146
|
+
assert usage == expected_usage
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@pytest.mark.parametrize(
|
|
150
|
+
"head_dim, expected_padded_head_dim",
|
|
151
|
+
[
|
|
152
|
+
(1, 128),
|
|
153
|
+
(64, 128),
|
|
154
|
+
(127, 128),
|
|
155
|
+
(128, 128),
|
|
156
|
+
(129, 256),
|
|
157
|
+
(255, 256),
|
|
158
|
+
(256, 256),
|
|
159
|
+
(0, 0), # Although head_dim is usually positive, testing boundary
|
|
160
|
+
],
|
|
161
|
+
)
|
|
162
|
+
def test_get_padded_head_dim(head_dim, expected_padded_head_dim):
|
|
163
|
+
"""Tests the get_padded_head_dim function."""
|
|
164
|
+
assert get_padded_head_dim(head_dim) == expected_padded_head_dim
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def test_quantize_kv_float8_e4m3fn():
|
|
168
|
+
"""Tests the quantize_kv function with float8_e4m3fn dtype."""
|
|
169
|
+
key = jnp.array([-1.0, 0.5, 1.0, 1.5])
|
|
170
|
+
value = jnp.array([2.0, 0.0, -2.0, -3.0])
|
|
171
|
+
kv_cache_quantized_dtype = jnp.float8_e4m3fn
|
|
172
|
+
k_scale = 0.1
|
|
173
|
+
v_scale = 0.2
|
|
174
|
+
|
|
175
|
+
quantized_key, quantized_value = quantize_kv(key, value,
|
|
176
|
+
kv_cache_quantized_dtype,
|
|
177
|
+
k_scale, v_scale)
|
|
178
|
+
|
|
179
|
+
# Expected key: key / k_scale -> clip -> astype
|
|
180
|
+
# [-10., 5., 10., 15.] are within float8_e4m3fn range
|
|
181
|
+
expected_key = jnp.array([-10.0, 5.0, 10.0, 15.0], dtype=jnp.float8_e4m3fn)
|
|
182
|
+
|
|
183
|
+
# Expected value: value / v_scale -> clip -> astype
|
|
184
|
+
# [10., 0., -10., -15.] are within float8_e4m3fn range
|
|
185
|
+
expected_value = jnp.array([10.0, 0.0, -10.0, -15.0],
|
|
186
|
+
dtype=jnp.float8_e4m3fn)
|
|
187
|
+
|
|
188
|
+
assert jnp.array_equal(quantized_key, expected_key)
|
|
189
|
+
assert jnp.array_equal(quantized_value, expected_value)
|
|
190
|
+
|
|
191
|
+
# Test clipping
|
|
192
|
+
dtype_info = jnp.finfo(kv_cache_quantized_dtype)
|
|
193
|
+
minval, maxval = float(dtype_info.min), float(dtype_info.max)
|
|
194
|
+
|
|
195
|
+
# Values that will be outside the range after scaling
|
|
196
|
+
key_clip = jnp.array([minval * k_scale * 2, maxval * k_scale * 2])
|
|
197
|
+
value_clip = jnp.array([maxval * v_scale * 2, minval * v_scale * 2])
|
|
198
|
+
quantized_key_clip, quantized_value_clip = quantize_kv(
|
|
199
|
+
key_clip, value_clip, kv_cache_quantized_dtype, k_scale, v_scale)
|
|
200
|
+
|
|
201
|
+
# Values should be clipped to the min/max of the float8 dtype
|
|
202
|
+
expected_key_clip = jnp.array([minval, maxval], dtype=jnp.float8_e4m3fn)
|
|
203
|
+
expected_value_clip = jnp.array([maxval, minval], dtype=jnp.float8_e4m3fn)
|
|
204
|
+
|
|
205
|
+
assert jnp.array_equal(quantized_key_clip, expected_key_clip)
|
|
206
|
+
assert jnp.array_equal(quantized_value_clip, expected_value_clip)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def test_get_jax_dtype_from_str_dtype():
|
|
210
|
+
"""
|
|
211
|
+
Test the get_jax_dtype_from_str_dtype function
|
|
212
|
+
"""
|
|
213
|
+
assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
|
|
214
|
+
assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
|
|
215
|
+
assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
|
|
216
|
+
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3
|
|
217
|
+
assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
|
|
218
|
+
assert get_jax_dtype_from_str_dtype("auto") is None
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
from unittest.mock import Mock, patch
|
|
3
|
+
|
|
4
|
+
from tpu_inference.backend import TPUBackend
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TPUBackendTest(unittest.TestCase):
|
|
8
|
+
|
|
9
|
+
@patch('tpu_inference.backend.TPUWorker')
|
|
10
|
+
def test_tpu_backend_initialization(self, mock_tpu_worker_class):
|
|
11
|
+
"""Test that TPUBackend initializes the worker correctly."""
|
|
12
|
+
mock_host_interface = Mock()
|
|
13
|
+
mock_worker_kwargs = {'worker_arg': 'test_value'}
|
|
14
|
+
|
|
15
|
+
backend = TPUBackend(host_interface=mock_host_interface,
|
|
16
|
+
**mock_worker_kwargs)
|
|
17
|
+
|
|
18
|
+
# Assert that the TPUWorker was instantiated with the correct arguments
|
|
19
|
+
mock_tpu_worker_class.assert_called_once_with(
|
|
20
|
+
host_interface=mock_host_interface, **mock_worker_kwargs)
|
|
21
|
+
|
|
22
|
+
# Assert that the worker attribute is an instance of the mock class
|
|
23
|
+
self.assertEqual(backend.worker, mock_tpu_worker_class.return_value)
|
|
24
|
+
|
|
25
|
+
@patch('tpu_inference.backend.VllmSchedulerOutputAdapter')
|
|
26
|
+
@patch('tpu_inference.backend.TPUWorker')
|
|
27
|
+
def test_launch_tpu_batch(self, mock_tpu_worker_class, mock_adapter_class):
|
|
28
|
+
"""Test that launch_tpu_batch delegates to the worker correctly."""
|
|
29
|
+
mock_worker_instance = mock_tpu_worker_class.return_value
|
|
30
|
+
|
|
31
|
+
backend = TPUBackend()
|
|
32
|
+
mock_batch = Mock()
|
|
33
|
+
|
|
34
|
+
backend.launch_tpu_batch(mock_batch)
|
|
35
|
+
|
|
36
|
+
# Assert that the adapter was created with the correct input
|
|
37
|
+
mock_adapter_class.assert_called_once_with(mock_batch)
|
|
38
|
+
|
|
39
|
+
# Assert that the worker's execute_model method was called with the mock adapter's return value
|
|
40
|
+
mock_worker_instance.execute_model.assert_called_once_with(
|
|
41
|
+
mock_adapter_class.return_value)
|
|
42
|
+
|
|
43
|
+
@patch('tpu_inference.backend.VllmLoRARequestAdapter')
|
|
44
|
+
@patch('tpu_inference.backend.TPUWorker')
|
|
45
|
+
def test_add_lora(self, mock_tpu_worker_class, mock_adapter_class):
|
|
46
|
+
"""Test that add_lora delegates to the worker correctly."""
|
|
47
|
+
mock_worker_instance = mock_tpu_worker_class.return_value
|
|
48
|
+
|
|
49
|
+
backend = TPUBackend()
|
|
50
|
+
mock_lora_request = Mock()
|
|
51
|
+
|
|
52
|
+
backend.add_lora(mock_lora_request)
|
|
53
|
+
|
|
54
|
+
# Assert that the adapter was created with the correct input
|
|
55
|
+
mock_adapter_class.assert_called_once_with(mock_lora_request)
|
|
56
|
+
|
|
57
|
+
# Assert that the worker's add_lora method was called with the mock adapter's return value
|
|
58
|
+
mock_worker_instance.add_lora.assert_called_once_with(
|
|
59
|
+
mock_adapter_class.return_value)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from tpu_inference import tpu_info as ti
|
|
4
|
+
from tpu_inference.logger import init_logger
|
|
5
|
+
|
|
6
|
+
logger = init_logger(__name__)
|
|
7
|
+
|
|
8
|
+
if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
|
|
9
|
+
logger.info("Running vLLM on TPU via Pathways proxy.")
|
|
10
|
+
# Must run pathwaysutils.initialize() before any JAX operations
|
|
11
|
+
try:
|
|
12
|
+
import pathwaysutils
|
|
13
|
+
pathwaysutils.initialize()
|
|
14
|
+
logger.info("Module pathwaysutils is imported.")
|
|
15
|
+
except Exception as e:
|
|
16
|
+
logger.error(
|
|
17
|
+
f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
|
|
18
|
+
)
|
|
19
|
+
else:
|
|
20
|
+
# Either running on TPU or CPU
|
|
21
|
+
try:
|
|
22
|
+
logger.info(f"TPU info: node_name={ti.get_node_name()} | "
|
|
23
|
+
f"tpu_type={ti.get_tpu_type()} | "
|
|
24
|
+
f"worker_id={ti.get_node_worker_id()} | "
|
|
25
|
+
f"num_chips={ti.get_num_chips()} | "
|
|
26
|
+
f"num_cores_per_chip={ti.get_num_cores_per_chip()}")
|
|
27
|
+
except Exception as e:
|
|
28
|
+
logger.error(
|
|
29
|
+
f"Error occurred while logging TPU info: {e}. Are you running on CPU?"
|
|
30
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
from vllm.lora.request import LoRARequest
|
|
4
|
+
from vllm.v1.core.sched.output import SchedulerOutput
|
|
5
|
+
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|
6
|
+
from vllm.v1.outputs import ModelRunnerOutput
|
|
7
|
+
|
|
8
|
+
from tpu_inference.di.abstracts import (AbstractKVCacheConfig,
|
|
9
|
+
AbstractKVCacheSpec,
|
|
10
|
+
AbstractLoRARequest,
|
|
11
|
+
AbstractModelRunnerOutput,
|
|
12
|
+
AbstractSchedulerOutput)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class VllmModelRunnerOutputAdapter(AbstractModelRunnerOutput):
|
|
16
|
+
|
|
17
|
+
def __init__(self, vllm_output: ModelRunnerOutput):
|
|
18
|
+
self.vllm_output = vllm_output
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class VllmSchedulerOutputAdapter(AbstractSchedulerOutput):
|
|
22
|
+
|
|
23
|
+
def __init__(self, vllm_scheduler_output: SchedulerOutput):
|
|
24
|
+
self.vllm_scheduler_output = vllm_scheduler_output
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class VllmLoRARequestAdapter(AbstractLoRARequest):
|
|
28
|
+
|
|
29
|
+
def __init__(self, vllm_lora_request: LoRARequest):
|
|
30
|
+
self.vllm_lora_request = vllm_lora_request
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class VllmKVCacheConfigAdapter(AbstractKVCacheConfig):
|
|
34
|
+
|
|
35
|
+
def __init__(self, vllm_kv_cache_config: KVCacheConfig):
|
|
36
|
+
self.vllm_kv_cache_config = vllm_kv_cache_config
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class VllmKVCacheSpecAdapter(AbstractKVCacheSpec):
|
|
40
|
+
|
|
41
|
+
def __init__(self, vllm_kv_cache_spec: KVCacheSpec):
|
|
42
|
+
self.vllm_kv_cache_spec = vllm_kv_cache_spec
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Adapters for wrapping concrete vLLM config objects in tpu_inference interfaces.
|
|
3
|
+
"""
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from tpu_inference.interfaces.config_parts import (ICacheConfig,
|
|
9
|
+
ICompilationConfig,
|
|
10
|
+
IModelConfig,
|
|
11
|
+
IParallelConfig,
|
|
12
|
+
ISchedulerConfig)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class VllmCacheConfigAdapter(ICacheConfig):
|
|
16
|
+
|
|
17
|
+
def __init__(self, vllm_cache_config: Any):
|
|
18
|
+
self._vllm_cache_config = vllm_cache_config
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def block_size(self) -> Optional[int]:
|
|
22
|
+
return self._vllm_cache_config.block_size
|
|
23
|
+
|
|
24
|
+
@block_size.setter
|
|
25
|
+
def block_size(self, value: Optional[int]) -> None:
|
|
26
|
+
self._vllm_cache_config.block_size = value
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class VllmCompilationConfigAdapter(ICompilationConfig):
|
|
30
|
+
|
|
31
|
+
def __init__(self, vllm_compilation_config: Any):
|
|
32
|
+
self._vllm_compilation_config = vllm_compilation_config
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def level(self) -> Any:
|
|
36
|
+
return self._vllm_compilation_config.level
|
|
37
|
+
|
|
38
|
+
@level.setter
|
|
39
|
+
def level(self, value: Any) -> None:
|
|
40
|
+
self._vllm_compilation_config.level = value
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def backend(self) -> str:
|
|
44
|
+
return self._vllm_compilation_config.backend
|
|
45
|
+
|
|
46
|
+
@backend.setter
|
|
47
|
+
def backend(self, value: str) -> None:
|
|
48
|
+
self._vllm_compilation_config.backend = value
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class VllmModelConfigAdapter(IModelConfig):
|
|
52
|
+
|
|
53
|
+
def __init__(self, vllm_model_config: Any):
|
|
54
|
+
self._vllm_model_config = vllm_model_config
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def dtype(self) -> torch.dtype:
|
|
58
|
+
return self._vllm_model_config.dtype
|
|
59
|
+
|
|
60
|
+
@dtype.setter
|
|
61
|
+
def dtype(self, value: torch.dtype) -> None:
|
|
62
|
+
self._vllm_model_config.dtype = value
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def use_mla(self) -> bool:
|
|
66
|
+
return self._vllm_model_config.use_mla
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class VllmParallelConfigAdapter(IParallelConfig):
|
|
70
|
+
|
|
71
|
+
def __init__(self, vllm_parallel_config: Any):
|
|
72
|
+
self._vllm_parallel_config = vllm_parallel_config
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def worker_cls(self) -> str:
|
|
76
|
+
return self._vllm_parallel_config.worker_cls
|
|
77
|
+
|
|
78
|
+
@worker_cls.setter
|
|
79
|
+
def worker_cls(self, value: str) -> None:
|
|
80
|
+
self._vllm_parallel_config.worker_cls = value
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class VllmSchedulerConfigAdapter(ISchedulerConfig):
|
|
84
|
+
|
|
85
|
+
def __init__(self, vllm_scheduler_config: Any):
|
|
86
|
+
self._vllm_scheduler_config = vllm_scheduler_config
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def max_num_seqs(self) -> int:
|
|
90
|
+
return self._vllm_scheduler_config.max_num_seqs
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def is_multi_step(self) -> bool:
|
|
94
|
+
return self._vllm_scheduler_config.is_multi_step
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def is_multimodal_model(self) -> bool:
|
|
98
|
+
return self._vllm_scheduler_config.is_multimodal_model
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def disable_chunked_mm_input(self) -> bool:
|
|
102
|
+
return self._vllm_scheduler_config.disable_chunked_mm_input
|
|
103
|
+
|
|
104
|
+
@disable_chunked_mm_input.setter
|
|
105
|
+
def disable_chunked_mm_input(self, value: bool) -> None:
|
|
106
|
+
self._vllm_scheduler_config.disable_chunked_mm_input = value
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def enable_chunked_prefill(self) -> bool:
|
|
110
|
+
return self._vllm_scheduler_config.enable_chunked_prefill
|
|
111
|
+
|
|
112
|
+
@enable_chunked_prefill.setter
|
|
113
|
+
def enable_chunked_prefill(self, value: bool) -> None:
|
|
114
|
+
self._vllm_scheduler_config.enable_chunked_prefill = value
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def chunked_prefill_enabled(self) -> bool:
|
|
118
|
+
return self._vllm_scheduler_config.chunked_prefill_enabled
|
|
119
|
+
|
|
120
|
+
@chunked_prefill_enabled.setter
|
|
121
|
+
def chunked_prefill_enabled(self, value: bool) -> None:
|
|
122
|
+
self._vllm_scheduler_config.chunked_prefill_enabled = value
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def max_model_len(self) -> int:
|
|
126
|
+
return self._vllm_scheduler_config.max_model_len
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def max_num_batched_tokens(self) -> int:
|
|
130
|
+
return self._vllm_scheduler_config.max_num_batched_tokens
|
|
131
|
+
|
|
132
|
+
@max_num_batched_tokens.setter
|
|
133
|
+
def max_num_batched_tokens(self, value: int) -> None:
|
|
134
|
+
self._vllm_scheduler_config.max_num_batched_tokens = value
|
tpu_inference/backend.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2025 Google LLC
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
from tpu_inference.adapters.vllm_adapters import (VllmLoRARequestAdapter,
|
|
19
|
+
VllmSchedulerOutputAdapter)
|
|
20
|
+
from tpu_inference.di.interfaces import BackendInterface, HostInterface
|
|
21
|
+
from tpu_inference.worker.base import AbstractTpuWorker
|
|
22
|
+
from tpu_inference.worker.tpu_worker_jax import TPUWorker
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TPUBackend(BackendInterface):
|
|
26
|
+
"""
|
|
27
|
+
The main entry point for the host system to interact with the TPU backend.
|
|
28
|
+
|
|
29
|
+
This class implements the BackendInterface. It is responsible for creating
|
|
30
|
+
and managing the concrete TPU worker instance and delegating calls to it.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self,
|
|
34
|
+
host_interface: Optional[HostInterface] = None,
|
|
35
|
+
**worker_kwargs):
|
|
36
|
+
"""
|
|
37
|
+
Initializes the TPUBackend.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
host_interface: An optional object that implements the HostInterface,
|
|
41
|
+
providing a way for the backend to communicate with the host.
|
|
42
|
+
**worker_kwargs: Additional keyword arguments to be passed to the
|
|
43
|
+
worker's constructor.
|
|
44
|
+
"""
|
|
45
|
+
self.worker: AbstractTpuWorker = TPUWorker(
|
|
46
|
+
host_interface=host_interface, **worker_kwargs)
|
|
47
|
+
|
|
48
|
+
def launch_tpu_batch(self, batch_to_launch):
|
|
49
|
+
"""
|
|
50
|
+
Launches a batch of requests on the TPU worker and returns the result.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
batch_to_launch: The batch of requests to be processed.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
The result of the model execution.
|
|
57
|
+
"""
|
|
58
|
+
adapted_batch = VllmSchedulerOutputAdapter(batch_to_launch)
|
|
59
|
+
return self.worker.execute_model(adapted_batch)
|
|
60
|
+
|
|
61
|
+
def add_lora(self, lora_request):
|
|
62
|
+
"""
|
|
63
|
+
Adds a LoRA adapter to the worker.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
lora_request: The LoRA request to be processed.
|
|
67
|
+
"""
|
|
68
|
+
adapted_lora_request = VllmLoRARequestAdapter(lora_request)
|
|
69
|
+
return self.worker.add_lora(adapted_lora_request)
|
|
File without changes
|