tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,724 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock, patch
|
|
16
|
+
|
|
17
|
+
import pytest
|
|
18
|
+
from vllm.config import VllmConfig
|
|
19
|
+
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
|
20
|
+
from vllm.v1.core.sched.scheduler import Scheduler
|
|
21
|
+
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
22
|
+
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
|
|
23
|
+
from vllm.v1.request import Request
|
|
24
|
+
|
|
25
|
+
from tpu_inference.core.sched.dp_scheduler import (
|
|
26
|
+
DPScheduler, DPSchedulerOutput, SchedulerCommand,
|
|
27
|
+
update_vllm_config_for_dp_scheduler)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TestDPScheduler:
|
|
31
|
+
|
|
32
|
+
@pytest.fixture
|
|
33
|
+
def mock_vllm_config(self):
|
|
34
|
+
"""Create a mock VllmConfig for testing."""
|
|
35
|
+
config = MagicMock(spec=VllmConfig)
|
|
36
|
+
config.sharding_config = MagicMock()
|
|
37
|
+
config.sharding_config.total_dp_size = 2
|
|
38
|
+
config.scheduler_config = MagicMock()
|
|
39
|
+
config.scheduler_config._original_scheduler_cls = Scheduler
|
|
40
|
+
config.scheduler_config.max_num_seqs = 8
|
|
41
|
+
config.scheduler_config.max_num_batched_tokens = 1024
|
|
42
|
+
config.scheduler_config.async_scheduling = False
|
|
43
|
+
return config
|
|
44
|
+
|
|
45
|
+
@pytest.fixture
|
|
46
|
+
def mock_kv_cache_config(self):
|
|
47
|
+
"""Create a mock KVCacheConfig for testing."""
|
|
48
|
+
config = MagicMock(spec=KVCacheConfig)
|
|
49
|
+
config.num_blocks = 100
|
|
50
|
+
return config
|
|
51
|
+
|
|
52
|
+
@pytest.fixture
|
|
53
|
+
def mock_structured_output_manager(self):
|
|
54
|
+
"""Create a mock StructuredOutputManager."""
|
|
55
|
+
return MagicMock()
|
|
56
|
+
|
|
57
|
+
def test_init_creates_worker_processes(
|
|
58
|
+
self,
|
|
59
|
+
mock_vllm_config,
|
|
60
|
+
mock_kv_cache_config,
|
|
61
|
+
mock_structured_output_manager,
|
|
62
|
+
):
|
|
63
|
+
"""Test initialization creates worker processes for each DP rank."""
|
|
64
|
+
with patch(
|
|
65
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
66
|
+
):
|
|
67
|
+
with patch('multiprocessing.get_context') as mock_get_context:
|
|
68
|
+
# Setup mock context
|
|
69
|
+
mock_ctx = MagicMock()
|
|
70
|
+
mock_process = MagicMock()
|
|
71
|
+
mock_queue = MagicMock()
|
|
72
|
+
|
|
73
|
+
mock_ctx.Queue = MagicMock(return_value=mock_queue)
|
|
74
|
+
mock_ctx.Process = MagicMock(return_value=mock_process)
|
|
75
|
+
mock_get_context.return_value = mock_ctx
|
|
76
|
+
|
|
77
|
+
scheduler = DPScheduler(
|
|
78
|
+
vllm_config=mock_vllm_config,
|
|
79
|
+
kv_cache_config=mock_kv_cache_config,
|
|
80
|
+
structured_output_manager=mock_structured_output_manager,
|
|
81
|
+
block_size=16,
|
|
82
|
+
log_stats=True,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Verify processes and queues were created
|
|
86
|
+
assert scheduler.dp_size == 2
|
|
87
|
+
assert len(scheduler.processes) == 2
|
|
88
|
+
assert len(scheduler.input_queues) == 2
|
|
89
|
+
assert len(scheduler.output_queues) == 2
|
|
90
|
+
assert scheduler.log_stats is True
|
|
91
|
+
assert len(scheduler.per_rank_kv_cache_configs) == 2
|
|
92
|
+
|
|
93
|
+
# Verify each rank got the correct config
|
|
94
|
+
for rank_config in scheduler.per_rank_kv_cache_configs:
|
|
95
|
+
assert rank_config.num_blocks == 50 # 100 / 2
|
|
96
|
+
|
|
97
|
+
# Verify processes were started
|
|
98
|
+
assert mock_process.start.call_count == 2
|
|
99
|
+
|
|
100
|
+
def test_get_rank_token_counts(self, mock_vllm_config,
|
|
101
|
+
mock_kv_cache_config,
|
|
102
|
+
mock_structured_output_manager):
|
|
103
|
+
"""Test _get_rank_token_counts queries workers and aggregates tokens."""
|
|
104
|
+
with patch(
|
|
105
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
106
|
+
):
|
|
107
|
+
with patch('multiprocessing.get_context'):
|
|
108
|
+
scheduler = DPScheduler(
|
|
109
|
+
vllm_config=mock_vllm_config,
|
|
110
|
+
kv_cache_config=mock_kv_cache_config,
|
|
111
|
+
structured_output_manager=mock_structured_output_manager,
|
|
112
|
+
block_size=16,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Mock the queues
|
|
116
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
117
|
+
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
118
|
+
|
|
119
|
+
# Mock responses from workers
|
|
120
|
+
scheduler.output_queues[0].get = MagicMock(return_value=30)
|
|
121
|
+
scheduler.output_queues[1].get = MagicMock(return_value=15)
|
|
122
|
+
|
|
123
|
+
rank_tokens = scheduler._get_rank_token_counts()
|
|
124
|
+
|
|
125
|
+
# Verify correct commands were sent
|
|
126
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
127
|
+
(SchedulerCommand.GET_TOKEN_COUNT, None))
|
|
128
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
129
|
+
(SchedulerCommand.GET_TOKEN_COUNT, None))
|
|
130
|
+
|
|
131
|
+
assert rank_tokens[0] == 30
|
|
132
|
+
assert rank_tokens[1] == 15
|
|
133
|
+
|
|
134
|
+
def test_find_best_rank_with_cache_hit(self, mock_vllm_config,
|
|
135
|
+
mock_kv_cache_config,
|
|
136
|
+
mock_structured_output_manager):
|
|
137
|
+
"""Test _find_best_rank_for_request prefers cache hits."""
|
|
138
|
+
with patch(
|
|
139
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
140
|
+
):
|
|
141
|
+
with patch('multiprocessing.get_context'):
|
|
142
|
+
scheduler = DPScheduler(
|
|
143
|
+
vllm_config=mock_vllm_config,
|
|
144
|
+
kv_cache_config=mock_kv_cache_config,
|
|
145
|
+
structured_output_manager=mock_structured_output_manager,
|
|
146
|
+
block_size=16,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
mock_request = MagicMock(spec=Request)
|
|
150
|
+
|
|
151
|
+
# Mock the queues
|
|
152
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
153
|
+
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
154
|
+
|
|
155
|
+
# Track call counts for proper sequencing
|
|
156
|
+
call_sequence = [100, 50, ([], 10), ([], 25)]
|
|
157
|
+
|
|
158
|
+
# Both queues use the same sequence
|
|
159
|
+
for q in scheduler.output_queues:
|
|
160
|
+
q.get = MagicMock(
|
|
161
|
+
side_effect=lambda timeout=None: call_sequence[len([
|
|
162
|
+
c for c in scheduler.output_queues if c.get.called
|
|
163
|
+
])])
|
|
164
|
+
|
|
165
|
+
# Simpler mock setup
|
|
166
|
+
responses_0 = [100, ([], 10)]
|
|
167
|
+
responses_1 = [50, ([], 25)]
|
|
168
|
+
scheduler.output_queues[0].get = MagicMock(
|
|
169
|
+
side_effect=responses_0)
|
|
170
|
+
scheduler.output_queues[1].get = MagicMock(
|
|
171
|
+
side_effect=responses_1)
|
|
172
|
+
|
|
173
|
+
rank = scheduler._find_best_rank_for_request(mock_request)
|
|
174
|
+
|
|
175
|
+
# Should prefer rank with better cache hit
|
|
176
|
+
assert rank == 1
|
|
177
|
+
|
|
178
|
+
def test_find_best_rank_without_cache_hit(self, mock_vllm_config,
|
|
179
|
+
mock_kv_cache_config,
|
|
180
|
+
mock_structured_output_manager):
|
|
181
|
+
"""Test _find_best_rank_for_request uses load balancing without cache hit."""
|
|
182
|
+
with patch(
|
|
183
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
184
|
+
):
|
|
185
|
+
with patch('multiprocessing.get_context'):
|
|
186
|
+
scheduler = DPScheduler(
|
|
187
|
+
vllm_config=mock_vllm_config,
|
|
188
|
+
kv_cache_config=mock_kv_cache_config,
|
|
189
|
+
structured_output_manager=mock_structured_output_manager,
|
|
190
|
+
block_size=16,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
mock_request = MagicMock(spec=Request)
|
|
194
|
+
|
|
195
|
+
# Mock the queues
|
|
196
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
197
|
+
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
198
|
+
|
|
199
|
+
# No cache hits - both return 0
|
|
200
|
+
scheduler.output_queues[0].get = MagicMock(
|
|
201
|
+
side_effect=[100, ([], 0)])
|
|
202
|
+
scheduler.output_queues[1].get = MagicMock(
|
|
203
|
+
side_effect=[50, ([], 0)])
|
|
204
|
+
|
|
205
|
+
rank = scheduler._find_best_rank_for_request(mock_request)
|
|
206
|
+
|
|
207
|
+
# Should choose rank with fewer tokens (rank 1)
|
|
208
|
+
assert rank == 1
|
|
209
|
+
|
|
210
|
+
def test_add_request_assigns_to_best_rank(self, mock_vllm_config,
|
|
211
|
+
mock_kv_cache_config,
|
|
212
|
+
mock_structured_output_manager):
|
|
213
|
+
"""Test add_request assigns request to best rank."""
|
|
214
|
+
with patch(
|
|
215
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
216
|
+
):
|
|
217
|
+
with patch('multiprocessing.get_context'):
|
|
218
|
+
scheduler = DPScheduler(
|
|
219
|
+
vllm_config=mock_vllm_config,
|
|
220
|
+
kv_cache_config=mock_kv_cache_config,
|
|
221
|
+
structured_output_manager=mock_structured_output_manager,
|
|
222
|
+
block_size=16,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
mock_request = MagicMock(spec=Request)
|
|
226
|
+
mock_request.request_id = "req1"
|
|
227
|
+
|
|
228
|
+
# Mock the queues
|
|
229
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
230
|
+
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
231
|
+
scheduler.output_queues[0].get = MagicMock()
|
|
232
|
+
scheduler.output_queues[1].get = MagicMock()
|
|
233
|
+
|
|
234
|
+
# Mock _find_best_rank_for_request to return rank 1
|
|
235
|
+
scheduler._find_best_rank_for_request = MagicMock(
|
|
236
|
+
return_value=1)
|
|
237
|
+
|
|
238
|
+
scheduler.add_request(mock_request)
|
|
239
|
+
|
|
240
|
+
# Verify request was assigned to rank 1
|
|
241
|
+
assert scheduler.assigned_dp_rank["req1"] == 1
|
|
242
|
+
|
|
243
|
+
# Verify ADD_REQUEST command was sent to rank 1
|
|
244
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
245
|
+
(SchedulerCommand.ADD_REQUEST, mock_request))
|
|
246
|
+
|
|
247
|
+
# Verify we waited for completion
|
|
248
|
+
scheduler.output_queues[1].get.assert_called_once()
|
|
249
|
+
|
|
250
|
+
def test_schedule_sends_commands_and_combines_output(
|
|
251
|
+
self, mock_vllm_config, mock_kv_cache_config,
|
|
252
|
+
mock_structured_output_manager):
|
|
253
|
+
"""Test schedule sends SCHEDULE command to all workers and combines output."""
|
|
254
|
+
with patch(
|
|
255
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
256
|
+
):
|
|
257
|
+
with patch('multiprocessing.get_context'):
|
|
258
|
+
scheduler = DPScheduler(
|
|
259
|
+
vllm_config=mock_vllm_config,
|
|
260
|
+
kv_cache_config=mock_kv_cache_config,
|
|
261
|
+
structured_output_manager=mock_structured_output_manager,
|
|
262
|
+
block_size=16,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# Mock the queues
|
|
266
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
267
|
+
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
268
|
+
|
|
269
|
+
# Create mock scheduler outputs
|
|
270
|
+
mock_output_0 = MagicMock(spec=SchedulerOutput)
|
|
271
|
+
mock_output_0.scheduled_new_reqs = []
|
|
272
|
+
mock_output_0.num_scheduled_tokens = {"req1": 10}
|
|
273
|
+
mock_output_0.total_num_scheduled_tokens = 10
|
|
274
|
+
mock_output_0.finished_req_ids = set()
|
|
275
|
+
mock_output_0.scheduled_cached_reqs = CachedRequestData(
|
|
276
|
+
req_ids=[],
|
|
277
|
+
resumed_req_ids=[],
|
|
278
|
+
new_token_ids=[],
|
|
279
|
+
all_token_ids=[],
|
|
280
|
+
new_block_ids=[],
|
|
281
|
+
num_computed_tokens=[],
|
|
282
|
+
num_output_tokens=[],
|
|
283
|
+
)
|
|
284
|
+
mock_output_0.scheduled_spec_decode_tokens = {}
|
|
285
|
+
mock_output_0.scheduled_encoder_inputs = {}
|
|
286
|
+
mock_output_0.num_common_prefix_blocks = []
|
|
287
|
+
|
|
288
|
+
mock_output_1 = MagicMock(spec=SchedulerOutput)
|
|
289
|
+
mock_output_1.scheduled_new_reqs = []
|
|
290
|
+
mock_output_1.num_scheduled_tokens = {"req2": 20}
|
|
291
|
+
mock_output_1.total_num_scheduled_tokens = 20
|
|
292
|
+
mock_output_1.finished_req_ids = set()
|
|
293
|
+
mock_output_1.scheduled_cached_reqs = CachedRequestData(
|
|
294
|
+
req_ids=[],
|
|
295
|
+
resumed_req_ids=[],
|
|
296
|
+
new_token_ids=[],
|
|
297
|
+
all_token_ids=[],
|
|
298
|
+
new_block_ids=[],
|
|
299
|
+
num_computed_tokens=[],
|
|
300
|
+
num_output_tokens=[],
|
|
301
|
+
)
|
|
302
|
+
mock_output_1.scheduled_spec_decode_tokens = {}
|
|
303
|
+
mock_output_1.scheduled_encoder_inputs = {}
|
|
304
|
+
mock_output_1.num_common_prefix_blocks = []
|
|
305
|
+
|
|
306
|
+
# Setup mock queue responses
|
|
307
|
+
scheduler.output_queues[0].get = MagicMock(
|
|
308
|
+
return_value=mock_output_0)
|
|
309
|
+
scheduler.output_queues[1].get = MagicMock(
|
|
310
|
+
return_value=mock_output_1)
|
|
311
|
+
|
|
312
|
+
# Setup assigned ranks
|
|
313
|
+
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1}
|
|
314
|
+
|
|
315
|
+
output = scheduler.schedule()
|
|
316
|
+
|
|
317
|
+
# Verify SCHEDULE commands were sent
|
|
318
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
319
|
+
(SchedulerCommand.SCHEDULE, None))
|
|
320
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
321
|
+
(SchedulerCommand.SCHEDULE, None))
|
|
322
|
+
|
|
323
|
+
# Verify combined output
|
|
324
|
+
assert isinstance(output, DPSchedulerOutput)
|
|
325
|
+
assert output.total_num_scheduled_tokens == 30
|
|
326
|
+
assert "req1" in output.num_scheduled_tokens
|
|
327
|
+
assert "req2" in output.num_scheduled_tokens
|
|
328
|
+
assert output.assigned_dp_rank == {"req1": 0, "req2": 1}
|
|
329
|
+
|
|
330
|
+
def test_combine_cached_request_data(self, mock_vllm_config,
|
|
331
|
+
mock_kv_cache_config,
|
|
332
|
+
mock_structured_output_manager):
|
|
333
|
+
"""Test _combine_cached_request_data combines data from all ranks."""
|
|
334
|
+
with patch(
|
|
335
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
336
|
+
):
|
|
337
|
+
with patch('multiprocessing.get_context'):
|
|
338
|
+
scheduler = DPScheduler(
|
|
339
|
+
vllm_config=mock_vllm_config,
|
|
340
|
+
kv_cache_config=mock_kv_cache_config,
|
|
341
|
+
structured_output_manager=mock_structured_output_manager,
|
|
342
|
+
block_size=16,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Create mock rank outputs
|
|
346
|
+
output_0 = MagicMock(spec=SchedulerOutput)
|
|
347
|
+
output_0.scheduled_cached_reqs = CachedRequestData(
|
|
348
|
+
req_ids=["req1"],
|
|
349
|
+
resumed_req_ids=["req1"],
|
|
350
|
+
new_token_ids=[[1, 2, 3]],
|
|
351
|
+
all_token_ids=[[1, 2, 3, 4, 5]],
|
|
352
|
+
new_block_ids=[[10, 11]],
|
|
353
|
+
num_computed_tokens=[5],
|
|
354
|
+
num_output_tokens=[3],
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
output_1 = MagicMock(spec=SchedulerOutput)
|
|
358
|
+
output_1.scheduled_cached_reqs = CachedRequestData(
|
|
359
|
+
req_ids=["req2"],
|
|
360
|
+
resumed_req_ids=[],
|
|
361
|
+
new_token_ids=[[6, 7]],
|
|
362
|
+
all_token_ids=[[6, 7, 8, 9]],
|
|
363
|
+
new_block_ids=[[20, 21]],
|
|
364
|
+
num_computed_tokens=[4],
|
|
365
|
+
num_output_tokens=[2],
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
combined = scheduler._combine_cached_request_data(
|
|
369
|
+
[output_0, output_1])
|
|
370
|
+
|
|
371
|
+
# Verify combined data
|
|
372
|
+
assert combined.req_ids == ["req1", "req2"]
|
|
373
|
+
assert combined.resumed_req_ids == ["req1"]
|
|
374
|
+
assert combined.new_token_ids == [[1, 2, 3], [6, 7]]
|
|
375
|
+
assert combined.num_computed_tokens == [5, 4]
|
|
376
|
+
assert combined.num_output_tokens == [3, 2]
|
|
377
|
+
|
|
378
|
+
def test_finish_requests_routes_to_workers(self, mock_vllm_config,
|
|
379
|
+
mock_kv_cache_config,
|
|
380
|
+
mock_structured_output_manager):
|
|
381
|
+
"""Test finish_requests sends FINISH_REQUESTS command to appropriate workers."""
|
|
382
|
+
with patch(
|
|
383
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
384
|
+
):
|
|
385
|
+
with patch('multiprocessing.get_context'):
|
|
386
|
+
scheduler = DPScheduler(
|
|
387
|
+
vllm_config=mock_vllm_config,
|
|
388
|
+
kv_cache_config=mock_kv_cache_config,
|
|
389
|
+
structured_output_manager=mock_structured_output_manager,
|
|
390
|
+
block_size=16,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
394
|
+
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
395
|
+
scheduler.output_queues[0].get = MagicMock()
|
|
396
|
+
scheduler.output_queues[1].get = MagicMock()
|
|
397
|
+
|
|
398
|
+
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
|
|
399
|
+
|
|
400
|
+
# Test with list of requests
|
|
401
|
+
scheduler.finish_requests(["req1", "req2"],
|
|
402
|
+
finished_status="completed")
|
|
403
|
+
|
|
404
|
+
# Verify FINISH_REQUESTS commands were sent to correct ranks
|
|
405
|
+
scheduler.input_queues[0].put.assert_called()
|
|
406
|
+
scheduler.input_queues[1].put.assert_called()
|
|
407
|
+
|
|
408
|
+
def test_get_num_unfinished_requests(self, mock_vllm_config,
|
|
409
|
+
mock_kv_cache_config,
|
|
410
|
+
mock_structured_output_manager):
|
|
411
|
+
"""Test get_num_unfinished_requests queries all workers."""
|
|
412
|
+
with patch(
|
|
413
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
414
|
+
):
|
|
415
|
+
with patch('multiprocessing.get_context'):
|
|
416
|
+
scheduler = DPScheduler(
|
|
417
|
+
vllm_config=mock_vllm_config,
|
|
418
|
+
kv_cache_config=mock_kv_cache_config,
|
|
419
|
+
structured_output_manager=mock_structured_output_manager,
|
|
420
|
+
block_size=16,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
424
|
+
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
425
|
+
|
|
426
|
+
scheduler.output_queues[0].get = MagicMock(return_value=5)
|
|
427
|
+
scheduler.output_queues[1].get = MagicMock(return_value=3)
|
|
428
|
+
|
|
429
|
+
total = scheduler.get_num_unfinished_requests()
|
|
430
|
+
|
|
431
|
+
# Verify commands were sent
|
|
432
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
433
|
+
(SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS, None))
|
|
434
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
435
|
+
(SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS, None))
|
|
436
|
+
|
|
437
|
+
assert total == 8
|
|
438
|
+
|
|
439
|
+
def test_has_finished_requests(self, mock_vllm_config,
|
|
440
|
+
mock_kv_cache_config,
|
|
441
|
+
mock_structured_output_manager):
|
|
442
|
+
"""Test has_finished_requests checks all workers."""
|
|
443
|
+
with patch(
|
|
444
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
445
|
+
):
|
|
446
|
+
with patch('multiprocessing.get_context'):
|
|
447
|
+
scheduler = DPScheduler(
|
|
448
|
+
vllm_config=mock_vllm_config,
|
|
449
|
+
kv_cache_config=mock_kv_cache_config,
|
|
450
|
+
structured_output_manager=mock_structured_output_manager,
|
|
451
|
+
block_size=16,
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
455
|
+
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
456
|
+
|
|
457
|
+
scheduler.output_queues[0].get = MagicMock(return_value=False)
|
|
458
|
+
scheduler.output_queues[1].get = MagicMock(return_value=True)
|
|
459
|
+
|
|
460
|
+
result = scheduler.has_finished_requests()
|
|
461
|
+
|
|
462
|
+
assert result is True
|
|
463
|
+
|
|
464
|
+
# Verify commands were sent
|
|
465
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
466
|
+
(SchedulerCommand.HAS_FINISHED_REQUESTS, None))
|
|
467
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
468
|
+
(SchedulerCommand.HAS_FINISHED_REQUESTS, None))
|
|
469
|
+
|
|
470
|
+
def test_get_request_counts(self, mock_vllm_config, mock_kv_cache_config,
|
|
471
|
+
mock_structured_output_manager):
|
|
472
|
+
"""Test get_request_counts queries all workers."""
|
|
473
|
+
with patch(
|
|
474
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
475
|
+
):
|
|
476
|
+
with patch('multiprocessing.get_context'):
|
|
477
|
+
scheduler = DPScheduler(
|
|
478
|
+
vllm_config=mock_vllm_config,
|
|
479
|
+
kv_cache_config=mock_kv_cache_config,
|
|
480
|
+
structured_output_manager=mock_structured_output_manager,
|
|
481
|
+
block_size=16,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
485
|
+
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
486
|
+
|
|
487
|
+
scheduler.output_queues[0].get = MagicMock(return_value=(2, 1))
|
|
488
|
+
scheduler.output_queues[1].get = MagicMock(return_value=(1, 3))
|
|
489
|
+
|
|
490
|
+
running, waiting = scheduler.get_request_counts()
|
|
491
|
+
|
|
492
|
+
# Verify commands were sent
|
|
493
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
494
|
+
(SchedulerCommand.GET_REQUEST_COUNTS, None))
|
|
495
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
496
|
+
(SchedulerCommand.GET_REQUEST_COUNTS, None))
|
|
497
|
+
|
|
498
|
+
assert running == 3
|
|
499
|
+
assert waiting == 4
|
|
500
|
+
|
|
501
|
+
def test_reset_prefix_cache(self, mock_vllm_config, mock_kv_cache_config,
|
|
502
|
+
mock_structured_output_manager):
|
|
503
|
+
"""Test reset_prefix_cache sends command to all workers."""
|
|
504
|
+
with patch(
|
|
505
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
506
|
+
):
|
|
507
|
+
with patch('multiprocessing.get_context'):
|
|
508
|
+
scheduler = DPScheduler(
|
|
509
|
+
vllm_config=mock_vllm_config,
|
|
510
|
+
kv_cache_config=mock_kv_cache_config,
|
|
511
|
+
structured_output_manager=mock_structured_output_manager,
|
|
512
|
+
block_size=16,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
516
|
+
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
517
|
+
|
|
518
|
+
scheduler.output_queues[0].get = MagicMock(return_value=True)
|
|
519
|
+
scheduler.output_queues[1].get = MagicMock(return_value=True)
|
|
520
|
+
|
|
521
|
+
result = scheduler.reset_prefix_cache()
|
|
522
|
+
|
|
523
|
+
# Verify commands were sent
|
|
524
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
525
|
+
(SchedulerCommand.RESET_PREFIX_CACHE, None))
|
|
526
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
527
|
+
(SchedulerCommand.RESET_PREFIX_CACHE, None))
|
|
528
|
+
|
|
529
|
+
assert result is True
|
|
530
|
+
|
|
531
|
+
def test_make_stats_aggregates_from_workers(
|
|
532
|
+
self, mock_vllm_config, mock_kv_cache_config,
|
|
533
|
+
mock_structured_output_manager):
|
|
534
|
+
"""Test make_stats aggregates statistics from all workers."""
|
|
535
|
+
with patch(
|
|
536
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
537
|
+
):
|
|
538
|
+
with patch('multiprocessing.get_context'):
|
|
539
|
+
scheduler = DPScheduler(
|
|
540
|
+
vllm_config=mock_vllm_config,
|
|
541
|
+
kv_cache_config=mock_kv_cache_config,
|
|
542
|
+
structured_output_manager=mock_structured_output_manager,
|
|
543
|
+
block_size=16,
|
|
544
|
+
log_stats=True,
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
548
|
+
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
549
|
+
|
|
550
|
+
# Create mock stats
|
|
551
|
+
stats_0 = SchedulerStats(
|
|
552
|
+
num_running_reqs=3,
|
|
553
|
+
num_waiting_reqs=2,
|
|
554
|
+
kv_cache_usage=0.5,
|
|
555
|
+
prefix_cache_stats=PrefixCacheStats(reset=False,
|
|
556
|
+
requests=10,
|
|
557
|
+
queries=8,
|
|
558
|
+
hits=5),
|
|
559
|
+
connector_prefix_cache_stats=PrefixCacheStats(reset=False,
|
|
560
|
+
requests=5,
|
|
561
|
+
queries=4,
|
|
562
|
+
hits=2),
|
|
563
|
+
spec_decoding_stats=None,
|
|
564
|
+
kv_connector_stats=None,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
stats_1 = SchedulerStats(
|
|
568
|
+
num_running_reqs=4,
|
|
569
|
+
num_waiting_reqs=1,
|
|
570
|
+
kv_cache_usage=0.7,
|
|
571
|
+
prefix_cache_stats=PrefixCacheStats(reset=False,
|
|
572
|
+
requests=15,
|
|
573
|
+
queries=12,
|
|
574
|
+
hits=8),
|
|
575
|
+
connector_prefix_cache_stats=PrefixCacheStats(reset=False,
|
|
576
|
+
requests=6,
|
|
577
|
+
queries=5,
|
|
578
|
+
hits=3),
|
|
579
|
+
spec_decoding_stats=None,
|
|
580
|
+
kv_connector_stats=None,
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
scheduler.output_queues[0].get = MagicMock(
|
|
584
|
+
return_value=stats_0)
|
|
585
|
+
scheduler.output_queues[1].get = MagicMock(
|
|
586
|
+
return_value=stats_1)
|
|
587
|
+
|
|
588
|
+
combined_stats = scheduler.make_stats()
|
|
589
|
+
|
|
590
|
+
# Verify commands were sent
|
|
591
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
592
|
+
(SchedulerCommand.MAKE_STATS, (None, None)))
|
|
593
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
594
|
+
(SchedulerCommand.MAKE_STATS, (None, None)))
|
|
595
|
+
|
|
596
|
+
assert combined_stats.num_running_reqs == 7
|
|
597
|
+
assert combined_stats.num_waiting_reqs == 3
|
|
598
|
+
assert combined_stats.kv_cache_usage == 0.6
|
|
599
|
+
|
|
600
|
+
def test_make_stats_returns_none_when_disabled(
|
|
601
|
+
self, mock_vllm_config, mock_kv_cache_config,
|
|
602
|
+
mock_structured_output_manager):
|
|
603
|
+
"""Test make_stats returns None when logging disabled."""
|
|
604
|
+
with patch(
|
|
605
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
606
|
+
):
|
|
607
|
+
with patch('multiprocessing.get_context'):
|
|
608
|
+
scheduler = DPScheduler(
|
|
609
|
+
vllm_config=mock_vllm_config,
|
|
610
|
+
kv_cache_config=mock_kv_cache_config,
|
|
611
|
+
structured_output_manager=mock_structured_output_manager,
|
|
612
|
+
block_size=16,
|
|
613
|
+
log_stats=False,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
stats = scheduler.make_stats()
|
|
617
|
+
assert stats is None
|
|
618
|
+
|
|
619
|
+
def test_update_draft_token_ids(self, mock_vllm_config,
|
|
620
|
+
mock_kv_cache_config,
|
|
621
|
+
mock_structured_output_manager):
|
|
622
|
+
"""Test update_draft_token_ids routes to correct workers."""
|
|
623
|
+
with patch(
|
|
624
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
625
|
+
):
|
|
626
|
+
with patch('multiprocessing.get_context'):
|
|
627
|
+
scheduler = DPScheduler(
|
|
628
|
+
vllm_config=mock_vllm_config,
|
|
629
|
+
kv_cache_config=mock_kv_cache_config,
|
|
630
|
+
structured_output_manager=mock_structured_output_manager,
|
|
631
|
+
block_size=16,
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
635
|
+
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
636
|
+
scheduler.output_queues[0].get = MagicMock()
|
|
637
|
+
scheduler.output_queues[1].get = MagicMock()
|
|
638
|
+
|
|
639
|
+
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
|
|
640
|
+
|
|
641
|
+
draft_token_ids = MagicMock()
|
|
642
|
+
draft_token_ids.req_ids = ["req1", "req2", "req3"]
|
|
643
|
+
draft_token_ids.draft_token_ids = [
|
|
644
|
+
[101, 102, 103],
|
|
645
|
+
[201, 202],
|
|
646
|
+
[301, 302, 303, 304],
|
|
647
|
+
]
|
|
648
|
+
|
|
649
|
+
scheduler.update_draft_token_ids(draft_token_ids)
|
|
650
|
+
|
|
651
|
+
# Verify commands were sent to correct workers
|
|
652
|
+
scheduler.input_queues[0].put.assert_called()
|
|
653
|
+
scheduler.input_queues[1].put.assert_called()
|
|
654
|
+
|
|
655
|
+
def test_shutdown(self, mock_vllm_config, mock_kv_cache_config,
|
|
656
|
+
mock_structured_output_manager):
|
|
657
|
+
"""Test shutdown sends SHUTDOWN command to all workers."""
|
|
658
|
+
with patch(
|
|
659
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
660
|
+
):
|
|
661
|
+
with patch('multiprocessing.get_context'):
|
|
662
|
+
scheduler = DPScheduler(
|
|
663
|
+
vllm_config=mock_vllm_config,
|
|
664
|
+
kv_cache_config=mock_kv_cache_config,
|
|
665
|
+
structured_output_manager=mock_structured_output_manager,
|
|
666
|
+
block_size=16,
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
670
|
+
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
671
|
+
scheduler.output_queues[0].get = MagicMock()
|
|
672
|
+
scheduler.output_queues[1].get = MagicMock()
|
|
673
|
+
|
|
674
|
+
mock_process_0 = MagicMock()
|
|
675
|
+
mock_process_1 = MagicMock()
|
|
676
|
+
mock_process_0.is_alive = MagicMock(return_value=False)
|
|
677
|
+
mock_process_1.is_alive = MagicMock(return_value=False)
|
|
678
|
+
scheduler.processes = [mock_process_0, mock_process_1]
|
|
679
|
+
|
|
680
|
+
scheduler.shutdown()
|
|
681
|
+
|
|
682
|
+
# Verify SHUTDOWN commands were sent
|
|
683
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
684
|
+
(SchedulerCommand.SHUTDOWN, None))
|
|
685
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
686
|
+
(SchedulerCommand.SHUTDOWN, None))
|
|
687
|
+
|
|
688
|
+
# Verify processes were joined
|
|
689
|
+
mock_process_0.join.assert_called()
|
|
690
|
+
mock_process_1.join.assert_called()
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
class TestUpdateVllmConfigForDPScheduler:
|
|
694
|
+
"""Test the update_vllm_config_for_dp_scheduler function."""
|
|
695
|
+
|
|
696
|
+
def test_update_config_with_dp_size_greater_than_one(self):
|
|
697
|
+
"""Test Config is updated when DP size > 1."""
|
|
698
|
+
mock_config = MagicMock()
|
|
699
|
+
mock_config.sharding_config.total_dp_size = 2
|
|
700
|
+
mock_config.scheduler_config._original_scheduler_cls = None
|
|
701
|
+
mock_config.scheduler_config.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
|
|
702
|
+
mock_config.scheduler_config.async_scheduling = False
|
|
703
|
+
|
|
704
|
+
update_vllm_config_for_dp_scheduler(mock_config)
|
|
705
|
+
|
|
706
|
+
# Verify config was updated
|
|
707
|
+
assert mock_config.scheduler_config._original_scheduler_cls == Scheduler
|
|
708
|
+
assert mock_config.scheduler_config.scheduler_cls == DPScheduler
|
|
709
|
+
|
|
710
|
+
def test_update_config_with_dp_size_one(self):
|
|
711
|
+
"""Test that config is NOT updated when DP size == 1."""
|
|
712
|
+
mock_config = MagicMock()
|
|
713
|
+
mock_config.sharding_config.total_dp_size = 1
|
|
714
|
+
original_scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
|
|
715
|
+
mock_config.scheduler_config.scheduler_cls = original_scheduler_cls
|
|
716
|
+
|
|
717
|
+
update_vllm_config_for_dp_scheduler(mock_config)
|
|
718
|
+
|
|
719
|
+
# Verify config was NOT changed
|
|
720
|
+
assert mock_config.scheduler_config.scheduler_cls == original_scheduler_cls
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
if __name__ == "__main__":
|
|
724
|
+
pytest.main([__file__, "-v"])
|