tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
tests/test_tpu_info.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from unittest.mock import MagicMock, patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
from tpu_inference.tpu_info import (get_node_name, get_node_worker_id,
|
|
8
|
+
get_num_chips, get_num_cores_per_chip,
|
|
9
|
+
get_tpu_metadata, get_tpu_type)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Mock requests.get for get_tpu_metadata tests
|
|
13
|
+
@patch("tpu_inference.tpu_info.requests.get")
|
|
14
|
+
def test_get_tpu_metadata_success(mock_get):
|
|
15
|
+
"""Test get_tpu_metadata when the request is successful."""
|
|
16
|
+
mock_response = MagicMock()
|
|
17
|
+
mock_response.status_code = 200
|
|
18
|
+
mock_response.text = "test_metadata_value"
|
|
19
|
+
mock_get.return_value = mock_response
|
|
20
|
+
assert get_tpu_metadata(key="test-key") == "test_metadata_value"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@patch("tpu_inference.tpu_info.requests.get")
|
|
24
|
+
def test_get_tpu_metadata_request_error(mock_get):
|
|
25
|
+
"""Test get_tpu_metadata when a RequestException is raised."""
|
|
26
|
+
mock_get.side_effect = requests.RequestException("Test RequestException")
|
|
27
|
+
assert get_tpu_metadata(key="test-key") is None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Test get_tpu_type
|
|
31
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata")
|
|
32
|
+
@patch.dict(os.environ, {"TPU_ACCELERATOR_TYPE": "env_tpu_type"})
|
|
33
|
+
def test_get_tpu_type_from_env(mock_get_tpu_metadata):
|
|
34
|
+
"""Test get_tpu_type when TPU_ACCELERATOR_TYPE is set in environment."""
|
|
35
|
+
# The function should return the env var value and not call get_tpu_metadata
|
|
36
|
+
assert get_tpu_type() == "env_tpu_type"
|
|
37
|
+
mock_get_tpu_metadata.assert_not_called()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@patch.dict(os.environ, {}, clear=True)
|
|
41
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata",
|
|
42
|
+
return_value="metadata_tpu_type")
|
|
43
|
+
def test_get_tpu_type_from_metadata(mock_get_tpu_metadata):
|
|
44
|
+
"""Test get_tpu_type when environment variable is not set."""
|
|
45
|
+
assert get_tpu_type() == "metadata_tpu_type"
|
|
46
|
+
mock_get_tpu_metadata.assert_called_once_with(key="accelerator-type")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# Test get_node_name
|
|
50
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata")
|
|
51
|
+
@patch.dict(os.environ, {"TPU_NAME": "env_tpu_name"})
|
|
52
|
+
def test_get_node_name_from_env(mock_get_tpu_metadata):
|
|
53
|
+
"""Test get_node_name when TPU_NAME is set in environment."""
|
|
54
|
+
assert get_node_name() == "env_tpu_name"
|
|
55
|
+
mock_get_tpu_metadata.assert_not_called()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@patch.dict(os.environ, {}, clear=True)
|
|
59
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata",
|
|
60
|
+
return_value="metadata_tpu_name")
|
|
61
|
+
def test_get_node_name_from_metadata(mock_get_tpu_metadata):
|
|
62
|
+
"""Test get_node_name when environment variable is not set."""
|
|
63
|
+
assert get_node_name() == "metadata_tpu_name"
|
|
64
|
+
mock_get_tpu_metadata.assert_called_once_with(key="instance-id")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# Test get_node_worker_id
|
|
68
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata")
|
|
69
|
+
@patch.dict(os.environ, {"TPU_WORKER_ID": "5"})
|
|
70
|
+
def test_get_node_worker_id_from_env(mock_get_tpu_metadata):
|
|
71
|
+
"""Test get_node_worker_id when TPU_WORKER_ID is set in environment."""
|
|
72
|
+
assert get_node_worker_id() == 5
|
|
73
|
+
mock_get_tpu_metadata.assert_not_called()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@patch.dict(os.environ, {}, clear=True)
|
|
77
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata", return_value="10")
|
|
78
|
+
def test_get_node_worker_id_from_metadata(mock_get_tpu_metadata):
|
|
79
|
+
"""Test get_node_worker_id when environment variable is not set."""
|
|
80
|
+
assert get_node_worker_id() == 10
|
|
81
|
+
mock_get_tpu_metadata.assert_called_once_with(key="agent-worker-number")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# Test get_num_cores_per_chip
|
|
85
|
+
@pytest.mark.parametrize(
|
|
86
|
+
"tpu_type, expected",
|
|
87
|
+
[
|
|
88
|
+
("v5litepod-4", 1),
|
|
89
|
+
("v6e-8", 1),
|
|
90
|
+
("v4-8", 2),
|
|
91
|
+
("v5p-16", 2),
|
|
92
|
+
("unknown-type", 2) # Default case
|
|
93
|
+
])
|
|
94
|
+
@patch("tpu_inference.tpu_info.get_tpu_type")
|
|
95
|
+
def test_get_num_cores_per_chip(mock_get_tpu_type, tpu_type, expected):
|
|
96
|
+
"""Test get_num_cores_per_chip with different TPU types."""
|
|
97
|
+
mock_get_tpu_type.return_value = tpu_type
|
|
98
|
+
assert get_num_cores_per_chip() == expected
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# Test get_num_chips
|
|
102
|
+
@patch("tpu_inference.tpu_info.glob.glob",
|
|
103
|
+
return_value=["/dev/accel0", "/dev/accel1"])
|
|
104
|
+
def test_get_num_chips_from_accel(mock_glob):
|
|
105
|
+
"""Test get_num_chips when /dev/accel* files exist."""
|
|
106
|
+
assert get_num_chips() == 2
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@patch("tpu_inference.tpu_info.glob.glob", return_value=[])
|
|
110
|
+
@patch("tpu_inference.tpu_info.os.listdir", return_value=["0", "1", "2"])
|
|
111
|
+
def test_get_num_chips_from_vfio(mock_listdir, mock_glob):
|
|
112
|
+
"""Test get_num_chips when /dev/accel* files don't exist but /dev/vfio entries do."""
|
|
113
|
+
assert get_num_chips() == 3
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@patch("tpu_inference.tpu_info.glob.glob", return_value=[])
|
|
117
|
+
@patch("tpu_inference.tpu_info.os.listdir", side_effect=FileNotFoundError)
|
|
118
|
+
def test_get_num_chips_not_found(mock_listdir, mock_glob, caplog):
|
|
119
|
+
"""Test get_num_chips when neither files nor directory are found."""
|
|
120
|
+
assert get_num_chips() == 0
|
tests/test_utils.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
import os
|
|
3
|
+
from unittest.mock import MagicMock, patch
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
# Import the functions to be tested
|
|
9
|
+
from tpu_inference.utils import (GBYTES, enable_megacore,
|
|
10
|
+
get_jax_dtype_from_str_dtype, get_megacore,
|
|
11
|
+
get_padded_head_dim, hbm_usage_bytes,
|
|
12
|
+
hbm_usage_gb, quantize_kv)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def test_enable_and_get_megacore():
|
|
16
|
+
"""Tests the enable_megacore and get_megacore functions."""
|
|
17
|
+
assert not get_megacore()
|
|
18
|
+
enable_megacore()
|
|
19
|
+
assert get_megacore()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@patch.dict(os.environ, {"TPU_MULTIHOST_BACKEND": "ray"})
|
|
23
|
+
def test_hbm_usage_bytes_ray_backend():
|
|
24
|
+
"""Tests hbm_usage_bytes when TPU_MULTIHOST_BACKEND is ray."""
|
|
25
|
+
mock_device1 = MagicMock()
|
|
26
|
+
mock_device1.memory_stats.return_value = {
|
|
27
|
+
"bytes_in_use": 100 * GBYTES,
|
|
28
|
+
"bytes_limit": 128 * GBYTES
|
|
29
|
+
}
|
|
30
|
+
mock_device2 = MagicMock()
|
|
31
|
+
mock_device2.memory_stats.side_effect = Exception("Memory stats failed")
|
|
32
|
+
|
|
33
|
+
devices = [mock_device1, mock_device2]
|
|
34
|
+
usage = hbm_usage_bytes(devices)
|
|
35
|
+
|
|
36
|
+
expected_usage = [(100 * GBYTES, 128 * GBYTES),
|
|
37
|
+
(100 * GBYTES, 128 * GBYTES)]
|
|
38
|
+
assert usage == expected_usage
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", False)
|
|
42
|
+
def test_hbm_usage_bytes_pathways_disabled():
|
|
43
|
+
"""Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is False."""
|
|
44
|
+
mock_device1 = MagicMock()
|
|
45
|
+
mock_device1.memory_stats.return_value = {
|
|
46
|
+
"bytes_in_use": 100 * GBYTES,
|
|
47
|
+
"bytes_limit": 128 * GBYTES
|
|
48
|
+
}
|
|
49
|
+
mock_device2 = MagicMock()
|
|
50
|
+
mock_device2.memory_stats.return_value = {
|
|
51
|
+
"bytes_in_use": 50 * GBYTES,
|
|
52
|
+
"bytes_limit": 128 * GBYTES
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
devices = [mock_device1, mock_device2]
|
|
56
|
+
usage = hbm_usage_bytes(devices)
|
|
57
|
+
|
|
58
|
+
expected_usage = [(100 * GBYTES, 128 * GBYTES),
|
|
59
|
+
(50 * GBYTES, 128 * GBYTES)]
|
|
60
|
+
assert usage == expected_usage
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", True)
|
|
64
|
+
@patch("jax.live_arrays")
|
|
65
|
+
@patch("jax.devices")
|
|
66
|
+
def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
|
|
67
|
+
"""Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is True."""
|
|
68
|
+
# Mock TPU v5p devices
|
|
69
|
+
mock_jax_device = MagicMock()
|
|
70
|
+
mock_jax_device.device_kind = "TPU v5p"
|
|
71
|
+
mock_devices.return_value = [mock_jax_device]
|
|
72
|
+
|
|
73
|
+
# Create mock devices
|
|
74
|
+
mock_device1 = MagicMock()
|
|
75
|
+
mock_device2 = MagicMock()
|
|
76
|
+
devices = [mock_device1, mock_device2]
|
|
77
|
+
|
|
78
|
+
# Create mock addressable shards with data property
|
|
79
|
+
mock_data1_dev1 = MagicMock()
|
|
80
|
+
mock_data1_dev1.device = mock_device1
|
|
81
|
+
mock_data1_dev1.nbytes = 2000 # 2000 bytes on device1
|
|
82
|
+
|
|
83
|
+
mock_data1_dev2 = MagicMock()
|
|
84
|
+
mock_data1_dev2.device = mock_device2
|
|
85
|
+
mock_data1_dev2.nbytes = 2000 # 2000 bytes on device2
|
|
86
|
+
|
|
87
|
+
mock_data2_dev1 = MagicMock()
|
|
88
|
+
mock_data2_dev1.device = mock_device1
|
|
89
|
+
mock_data2_dev1.nbytes = 1000 # 1000 bytes on device1
|
|
90
|
+
|
|
91
|
+
mock_shard1_dev1 = MagicMock()
|
|
92
|
+
mock_shard1_dev1.data = mock_data1_dev1
|
|
93
|
+
|
|
94
|
+
mock_shard1_dev2 = MagicMock()
|
|
95
|
+
mock_shard1_dev2.data = mock_data1_dev2
|
|
96
|
+
|
|
97
|
+
mock_shard2_dev1 = MagicMock()
|
|
98
|
+
mock_shard2_dev1.data = mock_data2_dev1
|
|
99
|
+
|
|
100
|
+
# Create mock arrays with addressable_shards
|
|
101
|
+
mock_array1 = MagicMock()
|
|
102
|
+
mock_array1.addressable_shards = [mock_shard1_dev1, mock_shard1_dev2]
|
|
103
|
+
|
|
104
|
+
mock_array2 = MagicMock()
|
|
105
|
+
mock_array2.addressable_shards = [mock_shard2_dev1]
|
|
106
|
+
|
|
107
|
+
mock_live_arrays.return_value = [mock_array1, mock_array2]
|
|
108
|
+
|
|
109
|
+
usage = hbm_usage_bytes(devices)
|
|
110
|
+
|
|
111
|
+
# Expected calculations:
|
|
112
|
+
# Array1: 2000 bytes on device1, 2000 bytes on device2
|
|
113
|
+
# Array2: 1000 bytes on device1
|
|
114
|
+
# Device1 total: 2000 + 1000 = 3000 bytes
|
|
115
|
+
# Device2 total: 2000 + 0 = 2000 bytes
|
|
116
|
+
# hbm_limit = 95 * GBYTES for TPU v5p
|
|
117
|
+
expected_usage = [(3000, 95 * GBYTES), (2000, 95 * GBYTES)]
|
|
118
|
+
assert usage == expected_usage
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", False)
|
|
122
|
+
def test_hbm_usage_gb_pathways_disabled():
|
|
123
|
+
"""Tests hbm_usage_gb when VLLM_TPU_USING_PATHWAYS is False."""
|
|
124
|
+
mock_device1 = MagicMock()
|
|
125
|
+
mock_device1.memory_stats.return_value = {
|
|
126
|
+
"bytes_in_use": 100 * GBYTES,
|
|
127
|
+
"bytes_limit": 128 * GBYTES
|
|
128
|
+
}
|
|
129
|
+
mock_device2 = MagicMock()
|
|
130
|
+
mock_device2.memory_stats.return_value = {
|
|
131
|
+
"bytes_in_use": 50.5 * GBYTES,
|
|
132
|
+
"bytes_limit": 128.0 * GBYTES
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
devices = [mock_device1, mock_device2]
|
|
136
|
+
usage = hbm_usage_gb(devices)
|
|
137
|
+
|
|
138
|
+
expected_usage = [(100.0, 128.0), (50.5, 128.0)]
|
|
139
|
+
assert usage == expected_usage
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", True)
|
|
143
|
+
@patch("jax.live_arrays")
|
|
144
|
+
@patch("jax.devices")
|
|
145
|
+
def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
|
|
146
|
+
"""Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is True but no live arrays."""
|
|
147
|
+
# Mock TPU v6e devices
|
|
148
|
+
mock_jax_device = MagicMock()
|
|
149
|
+
mock_jax_device.device_kind = "TPU v6e"
|
|
150
|
+
mock_devices.return_value = [mock_jax_device]
|
|
151
|
+
|
|
152
|
+
mock_device1 = MagicMock()
|
|
153
|
+
mock_device2 = MagicMock()
|
|
154
|
+
devices = [mock_device1, mock_device2]
|
|
155
|
+
|
|
156
|
+
# No live arrays
|
|
157
|
+
mock_live_arrays.return_value = []
|
|
158
|
+
|
|
159
|
+
usage = hbm_usage_bytes(devices)
|
|
160
|
+
|
|
161
|
+
# No arrays means no memory usage, defaultdict returns 0 for missing keys
|
|
162
|
+
# HBM limit for TPU v6e is 32 GB
|
|
163
|
+
expected_usage = [(0, 32 * GBYTES), (0, 32 * GBYTES)]
|
|
164
|
+
assert usage == expected_usage
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@pytest.mark.parametrize(
|
|
168
|
+
"head_dim, expected_padded_head_dim",
|
|
169
|
+
[
|
|
170
|
+
(1, 128),
|
|
171
|
+
(64, 64),
|
|
172
|
+
(127, 128),
|
|
173
|
+
(128, 128),
|
|
174
|
+
(129, 256),
|
|
175
|
+
(255, 256),
|
|
176
|
+
(256, 256),
|
|
177
|
+
(0, 0), # Although head_dim is usually positive, testing boundary
|
|
178
|
+
],
|
|
179
|
+
)
|
|
180
|
+
def test_get_padded_head_dim(head_dim, expected_padded_head_dim):
|
|
181
|
+
"""Tests the get_padded_head_dim function."""
|
|
182
|
+
assert get_padded_head_dim(head_dim) == expected_padded_head_dim
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def test_quantize_kv_float8_e4m3fn():
|
|
186
|
+
"""Tests the quantize_kv function with float8_e4m3fn dtype."""
|
|
187
|
+
key = jnp.array([-1.0, 0.5, 1.0, 1.5])
|
|
188
|
+
value = jnp.array([2.0, 0.0, -2.0, -3.0])
|
|
189
|
+
kv_cache_quantized_dtype = jnp.float8_e4m3fn
|
|
190
|
+
k_scale = 0.1
|
|
191
|
+
v_scale = 0.2
|
|
192
|
+
|
|
193
|
+
quantized_key, quantized_value = quantize_kv(key, value,
|
|
194
|
+
kv_cache_quantized_dtype,
|
|
195
|
+
k_scale, v_scale)
|
|
196
|
+
|
|
197
|
+
# Expected key: key / k_scale -> clip -> astype
|
|
198
|
+
# [-10., 5., 10., 15.] are within float8_e4m3fn range
|
|
199
|
+
expected_key = jnp.array([-10.0, 5.0, 10.0, 15.0], dtype=jnp.float8_e4m3fn)
|
|
200
|
+
|
|
201
|
+
# Expected value: value / v_scale -> clip -> astype
|
|
202
|
+
# [10., 0., -10., -15.] are within float8_e4m3fn range
|
|
203
|
+
expected_value = jnp.array([10.0, 0.0, -10.0, -15.0],
|
|
204
|
+
dtype=jnp.float8_e4m3fn)
|
|
205
|
+
|
|
206
|
+
assert jnp.array_equal(quantized_key, expected_key)
|
|
207
|
+
assert jnp.array_equal(quantized_value, expected_value)
|
|
208
|
+
|
|
209
|
+
# Test clipping
|
|
210
|
+
dtype_info = jnp.finfo(kv_cache_quantized_dtype)
|
|
211
|
+
minval, maxval = float(dtype_info.min), float(dtype_info.max)
|
|
212
|
+
|
|
213
|
+
# Values that will be outside the range after scaling
|
|
214
|
+
key_clip = jnp.array([minval * k_scale * 2, maxval * k_scale * 2])
|
|
215
|
+
value_clip = jnp.array([maxval * v_scale * 2, minval * v_scale * 2])
|
|
216
|
+
quantized_key_clip, quantized_value_clip = quantize_kv(
|
|
217
|
+
key_clip, value_clip, kv_cache_quantized_dtype, k_scale, v_scale)
|
|
218
|
+
|
|
219
|
+
# Values should be clipped to the min/max of the float8 dtype
|
|
220
|
+
expected_key_clip = jnp.array([minval, maxval], dtype=jnp.float8_e4m3fn)
|
|
221
|
+
expected_value_clip = jnp.array([maxval, minval], dtype=jnp.float8_e4m3fn)
|
|
222
|
+
|
|
223
|
+
assert jnp.array_equal(quantized_key_clip, expected_key_clip)
|
|
224
|
+
assert jnp.array_equal(quantized_value_clip, expected_value_clip)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def test_get_jax_dtype_from_str_dtype():
|
|
228
|
+
"""
|
|
229
|
+
Test the get_jax_dtype_from_str_dtype function
|
|
230
|
+
"""
|
|
231
|
+
assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
|
|
232
|
+
assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
|
|
233
|
+
assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
|
|
234
|
+
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3
|
|
235
|
+
assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
|
|
236
|
+
assert get_jax_dtype_from_str_dtype("auto") is None
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
# The environment variables override should be imported before any other
|
|
4
|
+
# modules to ensure that the environment variables are set before any
|
|
5
|
+
# other modules are imported.
|
|
6
|
+
import tpu_inference.env_override # noqa: F401
|
|
7
|
+
from tpu_inference import tpu_info as ti
|
|
8
|
+
from tpu_inference.logger import init_logger
|
|
9
|
+
|
|
10
|
+
logger = init_logger(__name__)
|
|
11
|
+
|
|
12
|
+
if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
|
|
13
|
+
logger.info("Running vLLM on TPU via Pathways proxy.")
|
|
14
|
+
# Must run pathwaysutils.initialize() before any JAX operations
|
|
15
|
+
try:
|
|
16
|
+
import pathwaysutils
|
|
17
|
+
pathwaysutils.initialize()
|
|
18
|
+
logger.info("Module pathwaysutils is imported.")
|
|
19
|
+
except Exception as e:
|
|
20
|
+
logger.error(
|
|
21
|
+
f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
|
|
22
|
+
)
|
|
23
|
+
else:
|
|
24
|
+
# Either running on TPU or CPU
|
|
25
|
+
try:
|
|
26
|
+
logger.info(f"TPU info: node_name={ti.get_node_name()} | "
|
|
27
|
+
f"tpu_type={ti.get_tpu_type()} | "
|
|
28
|
+
f"worker_id={ti.get_node_worker_id()} | "
|
|
29
|
+
f"num_chips={ti.get_num_chips()} | "
|
|
30
|
+
f"num_cores_per_chip={ti.get_num_cores_per_chip()}")
|
|
31
|
+
except Exception as e:
|
|
32
|
+
logger.error(
|
|
33
|
+
f"Error occurred while logging TPU info: {e}. Are you running on CPU?"
|
|
34
|
+
)
|
|
File without changes
|