tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +46 -17
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +44 -17
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
- tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
tests/core/test_init.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import importlib
|
|
2
16
|
import unittest
|
|
3
17
|
from unittest.mock import patch
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from collections import namedtuple
|
|
16
|
+
|
|
17
|
+
import pytest
|
|
18
|
+
|
|
19
|
+
from tpu_inference.distributed.utils import get_device_topology_order_id
|
|
20
|
+
|
|
21
|
+
# Mock TpuDevice object to simulate the real one.
|
|
22
|
+
TpuDevice = namedtuple('TpuDevice',
|
|
23
|
+
['id', 'process_index', 'coords', 'core_on_chip'])
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_get_device_topology_order_id():
|
|
27
|
+
"""
|
|
28
|
+
Tests the get_device_topology_order_id function with a mock topology.
|
|
29
|
+
"""
|
|
30
|
+
# V7x
|
|
31
|
+
global_devices = [
|
|
32
|
+
TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0),
|
|
33
|
+
TpuDevice(id=1, process_index=0, coords=(0, 0, 0), core_on_chip=1),
|
|
34
|
+
TpuDevice(id=2, process_index=0, coords=(1, 0, 0), core_on_chip=0),
|
|
35
|
+
TpuDevice(id=3, process_index=0, coords=(1, 0, 0), core_on_chip=1),
|
|
36
|
+
TpuDevice(id=4, process_index=0, coords=(0, 1, 0), core_on_chip=0),
|
|
37
|
+
TpuDevice(id=5, process_index=0, coords=(0, 1, 0), core_on_chip=1),
|
|
38
|
+
TpuDevice(id=6, process_index=0, coords=(1, 1, 0), core_on_chip=0),
|
|
39
|
+
TpuDevice(id=7, process_index=0, coords=(1, 1, 0), core_on_chip=1),
|
|
40
|
+
TpuDevice(id=8, process_index=1, coords=(0, 0, 1), core_on_chip=0),
|
|
41
|
+
TpuDevice(id=9, process_index=1, coords=(0, 0, 1), core_on_chip=1),
|
|
42
|
+
TpuDevice(id=10, process_index=1, coords=(1, 0, 1), core_on_chip=0),
|
|
43
|
+
TpuDevice(id=11, process_index=1, coords=(1, 0, 1), core_on_chip=1),
|
|
44
|
+
TpuDevice(id=12, process_index=1, coords=(0, 1, 1), core_on_chip=0),
|
|
45
|
+
TpuDevice(id=13, process_index=1, coords=(0, 1, 1), core_on_chip=1),
|
|
46
|
+
TpuDevice(id=14, process_index=1, coords=(1, 1, 1), core_on_chip=0),
|
|
47
|
+
TpuDevice(id=15, process_index=1, coords=(1, 1, 1), core_on_chip=1),
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
local_devices_1 = global_devices[:8]
|
|
51
|
+
local_devices_2 = global_devices[8:]
|
|
52
|
+
|
|
53
|
+
assert get_device_topology_order_id(local_devices_1, global_devices) == 0
|
|
54
|
+
assert get_device_topology_order_id(local_devices_2, global_devices) == 1
|
|
55
|
+
|
|
56
|
+
# Test with unsorted in global_devices
|
|
57
|
+
shuffled_z_global_devices = [
|
|
58
|
+
TpuDevice(id=8, process_index=1, coords=(0, 0, 1), core_on_chip=0),
|
|
59
|
+
TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0),
|
|
60
|
+
]
|
|
61
|
+
local_devices_z1 = [
|
|
62
|
+
TpuDevice(id=8, process_index=1, coords=(0, 0, 1), core_on_chip=0)
|
|
63
|
+
]
|
|
64
|
+
local_devices_z0 = [
|
|
65
|
+
TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0)
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
assert get_device_topology_order_id(local_devices_z0,
|
|
69
|
+
shuffled_z_global_devices) == 0
|
|
70
|
+
assert get_device_topology_order_id(local_devices_z1,
|
|
71
|
+
shuffled_z_global_devices) == 1
|
|
72
|
+
|
|
73
|
+
#v6e
|
|
74
|
+
global_devices = [
|
|
75
|
+
TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0),
|
|
76
|
+
TpuDevice(id=1, process_index=1, coords=(1, 0, 0), core_on_chip=0),
|
|
77
|
+
TpuDevice(id=2, process_index=2, coords=(0, 1, 0), core_on_chip=0),
|
|
78
|
+
TpuDevice(id=3, process_index=3, coords=(1, 1, 0), core_on_chip=0)
|
|
79
|
+
]
|
|
80
|
+
local_devices = [
|
|
81
|
+
TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0)
|
|
82
|
+
]
|
|
83
|
+
assert get_device_topology_order_id(local_devices, global_devices) == 0
|
|
84
|
+
|
|
85
|
+
local_devices = [
|
|
86
|
+
TpuDevice(id=1, process_index=1, coords=(1, 0, 0), core_on_chip=0)
|
|
87
|
+
]
|
|
88
|
+
assert get_device_topology_order_id(local_devices, global_devices) == 2
|
|
89
|
+
|
|
90
|
+
local_devices = [
|
|
91
|
+
TpuDevice(id=2, process_index=2, coords=(0, 1, 0), core_on_chip=0)
|
|
92
|
+
]
|
|
93
|
+
assert get_device_topology_order_id(local_devices, global_devices) == 1
|
|
94
|
+
|
|
95
|
+
local_devices = [
|
|
96
|
+
TpuDevice(id=3, process_index=3, coords=(1, 1, 0), core_on_chip=0)
|
|
97
|
+
]
|
|
98
|
+
assert get_device_topology_order_id(local_devices, global_devices) == 3
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def test_get_device_topology_order_id_empty_local():
|
|
102
|
+
"""
|
|
103
|
+
Tests that a ValueError is raised for empty local_devices.
|
|
104
|
+
"""
|
|
105
|
+
with pytest.raises(ValueError, match="local_devices cannot be empty"):
|
|
106
|
+
get_device_topology_order_id([], [])
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def test_get_device_topology_order_id_not_in_global():
|
|
110
|
+
"""
|
|
111
|
+
Tests that a ValueError is raised if local z-coordinate is not in global list.
|
|
112
|
+
"""
|
|
113
|
+
global_devices = [
|
|
114
|
+
TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0),
|
|
115
|
+
]
|
|
116
|
+
local_devices = [
|
|
117
|
+
TpuDevice(id=1, process_index=1, coords=(0, 0, 1), core_on_chip=0),
|
|
118
|
+
]
|
|
119
|
+
with pytest.raises(ValueError, match="do not exist in the global device:"):
|
|
120
|
+
get_device_topology_order_id(local_devices, global_devices)
|
|
@@ -0,0 +1,478 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import unittest
|
|
16
|
+
from unittest.mock import MagicMock, patch
|
|
17
|
+
|
|
18
|
+
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
|
|
19
|
+
from vllm.v1.request import RequestStatus
|
|
20
|
+
|
|
21
|
+
from tpu_inference.distributed import tpu_connector
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MockVllmConfig:
|
|
25
|
+
|
|
26
|
+
def __init__(self):
|
|
27
|
+
self.kv_transfer_config = MagicMock()
|
|
28
|
+
self.kv_transfer_config.is_kv_producer = True
|
|
29
|
+
self.cache_config = MagicMock()
|
|
30
|
+
self.cache_config.block_size = 16
|
|
31
|
+
self.parallel_config = MagicMock()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@patch("tpu_inference.distributed.tpu_connector.TPUConnectorWorker")
|
|
35
|
+
@patch("tpu_inference.distributed.tpu_connector.TPUConnectorScheduler")
|
|
36
|
+
class TestTPUConnector(unittest.TestCase):
|
|
37
|
+
|
|
38
|
+
def setUp(self):
|
|
39
|
+
self.vllm_config = MockVllmConfig()
|
|
40
|
+
|
|
41
|
+
def test_init_scheduler_role(self, mock_scheduler_cls, mock_worker_cls):
|
|
42
|
+
"""
|
|
43
|
+
Tests that TPUConnector initializes the scheduler connector for the
|
|
44
|
+
SCHEDULER role.
|
|
45
|
+
"""
|
|
46
|
+
connector = tpu_connector.TPUConnector(self.vllm_config,
|
|
47
|
+
KVConnectorRole.SCHEDULER)
|
|
48
|
+
mock_scheduler_cls.assert_called_once_with(self.vllm_config)
|
|
49
|
+
mock_worker_cls.assert_not_called()
|
|
50
|
+
self.assertIsNotNone(connector.connector_scheduler)
|
|
51
|
+
self.assertIsNone(connector.connector_worker)
|
|
52
|
+
|
|
53
|
+
def test_init_worker_role(self, mock_scheduler_cls, mock_worker_cls):
|
|
54
|
+
"""
|
|
55
|
+
Tests that TPUConnector initializes the worker connector for the WORKER
|
|
56
|
+
role.
|
|
57
|
+
"""
|
|
58
|
+
connector = tpu_connector.TPUConnector(self.vllm_config,
|
|
59
|
+
KVConnectorRole.WORKER)
|
|
60
|
+
mock_worker_cls.assert_called_once_with(self.vllm_config)
|
|
61
|
+
mock_scheduler_cls.assert_not_called()
|
|
62
|
+
self.assertIsNone(connector.connector_scheduler)
|
|
63
|
+
self.assertIsNotNone(connector.connector_worker)
|
|
64
|
+
|
|
65
|
+
def test_scheduler_methods_are_called(self, mock_scheduler_cls,
|
|
66
|
+
mock_worker_cls):
|
|
67
|
+
"""Tests that scheduler-side methods are correctly delegated."""
|
|
68
|
+
mock_scheduler_instance = mock_scheduler_cls.return_value
|
|
69
|
+
connector = tpu_connector.TPUConnector(self.vllm_config,
|
|
70
|
+
KVConnectorRole.SCHEDULER)
|
|
71
|
+
|
|
72
|
+
mock_request = MagicMock()
|
|
73
|
+
mock_blocks = MagicMock()
|
|
74
|
+
mock_scheduler_output = MagicMock()
|
|
75
|
+
|
|
76
|
+
connector.get_num_new_matched_tokens(mock_request, 16)
|
|
77
|
+
mock_scheduler_instance.get_num_new_matched_tokens.assert_called_once_with(
|
|
78
|
+
mock_request, 16)
|
|
79
|
+
|
|
80
|
+
connector.update_state_after_alloc(mock_request, mock_blocks, 16)
|
|
81
|
+
mock_scheduler_instance.update_state_after_alloc.assert_called_once_with(
|
|
82
|
+
mock_request, mock_blocks, 16)
|
|
83
|
+
|
|
84
|
+
connector.build_connector_meta(mock_scheduler_output)
|
|
85
|
+
mock_scheduler_instance.build_connector_meta.assert_called_once_with()
|
|
86
|
+
|
|
87
|
+
connector.request_finished(mock_request, [1, 2])
|
|
88
|
+
mock_scheduler_instance.request_finished.assert_called_once_with(
|
|
89
|
+
mock_request, [1, 2])
|
|
90
|
+
|
|
91
|
+
def test_worker_methods_are_called(self, mock_scheduler_cls,
|
|
92
|
+
mock_worker_cls):
|
|
93
|
+
"""Tests that worker-side methods are correctly delegated."""
|
|
94
|
+
mock_worker_instance = mock_worker_cls.return_value
|
|
95
|
+
connector = tpu_connector.TPUConnector(self.vllm_config,
|
|
96
|
+
KVConnectorRole.WORKER)
|
|
97
|
+
connector._connector_metadata = tpu_connector.TPUConnectorMetadata(
|
|
98
|
+
) # need to set this for start_load_kv
|
|
99
|
+
|
|
100
|
+
mock_runner = MagicMock()
|
|
101
|
+
|
|
102
|
+
connector.register_runner(mock_runner)
|
|
103
|
+
mock_worker_instance.register_runner.assert_called_once_with(
|
|
104
|
+
mock_runner)
|
|
105
|
+
|
|
106
|
+
connector.start_load_kv(None)
|
|
107
|
+
mock_worker_instance.process_send_load.assert_called_once_with(
|
|
108
|
+
connector._connector_metadata)
|
|
109
|
+
|
|
110
|
+
connector.get_finished(set())
|
|
111
|
+
mock_worker_instance.get_finished.assert_called_once_with()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class TestTPUConnectorScheduler(unittest.TestCase):
|
|
115
|
+
|
|
116
|
+
def setUp(self):
|
|
117
|
+
self.vllm_config = MockVllmConfig()
|
|
118
|
+
self.vllm_config.cache_config.block_size = 16
|
|
119
|
+
self.vllm_config.kv_transfer_config.is_kv_producer = False
|
|
120
|
+
|
|
121
|
+
with patch("tpu_inference.distributed.tpu_connector.get_kv_ips",
|
|
122
|
+
return_value="1.1.1.1"), patch(
|
|
123
|
+
"tpu_inference.distributed.tpu_connector.get_kv_ports",
|
|
124
|
+
return_value=12345):
|
|
125
|
+
self.scheduler = tpu_connector.TPUConnectorScheduler(
|
|
126
|
+
self.vllm_config)
|
|
127
|
+
|
|
128
|
+
def test_get_num_new_matched_tokens_producer(self):
|
|
129
|
+
"""Tests that producer returns 0 tokens to load."""
|
|
130
|
+
self.scheduler.is_producer = True
|
|
131
|
+
mock_request = MagicMock()
|
|
132
|
+
num_tokens, is_async = self.scheduler.get_num_new_matched_tokens(
|
|
133
|
+
mock_request, 16)
|
|
134
|
+
self.assertEqual(num_tokens, 0)
|
|
135
|
+
self.assertFalse(is_async)
|
|
136
|
+
|
|
137
|
+
def test_get_num_new_matched_tokens_consumer_needs_loading(self):
|
|
138
|
+
"""Tests consumer calculates correct number of tokens to load."""
|
|
139
|
+
mock_request = MagicMock()
|
|
140
|
+
mock_request.prompt_token_ids = [0] * 35 # 2 blocks worth, plus some
|
|
141
|
+
num_computed_tokens = 16 # 1 block
|
|
142
|
+
# rounded_down(35) = 32. 32 - 16 = 16.
|
|
143
|
+
expected_tokens = 16
|
|
144
|
+
num_tokens, is_async = self.scheduler.get_num_new_matched_tokens(
|
|
145
|
+
mock_request, num_computed_tokens)
|
|
146
|
+
self.assertEqual(num_tokens, expected_tokens)
|
|
147
|
+
self.assertTrue(is_async)
|
|
148
|
+
|
|
149
|
+
def test_get_num_new_matched_tokens_consumer_no_loading(self):
|
|
150
|
+
"""Tests consumer returns 0 if prompt is fully cached."""
|
|
151
|
+
mock_request = MagicMock()
|
|
152
|
+
mock_request.prompt_token_ids = [0] * 31 # less than 2 blocks
|
|
153
|
+
num_computed_tokens = 32 # 2 blocks computed
|
|
154
|
+
expected_tokens = 0
|
|
155
|
+
num_tokens, is_async = self.scheduler.get_num_new_matched_tokens(
|
|
156
|
+
mock_request, num_computed_tokens)
|
|
157
|
+
self.assertEqual(num_tokens, expected_tokens)
|
|
158
|
+
self.assertFalse(is_async)
|
|
159
|
+
|
|
160
|
+
def test_update_state_after_alloc_producer(self):
|
|
161
|
+
"""Tests that update_state_after_alloc is a no-op for producers."""
|
|
162
|
+
self.scheduler.is_producer = True
|
|
163
|
+
self.scheduler.update_state_after_alloc(MagicMock(), MagicMock(), 16)
|
|
164
|
+
self.assertEqual(len(self.scheduler.reqs_to_load), 0)
|
|
165
|
+
|
|
166
|
+
def test_update_state_after_alloc_consumer_with_external_tokens(self):
|
|
167
|
+
"""
|
|
168
|
+
Tests consumer state is updated when external tokens are needed.
|
|
169
|
+
"""
|
|
170
|
+
mock_request = MagicMock()
|
|
171
|
+
mock_request.request_id = "req1"
|
|
172
|
+
mock_request.kv_transfer_params = {
|
|
173
|
+
"uuid": 123,
|
|
174
|
+
"remote_block_ids": [10, 11],
|
|
175
|
+
"remote_host": "2.2.2.2",
|
|
176
|
+
"remote_port": 54321
|
|
177
|
+
}
|
|
178
|
+
mock_blocks = MagicMock()
|
|
179
|
+
mock_blocks.get_block_ids.return_value = [[1, 2]]
|
|
180
|
+
num_external_tokens = 32
|
|
181
|
+
|
|
182
|
+
self.scheduler.update_state_after_alloc(mock_request, mock_blocks,
|
|
183
|
+
num_external_tokens)
|
|
184
|
+
|
|
185
|
+
self.assertIn("req1", self.scheduler.reqs_to_load)
|
|
186
|
+
load_meta = self.scheduler.reqs_to_load["req1"]
|
|
187
|
+
self.assertEqual(load_meta.uuid, 123)
|
|
188
|
+
self.assertEqual(load_meta.local_block_ids, [1, 2])
|
|
189
|
+
self.assertEqual(load_meta.remote_block_ids, [10, 11])
|
|
190
|
+
|
|
191
|
+
def test_update_state_after_alloc_consumer_no_external_tokens(self):
|
|
192
|
+
"""
|
|
193
|
+
Tests consumer state is updated for notification when no external
|
|
194
|
+
tokens are needed.
|
|
195
|
+
"""
|
|
196
|
+
mock_request = MagicMock()
|
|
197
|
+
mock_request.request_id = "req1"
|
|
198
|
+
mock_request.kv_transfer_params = {
|
|
199
|
+
"uuid": 123,
|
|
200
|
+
"remote_block_ids": [10, 11],
|
|
201
|
+
"remote_host": "2.2.2.2",
|
|
202
|
+
"remote_port": 54321
|
|
203
|
+
}
|
|
204
|
+
mock_blocks = MagicMock()
|
|
205
|
+
num_external_tokens = 0
|
|
206
|
+
|
|
207
|
+
self.scheduler.update_state_after_alloc(mock_request, mock_blocks,
|
|
208
|
+
num_external_tokens)
|
|
209
|
+
|
|
210
|
+
self.assertIn("req1", self.scheduler.reqs_to_load)
|
|
211
|
+
load_meta = self.scheduler.reqs_to_load["req1"]
|
|
212
|
+
self.assertEqual(load_meta.uuid, 123)
|
|
213
|
+
self.assertIsNone(load_meta.local_block_ids)
|
|
214
|
+
self.assertIsNone(load_meta.remote_block_ids)
|
|
215
|
+
|
|
216
|
+
def test_build_connector_meta(self):
|
|
217
|
+
"""Tests that metadata is built and state is cleared."""
|
|
218
|
+
self.scheduler.is_producer = True
|
|
219
|
+
self.scheduler.reqs_to_send = {"req1": "meta1"}
|
|
220
|
+
meta = self.scheduler.build_connector_meta()
|
|
221
|
+
self.assertEqual(meta.reqs_to_send, {"req1": "meta1"})
|
|
222
|
+
self.assertEqual(len(self.scheduler.reqs_to_send),
|
|
223
|
+
0) # check it was cleared
|
|
224
|
+
|
|
225
|
+
self.scheduler.is_producer = False
|
|
226
|
+
self.scheduler.reqs_to_load = {"req2": "meta2"}
|
|
227
|
+
meta = self.scheduler.build_connector_meta()
|
|
228
|
+
self.assertEqual(meta.reqs_to_load, {"req2": "meta2"})
|
|
229
|
+
self.assertEqual(len(self.scheduler.reqs_to_load), 0)
|
|
230
|
+
|
|
231
|
+
def test_request_finished_consumer(self):
|
|
232
|
+
"""Tests request_finished is a no-op for consumers."""
|
|
233
|
+
self.scheduler.is_producer = False
|
|
234
|
+
delay_free, params = self.scheduler.request_finished(MagicMock(), [])
|
|
235
|
+
self.assertFalse(delay_free)
|
|
236
|
+
self.assertIsNone(params)
|
|
237
|
+
|
|
238
|
+
@patch("tpu_inference.distributed.tpu_connector.get_uuid",
|
|
239
|
+
return_value=456)
|
|
240
|
+
def test_request_finished_producer_finished_by_length(self, mock_get_uuid):
|
|
241
|
+
"""Tests producer logic when a request finishes normally."""
|
|
242
|
+
self.scheduler.is_producer = True
|
|
243
|
+
mock_request = MagicMock()
|
|
244
|
+
mock_request.request_id = "req-finished"
|
|
245
|
+
mock_request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
|
246
|
+
mock_request.num_computed_tokens = 32 # 2 blocks
|
|
247
|
+
block_ids = [1, 2]
|
|
248
|
+
|
|
249
|
+
delay_free, params = self.scheduler.request_finished(
|
|
250
|
+
mock_request, block_ids)
|
|
251
|
+
|
|
252
|
+
self.assertTrue(delay_free)
|
|
253
|
+
self.assertIn("req-finished", self.scheduler.reqs_to_send)
|
|
254
|
+
send_meta = self.scheduler.reqs_to_send["req-finished"]
|
|
255
|
+
self.assertEqual(send_meta.uuid, 456)
|
|
256
|
+
self.assertEqual(send_meta.local_block_ids, [1, 2])
|
|
257
|
+
|
|
258
|
+
self.assertIsNotNone(params)
|
|
259
|
+
self.assertEqual(params["uuid"], 456)
|
|
260
|
+
self.assertEqual(params["remote_block_ids"], [1, 2])
|
|
261
|
+
self.assertEqual(params["remote_host"], "1.1.1.1")
|
|
262
|
+
self.assertEqual(params["remote_port"], 12345)
|
|
263
|
+
|
|
264
|
+
def test_request_finished_producer_not_finished(self):
|
|
265
|
+
"""Tests producer logic when a request is not yet finished."""
|
|
266
|
+
self.scheduler.is_producer = True
|
|
267
|
+
mock_request = MagicMock()
|
|
268
|
+
mock_request.status = RequestStatus.RUNNING # Not finished
|
|
269
|
+
delay_free, params = self.scheduler.request_finished(
|
|
270
|
+
mock_request, [1, 2])
|
|
271
|
+
self.assertFalse(delay_free)
|
|
272
|
+
self.assertIsNone(params)
|
|
273
|
+
|
|
274
|
+
def test_request_finished_producer_prompt_too_short(self):
|
|
275
|
+
"""Tests producer logic when prompt is too short to transfer."""
|
|
276
|
+
self.scheduler.is_producer = True
|
|
277
|
+
mock_request = MagicMock()
|
|
278
|
+
mock_request.request_id = "req-short"
|
|
279
|
+
mock_request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
|
280
|
+
mock_request.num_computed_tokens = 10 # less than a block
|
|
281
|
+
block_ids = [1]
|
|
282
|
+
|
|
283
|
+
delay_free, params = self.scheduler.request_finished(
|
|
284
|
+
mock_request, block_ids)
|
|
285
|
+
|
|
286
|
+
self.assertFalse(delay_free)
|
|
287
|
+
self.assertEqual(params, {})
|
|
288
|
+
self.assertNotIn("req-short", self.scheduler.reqs_to_send)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class TestTPUConnectorWorker(unittest.TestCase):
|
|
292
|
+
|
|
293
|
+
def setUp(self):
|
|
294
|
+
self.vllm_config = MockVllmConfig()
|
|
295
|
+
patchers = {
|
|
296
|
+
"jax":
|
|
297
|
+
patch('tpu_inference.distributed.tpu_connector.jax'),
|
|
298
|
+
"get_host_ip":
|
|
299
|
+
patch('tpu_inference.distributed.tpu_connector.get_host_ip',
|
|
300
|
+
return_value='127.0.0.1'),
|
|
301
|
+
"get_kv_transfer_port":
|
|
302
|
+
patch(
|
|
303
|
+
'tpu_inference.distributed.tpu_connector.get_kv_transfer_port',
|
|
304
|
+
return_value=10000),
|
|
305
|
+
"get_side_channel_port":
|
|
306
|
+
patch(
|
|
307
|
+
'tpu_inference.distributed.tpu_connector.get_side_channel_port',
|
|
308
|
+
return_value=20000),
|
|
309
|
+
"start_transfer_server":
|
|
310
|
+
patch(
|
|
311
|
+
'tpu_inference.distributed.tpu_connector.start_transfer_server'
|
|
312
|
+
),
|
|
313
|
+
"zmq":
|
|
314
|
+
patch('tpu_inference.distributed.tpu_connector.zmq'),
|
|
315
|
+
"threading":
|
|
316
|
+
patch('tpu_inference.distributed.tpu_connector.threading'),
|
|
317
|
+
"ThreadPoolExecutor":
|
|
318
|
+
patch(
|
|
319
|
+
'tpu_inference.distributed.tpu_connector.ThreadPoolExecutor'),
|
|
320
|
+
"device_array":
|
|
321
|
+
patch('tpu_inference.distributed.tpu_connector.device_array'),
|
|
322
|
+
"select_from_kv_caches":
|
|
323
|
+
patch(
|
|
324
|
+
'tpu_inference.distributed.tpu_connector.select_from_kv_caches'
|
|
325
|
+
),
|
|
326
|
+
"scatter_kv_slices":
|
|
327
|
+
patch('tpu_inference.distributed.tpu_connector.scatter_kv_slices'),
|
|
328
|
+
"time":
|
|
329
|
+
patch('tpu_inference.distributed.tpu_connector.time'),
|
|
330
|
+
"make_zmq_path":
|
|
331
|
+
patch('tpu_inference.distributed.tpu_connector.make_zmq_path'),
|
|
332
|
+
"make_zmq_socket":
|
|
333
|
+
patch('tpu_inference.distributed.tpu_connector.make_zmq_socket'),
|
|
334
|
+
}
|
|
335
|
+
self.all_mocks = {k: p.start() for k, p in patchers.items()}
|
|
336
|
+
self.all_mocks["jax"].local_devices.return_value = [MagicMock()]
|
|
337
|
+
for p in patchers.values():
|
|
338
|
+
self.addCleanup(p.stop)
|
|
339
|
+
|
|
340
|
+
def test_init_producer(self):
|
|
341
|
+
"""Tests worker initialization for the producer role."""
|
|
342
|
+
self.vllm_config.kv_transfer_config.is_kv_producer = True
|
|
343
|
+
worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
|
|
344
|
+
|
|
345
|
+
self.all_mocks["zmq"].Context.assert_called_once()
|
|
346
|
+
self.all_mocks["threading"].Thread.assert_called_once()
|
|
347
|
+
self.all_mocks["threading"].Event.assert_called()
|
|
348
|
+
self.all_mocks["ThreadPoolExecutor"].assert_not_called()
|
|
349
|
+
self.assertTrue(worker.is_producer)
|
|
350
|
+
|
|
351
|
+
def test_init_consumer(self):
|
|
352
|
+
"""Tests worker initialization for the consumer role."""
|
|
353
|
+
self.vllm_config.kv_transfer_config.is_kv_producer = False
|
|
354
|
+
worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
|
|
355
|
+
|
|
356
|
+
self.all_mocks["zmq"].Context.assert_called_once()
|
|
357
|
+
self.all_mocks["threading"].Thread.assert_not_called()
|
|
358
|
+
self.all_mocks["ThreadPoolExecutor"].assert_called_once_with(
|
|
359
|
+
max_workers=64)
|
|
360
|
+
self.assertFalse(worker.is_producer)
|
|
361
|
+
|
|
362
|
+
def test_register_runner(self):
|
|
363
|
+
"""Tests that runner registration correctly sets worker attributes."""
|
|
364
|
+
self.vllm_config.kv_transfer_config.is_kv_producer = False
|
|
365
|
+
worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
|
|
366
|
+
|
|
367
|
+
mock_runner = MagicMock()
|
|
368
|
+
mock_kv_cache_layer = MagicMock()
|
|
369
|
+
mock_kv_cache_layer.shape = [10, 20, 30, 40]
|
|
370
|
+
mock_kv_cache_layer.dtype = 'float32'
|
|
371
|
+
mock_kv_cache_layer.sharding = 'sharding_spec'
|
|
372
|
+
mock_runner.kv_caches = [mock_kv_cache_layer] * 5
|
|
373
|
+
mock_runner.mesh = 'mesh'
|
|
374
|
+
|
|
375
|
+
worker.register_runner(mock_runner)
|
|
376
|
+
|
|
377
|
+
self.all_mocks["start_transfer_server"].assert_called_once()
|
|
378
|
+
self.assertEqual(worker.runner, mock_runner)
|
|
379
|
+
self.assertEqual(worker.mesh, 'mesh')
|
|
380
|
+
self.assertEqual(worker.num_layers, 5)
|
|
381
|
+
self.assertEqual(worker.shape, [10, 20, 30, 40])
|
|
382
|
+
self.assertEqual(worker.dtype, 'float32')
|
|
383
|
+
self.assertEqual(worker.sharding, 'sharding_spec')
|
|
384
|
+
|
|
385
|
+
def test_process_send_load_for_producer(self):
|
|
386
|
+
"""Tests process_send_load for the producer role."""
|
|
387
|
+
self.vllm_config.kv_transfer_config.is_kv_producer = True
|
|
388
|
+
worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
|
|
389
|
+
worker._prepare_kv_and_wait = MagicMock()
|
|
390
|
+
|
|
391
|
+
meta = tpu_connector.TPUConnectorMetadata()
|
|
392
|
+
send_meta = tpu_connector.SendMeta(uuid=1,
|
|
393
|
+
local_block_ids=[1],
|
|
394
|
+
expiration_time=123)
|
|
395
|
+
meta.reqs_to_send = {"req1": send_meta}
|
|
396
|
+
|
|
397
|
+
worker.process_send_load(meta)
|
|
398
|
+
|
|
399
|
+
worker._prepare_kv_and_wait.assert_called_once_with("req1", send_meta)
|
|
400
|
+
|
|
401
|
+
def test_process_send_load_for_consumer_loading(self):
|
|
402
|
+
"""Tests process_send_load for a consumer that needs to load KV."""
|
|
403
|
+
self.vllm_config.kv_transfer_config.is_kv_producer = False
|
|
404
|
+
worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
|
|
405
|
+
worker._maybe_build_kv_connection = MagicMock(return_value="conn")
|
|
406
|
+
|
|
407
|
+
meta = tpu_connector.TPUConnectorMetadata()
|
|
408
|
+
load_meta = tpu_connector.LoadMeta(uuid=1,
|
|
409
|
+
local_block_ids=[1],
|
|
410
|
+
remote_block_ids=[10],
|
|
411
|
+
remote_host="host",
|
|
412
|
+
remote_port=123)
|
|
413
|
+
meta.reqs_to_load = {"req1": load_meta}
|
|
414
|
+
|
|
415
|
+
worker.process_send_load(meta)
|
|
416
|
+
|
|
417
|
+
worker._maybe_build_kv_connection.assert_called_once_with(load_meta)
|
|
418
|
+
self.all_mocks[
|
|
419
|
+
"ThreadPoolExecutor"].return_value.submit.assert_called_once_with(
|
|
420
|
+
worker._pull_kv, "conn", load_meta)
|
|
421
|
+
|
|
422
|
+
def test_process_send_load_for_consumer_notifying(self):
|
|
423
|
+
"""Tests process_send_load for a consumer that needs to notify."""
|
|
424
|
+
self.vllm_config.kv_transfer_config.is_kv_producer = False
|
|
425
|
+
worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
|
|
426
|
+
worker._maybe_build_notif_socket = MagicMock(return_value="socket")
|
|
427
|
+
worker._notify_pull_done = MagicMock()
|
|
428
|
+
|
|
429
|
+
meta = tpu_connector.TPUConnectorMetadata()
|
|
430
|
+
load_meta = tpu_connector.LoadMeta(uuid=1,
|
|
431
|
+
local_block_ids=None,
|
|
432
|
+
remote_block_ids=None,
|
|
433
|
+
remote_host="host",
|
|
434
|
+
remote_port=123)
|
|
435
|
+
meta.reqs_to_load = {"req1": load_meta}
|
|
436
|
+
|
|
437
|
+
worker.process_send_load(meta)
|
|
438
|
+
|
|
439
|
+
worker._maybe_build_notif_socket.assert_called_once_with(load_meta)
|
|
440
|
+
worker._notify_pull_done.assert_called_once_with("socket", "req1")
|
|
441
|
+
|
|
442
|
+
def test_get_finished_recving(self):
|
|
443
|
+
"""Tests get_finished for a request that has finished pulling."""
|
|
444
|
+
self.vllm_config.kv_transfer_config.is_kv_producer = False
|
|
445
|
+
worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
|
|
446
|
+
worker.runner = MagicMock()
|
|
447
|
+
original_kv_caches = worker.runner.kv_caches
|
|
448
|
+
|
|
449
|
+
mock_future = MagicMock()
|
|
450
|
+
mock_future.done.return_value = True
|
|
451
|
+
mock_future.result.return_value = ('kv_data', 'indices')
|
|
452
|
+
worker.reqs_pulling = {'req1': mock_future}
|
|
453
|
+
|
|
454
|
+
done_sending, done_recving = worker.get_finished()
|
|
455
|
+
|
|
456
|
+
self.assertEqual(done_sending, set())
|
|
457
|
+
self.assertEqual(done_recving, {'req1'})
|
|
458
|
+
self.assertNotIn('req1', worker.reqs_pulling)
|
|
459
|
+
self.all_mocks['scatter_kv_slices'].assert_called_once_with(
|
|
460
|
+
original_kv_caches, 'kv_data', 'indices')
|
|
461
|
+
|
|
462
|
+
def test_get_finished_sending_expired(self):
|
|
463
|
+
"""Tests get_finished for a request that has expired."""
|
|
464
|
+
self.vllm_config.kv_transfer_config.is_kv_producer = True
|
|
465
|
+
worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
|
|
466
|
+
|
|
467
|
+
self.all_mocks['time'].perf_counter.return_value = 1000
|
|
468
|
+
worker.reqs_wait_pull = {'req1': ['kv_data', 900]}
|
|
469
|
+
|
|
470
|
+
done_sending, done_recving = worker.get_finished()
|
|
471
|
+
|
|
472
|
+
self.assertEqual(done_sending, {'req1'})
|
|
473
|
+
self.assertEqual(done_recving, set())
|
|
474
|
+
self.assertNotIn('req1', worker.reqs_wait_pull)
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
if __name__ == "__main__":
|
|
478
|
+
unittest.main()
|
tests/e2e/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|