tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,414 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock, patch
|
|
16
|
+
|
|
17
|
+
import pytest
|
|
18
|
+
from vllm.config import ModelConfig
|
|
19
|
+
from vllm.lora.request import LoRARequest
|
|
20
|
+
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
21
|
+
from vllm.v1.outputs import DraftTokenIds
|
|
22
|
+
|
|
23
|
+
# The class we are testing
|
|
24
|
+
from tpu_inference.worker.tpu_worker import TPUWorker
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.fixture
|
|
28
|
+
def mock_vllm_config():
|
|
29
|
+
"""
|
|
30
|
+
Provides a mock VllmConfig object for tests.
|
|
31
|
+
This version builds the mock explicitly to avoid spec-related AttributeErrors.
|
|
32
|
+
"""
|
|
33
|
+
# Create mocks for the nested config objects first
|
|
34
|
+
mock_cache_conf = MagicMock()
|
|
35
|
+
mock_cache_conf.gpu_memory_utilization = 0.9
|
|
36
|
+
mock_cache_conf.num_gpu_blocks = 0
|
|
37
|
+
mock_cache_conf.num_cpu_blocks = 0
|
|
38
|
+
|
|
39
|
+
mock_parallel_conf = MagicMock()
|
|
40
|
+
mock_parallel_conf.tensor_parallel_size = 2
|
|
41
|
+
mock_parallel_conf.data_parallel_size = 1
|
|
42
|
+
mock_parallel_conf.pipeline_parallel_size = 1
|
|
43
|
+
mock_parallel_conf.nnodes = 1
|
|
44
|
+
mock_parallel_conf.nnodes_within_dp = 1
|
|
45
|
+
|
|
46
|
+
mock_additional_config = {}
|
|
47
|
+
|
|
48
|
+
# Create the main config mock and attach the others without a top-level spec
|
|
49
|
+
config = MagicMock()
|
|
50
|
+
config.model_config = ModelConfig(model="Qwen/Qwen3-0.6B")
|
|
51
|
+
config.cache_config = mock_cache_conf
|
|
52
|
+
config.parallel_config = mock_parallel_conf
|
|
53
|
+
config.additional_config = mock_additional_config
|
|
54
|
+
|
|
55
|
+
config.sharding_config = MagicMock()
|
|
56
|
+
config.sharding_config.total_devices = 2
|
|
57
|
+
|
|
58
|
+
return config
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class TestTPUWorker:
|
|
62
|
+
"""Test suite for the TPUWorker class."""
|
|
63
|
+
|
|
64
|
+
#
|
|
65
|
+
# --- Initialization Tests ---
|
|
66
|
+
#
|
|
67
|
+
|
|
68
|
+
def test_init_success(self, mock_vllm_config):
|
|
69
|
+
"""Tests successful initialization of TPUWorker."""
|
|
70
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
71
|
+
local_rank=0,
|
|
72
|
+
rank=0,
|
|
73
|
+
distributed_init_method="test_method",
|
|
74
|
+
is_driver_worker=True,
|
|
75
|
+
devices=['tpu:0'])
|
|
76
|
+
assert worker.vllm_config == mock_vllm_config
|
|
77
|
+
assert worker.rank == 0
|
|
78
|
+
assert worker.local_rank == 0
|
|
79
|
+
assert worker.is_driver_worker
|
|
80
|
+
assert worker.profile_dir is None
|
|
81
|
+
assert worker.devices == ['tpu:0']
|
|
82
|
+
|
|
83
|
+
@patch('tpu_inference.worker.tpu_worker.vllm_envs')
|
|
84
|
+
def test_init_with_profiler_on_rank_zero(self, mock_envs,
|
|
85
|
+
mock_vllm_config):
|
|
86
|
+
"""Tests that the profiler directory is set correctly on rank 0."""
|
|
87
|
+
mock_envs.VLLM_TORCH_PROFILER_DIR = "/tmp/profiles"
|
|
88
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
89
|
+
local_rank=0,
|
|
90
|
+
rank=0,
|
|
91
|
+
distributed_init_method="test_method")
|
|
92
|
+
assert worker.profile_dir == "/tmp/profiles"
|
|
93
|
+
|
|
94
|
+
@patch('tpu_inference.worker.tpu_worker.vllm_envs')
|
|
95
|
+
def test_init_with_profiler_on_other_ranks(self, mock_envs,
|
|
96
|
+
mock_vllm_config):
|
|
97
|
+
"""Tests that the profiler directory is NOT set on non-rank 0 workers."""
|
|
98
|
+
mock_envs.VLLM_TORCH_PROFILER_DIR = "/tmp/profiles"
|
|
99
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
100
|
+
local_rank=1,
|
|
101
|
+
rank=1,
|
|
102
|
+
distributed_init_method="test_method")
|
|
103
|
+
assert worker.profile_dir is None
|
|
104
|
+
|
|
105
|
+
#
|
|
106
|
+
# --- Device and Cache Initialization Tests ---
|
|
107
|
+
#
|
|
108
|
+
|
|
109
|
+
def test_initialize_cache(self, mock_vllm_config):
|
|
110
|
+
"""Tests setting the number of GPU and CPU cache blocks."""
|
|
111
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
112
|
+
local_rank=0,
|
|
113
|
+
rank=0,
|
|
114
|
+
distributed_init_method="test_method")
|
|
115
|
+
worker.initialize_cache(num_gpu_blocks=2048, num_cpu_blocks=1024)
|
|
116
|
+
assert worker.cache_config.num_gpu_blocks == 2048
|
|
117
|
+
assert worker.cache_config.num_cpu_blocks == 1024
|
|
118
|
+
|
|
119
|
+
@patch('tpu_inference.worker.tpu_worker.TPUModelRunner')
|
|
120
|
+
@patch('tpu_inference.worker.tpu_worker.utils')
|
|
121
|
+
@patch('tpu_inference.worker.tpu_worker.jax')
|
|
122
|
+
@patch('tpu_inference.worker.tpu_worker.ensure_kv_transfer_initialized')
|
|
123
|
+
def test_init_device_with_provided_devices(
|
|
124
|
+
self, mock_ensure_kv_transfer_initialized, mock_jax, mock_utils,
|
|
125
|
+
mock_runner_cls, mock_vllm_config):
|
|
126
|
+
"""Tests init_device when devices are provided during construction."""
|
|
127
|
+
mock_devices = ['tpu:0', 'tpu:1']
|
|
128
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
129
|
+
local_rank=0,
|
|
130
|
+
rank=0,
|
|
131
|
+
distributed_init_method="test_method",
|
|
132
|
+
devices=mock_devices)
|
|
133
|
+
|
|
134
|
+
worker.init_device()
|
|
135
|
+
|
|
136
|
+
expected_rank = 0
|
|
137
|
+
expected_is_first_rank = True
|
|
138
|
+
expected_is_last_rank = True
|
|
139
|
+
mock_runner_cls.assert_called_once_with(mock_vllm_config, mock_devices,
|
|
140
|
+
expected_rank,
|
|
141
|
+
expected_is_first_rank,
|
|
142
|
+
expected_is_last_rank)
|
|
143
|
+
assert isinstance(worker.model_runner, MagicMock)
|
|
144
|
+
|
|
145
|
+
@patch('tpu_inference.worker.tpu_worker.TPUModelRunner')
|
|
146
|
+
@patch('tpu_inference.worker.tpu_worker.utils')
|
|
147
|
+
@patch('tpu_inference.worker.tpu_worker.jax')
|
|
148
|
+
@patch('tpu_inference.worker.tpu_worker.ensure_kv_transfer_initialized')
|
|
149
|
+
def test_init_device_autodetects_devices(
|
|
150
|
+
self, mock_ensure_kv_transfer_initialized, mock_jax, mock_utils,
|
|
151
|
+
mock_runner_cls, mock_vllm_config):
|
|
152
|
+
"""Tests init_device when devices are auto-detected via JAX."""
|
|
153
|
+
worker = TPUWorker(
|
|
154
|
+
vllm_config=mock_vllm_config,
|
|
155
|
+
local_rank=0,
|
|
156
|
+
rank=0,
|
|
157
|
+
distributed_init_method="test_method",
|
|
158
|
+
devices=[] # No devices provided, should trigger auto-detection
|
|
159
|
+
)
|
|
160
|
+
mock_jax.device_count.return_value = 4
|
|
161
|
+
mock_jax.devices.return_value = ['tpu:0', 'tpu:1', 'tpu:2', 'tpu:3']
|
|
162
|
+
|
|
163
|
+
worker.init_device()
|
|
164
|
+
|
|
165
|
+
expected_devices = ['tpu:0', 'tpu:1'] # Sliced by tensor_parallel_size
|
|
166
|
+
assert worker.devices == expected_devices
|
|
167
|
+
expected_rank = 0
|
|
168
|
+
expected_is_first_rank = True
|
|
169
|
+
expected_is_last_rank = True
|
|
170
|
+
mock_runner_cls.assert_called_once_with(mock_vllm_config,
|
|
171
|
+
expected_devices,
|
|
172
|
+
expected_rank,
|
|
173
|
+
expected_is_first_rank,
|
|
174
|
+
expected_is_last_rank)
|
|
175
|
+
|
|
176
|
+
@patch('tpu_inference.worker.tpu_worker.utils')
|
|
177
|
+
def test_determine_available_memory(self, mock_utils, mock_vllm_config):
|
|
178
|
+
"""Tests the available HBM memory calculation."""
|
|
179
|
+
# Setup mock return for hbm_usage_bytes: [(used_bytes, limit_bytes), ...]
|
|
180
|
+
mock_utils.hbm_usage_bytes.return_value = [
|
|
181
|
+
(100 * 1024**3, 1000 * 1024**3), (200 * 1024**3, 1000 * 1024**3)
|
|
182
|
+
]
|
|
183
|
+
mock_devices = ['tpu:0', 'tpu:1']
|
|
184
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
185
|
+
local_rank=0,
|
|
186
|
+
rank=0,
|
|
187
|
+
distributed_init_method="test_method",
|
|
188
|
+
devices=mock_devices)
|
|
189
|
+
|
|
190
|
+
available_mem = worker.determine_available_memory()
|
|
191
|
+
|
|
192
|
+
mock_utils.hbm_usage_bytes.assert_called_once_with(mock_devices)
|
|
193
|
+
# Total limit: 1000 + 1000 = 2000 GiB
|
|
194
|
+
# Total cap: 2000 * 0.9 = 1800 GiB
|
|
195
|
+
# Total used: 100 + 200 = 300 GiB
|
|
196
|
+
# Total free = 1800 - 300 = 1500 GiB
|
|
197
|
+
expected_mem = 1500 * 1024**3
|
|
198
|
+
assert available_mem == expected_mem
|
|
199
|
+
|
|
200
|
+
#
|
|
201
|
+
# --- Core Logic Tests ---
|
|
202
|
+
#
|
|
203
|
+
|
|
204
|
+
@patch('tpu_inference.worker.tpu_worker.TPUModelRunner')
|
|
205
|
+
def test_execute_model(self, mock_runner_cls, mock_vllm_config):
|
|
206
|
+
"""Tests that the driver worker executes the model and returns the concrete vLLM output."""
|
|
207
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
208
|
+
local_rank=0,
|
|
209
|
+
rank=0,
|
|
210
|
+
distributed_init_method="test",
|
|
211
|
+
is_driver_worker=True)
|
|
212
|
+
worker.model_runner = mock_runner_cls.return_value # Assign mocked runner instance
|
|
213
|
+
mock_scheduler_input = MagicMock()
|
|
214
|
+
|
|
215
|
+
# The model runner returns a concrete vllm output
|
|
216
|
+
mock_model_output = "concrete_model_output"
|
|
217
|
+
worker.model_runner.execute_model.return_value = mock_model_output
|
|
218
|
+
|
|
219
|
+
result = worker.execute_model(mock_scheduler_input)
|
|
220
|
+
|
|
221
|
+
# Assert the runner was called with the scheduler output directly
|
|
222
|
+
worker.model_runner.execute_model.assert_called_once_with(
|
|
223
|
+
mock_scheduler_input, None)
|
|
224
|
+
# Assert the final result is the concrete model output
|
|
225
|
+
assert result == mock_model_output
|
|
226
|
+
|
|
227
|
+
@patch('tpu_inference.worker.tpu_worker.TPUModelRunner')
|
|
228
|
+
def test_execute_model_non_driver_returns_none(self, mock_runner_cls,
|
|
229
|
+
mock_vllm_config):
|
|
230
|
+
"""Tests that a non-driver worker executes the model but returns None."""
|
|
231
|
+
worker = TPUWorker(
|
|
232
|
+
vllm_config=mock_vllm_config,
|
|
233
|
+
local_rank=0,
|
|
234
|
+
rank=0,
|
|
235
|
+
distributed_init_method="test",
|
|
236
|
+
is_driver_worker=False # Not a driver
|
|
237
|
+
)
|
|
238
|
+
worker.model_runner = mock_runner_cls.return_value
|
|
239
|
+
mock_scheduler_input = MagicMock()
|
|
240
|
+
|
|
241
|
+
result = worker.execute_model(mock_scheduler_input)
|
|
242
|
+
|
|
243
|
+
assert result is None
|
|
244
|
+
|
|
245
|
+
def test_take_draft_token_ids(self, mock_vllm_config):
|
|
246
|
+
"""Tests that take_draft_token_ids correctly passes through from the runner."""
|
|
247
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
248
|
+
local_rank=0,
|
|
249
|
+
rank=0,
|
|
250
|
+
distributed_init_method="test")
|
|
251
|
+
worker.model_runner = MagicMock()
|
|
252
|
+
|
|
253
|
+
# Case 1: Runner returns a DraftTokenIds object
|
|
254
|
+
mock_draft_tokens = DraftTokenIds(req_ids=["req1"],
|
|
255
|
+
draft_token_ids=[[1, 2]])
|
|
256
|
+
worker.model_runner.take_draft_token_ids.return_value = mock_draft_tokens
|
|
257
|
+
|
|
258
|
+
result = worker.take_draft_token_ids()
|
|
259
|
+
worker.model_runner.take_draft_token_ids.assert_called_once()
|
|
260
|
+
assert result == mock_draft_tokens
|
|
261
|
+
|
|
262
|
+
def test_add_lora_not_implemented(self, mock_vllm_config):
|
|
263
|
+
"""Tests that add_lora raises NotImplementedError."""
|
|
264
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
265
|
+
local_rank=0,
|
|
266
|
+
rank=0,
|
|
267
|
+
distributed_init_method="test")
|
|
268
|
+
mock_lora_request = MagicMock()
|
|
269
|
+
|
|
270
|
+
with pytest.raises(
|
|
271
|
+
NotImplementedError,
|
|
272
|
+
match="LoRA is not supported by the JAX worker yet."):
|
|
273
|
+
worker.add_lora(mock_lora_request)
|
|
274
|
+
|
|
275
|
+
def test_add_lora_not_implemented_lora_request(self, mock_vllm_config):
|
|
276
|
+
"""Tests that add_lora raises NotImplementedError."""
|
|
277
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
278
|
+
local_rank=0,
|
|
279
|
+
rank=0,
|
|
280
|
+
distributed_init_method="test")
|
|
281
|
+
mock_lora_request = MagicMock(spec=LoRARequest)
|
|
282
|
+
|
|
283
|
+
with pytest.raises(
|
|
284
|
+
NotImplementedError,
|
|
285
|
+
match="LoRA is not supported by the JAX worker yet."):
|
|
286
|
+
worker.add_lora(mock_lora_request)
|
|
287
|
+
|
|
288
|
+
#
|
|
289
|
+
# --- Profiling and Health Check Tests ---
|
|
290
|
+
#
|
|
291
|
+
|
|
292
|
+
@patch('tpu_inference.worker.tpu_worker.jax')
|
|
293
|
+
@patch.dict('os.environ', {"PYTHON_TRACER_LEVEL": "1"}, clear=True)
|
|
294
|
+
def test_profile_start(self, mock_jax, mock_vllm_config):
|
|
295
|
+
"""Tests starting the JAX profiler."""
|
|
296
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
297
|
+
local_rank=0,
|
|
298
|
+
rank=0,
|
|
299
|
+
distributed_init_method="test")
|
|
300
|
+
worker.profile_dir = "/tmp/profile_dir"
|
|
301
|
+
|
|
302
|
+
worker.profile(is_start=True)
|
|
303
|
+
|
|
304
|
+
mock_jax.profiler.ProfileOptions.assert_called_once()
|
|
305
|
+
mock_jax.profiler.start_trace.assert_called_once()
|
|
306
|
+
args, kwargs = mock_jax.profiler.start_trace.call_args
|
|
307
|
+
assert args[0] == "/tmp/profile_dir"
|
|
308
|
+
# Verify options from env var were used
|
|
309
|
+
assert kwargs['profiler_options'].python_tracer_level == 1
|
|
310
|
+
|
|
311
|
+
@patch('tpu_inference.worker.tpu_worker.jax')
|
|
312
|
+
def test_profile_stop(self, mock_jax, mock_vllm_config):
|
|
313
|
+
"""Tests stopping the JAX profiler."""
|
|
314
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
315
|
+
local_rank=0,
|
|
316
|
+
rank=0,
|
|
317
|
+
distributed_init_method="test")
|
|
318
|
+
worker.profile(is_start=False)
|
|
319
|
+
mock_jax.profiler.stop_trace.assert_called_once()
|
|
320
|
+
|
|
321
|
+
def test_check_health(self, mock_vllm_config):
|
|
322
|
+
"""Tests that check_health runs without error."""
|
|
323
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
324
|
+
local_rank=0,
|
|
325
|
+
rank=0,
|
|
326
|
+
distributed_init_method="test")
|
|
327
|
+
try:
|
|
328
|
+
worker.check_health()
|
|
329
|
+
except Exception as e:
|
|
330
|
+
pytest.fail(
|
|
331
|
+
f"TPUWorker.check_health() raised an unexpected exception: {e}"
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
#
|
|
335
|
+
# --- Pass-through Method Tests ---
|
|
336
|
+
#
|
|
337
|
+
|
|
338
|
+
@pytest.mark.parametrize(
|
|
339
|
+
"worker_method_name, runner_method_name, method_args", [
|
|
340
|
+
("load_model", "load_model", []),
|
|
341
|
+
("get_model", "get_model", []),
|
|
342
|
+
("get_kv_cache_spec", "get_kv_cache_spec", []),
|
|
343
|
+
])
|
|
344
|
+
def test_runner_passthrough_methods(self, worker_method_name,
|
|
345
|
+
runner_method_name, method_args,
|
|
346
|
+
mock_vllm_config):
|
|
347
|
+
"""Tests methods that are simple pass-throughs to the TPUModelRunner."""
|
|
348
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
349
|
+
local_rank=0,
|
|
350
|
+
rank=0,
|
|
351
|
+
distributed_init_method="test")
|
|
352
|
+
worker.model_runner = MagicMock()
|
|
353
|
+
|
|
354
|
+
# Call the worker method and assert the underlying runner method was called
|
|
355
|
+
getattr(worker, worker_method_name)(*method_args)
|
|
356
|
+
mock_runner_method = getattr(worker.model_runner, runner_method_name)
|
|
357
|
+
mock_runner_method.assert_called_once_with(*method_args)
|
|
358
|
+
|
|
359
|
+
def test_initialize_from_config(self, mock_vllm_config):
|
|
360
|
+
"""Tests the special case pass-through for initialize_from_config."""
|
|
361
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
362
|
+
local_rank=0,
|
|
363
|
+
rank=0,
|
|
364
|
+
distributed_init_method="test")
|
|
365
|
+
worker.model_runner = MagicMock()
|
|
366
|
+
worker.topology_order_id = 0
|
|
367
|
+
mock_input_config = MagicMock()
|
|
368
|
+
|
|
369
|
+
worker.initialize_from_config(mock_input_config)
|
|
370
|
+
|
|
371
|
+
worker.model_runner.initialize_kv_cache.assert_called_once_with(
|
|
372
|
+
mock_input_config, 0)
|
|
373
|
+
|
|
374
|
+
def test_initialize_from_config_kv_cache_config(self, mock_vllm_config):
|
|
375
|
+
"""Tests the special case pass-through for initialize_from_config."""
|
|
376
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
377
|
+
local_rank=0,
|
|
378
|
+
rank=0,
|
|
379
|
+
distributed_init_method="test")
|
|
380
|
+
worker.model_runner = MagicMock()
|
|
381
|
+
worker.topology_order_id = 0
|
|
382
|
+
mock_input_config = MagicMock(spec=KVCacheConfig)
|
|
383
|
+
|
|
384
|
+
worker.initialize_from_config(mock_input_config)
|
|
385
|
+
|
|
386
|
+
worker.model_runner.initialize_kv_cache.assert_called_once_with(
|
|
387
|
+
mock_input_config, 0)
|
|
388
|
+
|
|
389
|
+
def test_compile_or_warm_up_model(self, mock_vllm_config):
|
|
390
|
+
"""Tests the special case pass-through for model compilation/warmup."""
|
|
391
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
392
|
+
local_rank=0,
|
|
393
|
+
rank=0,
|
|
394
|
+
distributed_init_method="test")
|
|
395
|
+
worker.model_runner = MagicMock()
|
|
396
|
+
|
|
397
|
+
worker.compile_or_warm_up_model()
|
|
398
|
+
|
|
399
|
+
# This method calls two different runner methods
|
|
400
|
+
worker.model_runner.capture_model.assert_called_once()
|
|
401
|
+
worker.model_runner._init_random.assert_called_once()
|
|
402
|
+
|
|
403
|
+
def test_get_supported_tasks(self, mock_vllm_config):
|
|
404
|
+
"""Test get_supported_tasks passthrough to model runner."""
|
|
405
|
+
worker = TPUWorker(vllm_config=mock_vllm_config,
|
|
406
|
+
local_rank=0,
|
|
407
|
+
rank=0,
|
|
408
|
+
distributed_init_method="test")
|
|
409
|
+
worker.model_runner = MagicMock()
|
|
410
|
+
worker.model_runner.get_supported_tasks.return_value = ("generate", )
|
|
411
|
+
|
|
412
|
+
_ = worker.get_supported_tasks()
|
|
413
|
+
|
|
414
|
+
worker.model_runner.get_supported_tasks.assert_called_once()
|
tpu_inference/__init__.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
# The environment variables override should be imported before any other
|
|
2
16
|
# modules to ensure that the environment variables are set before any
|
|
3
17
|
# other modules are imported.
|
tpu_inference/core/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|