tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
tests/core/test_dp_scheduler.py
CHANGED
|
@@ -1,19 +1,30 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from unittest.mock import MagicMock, patch
|
|
2
16
|
|
|
3
17
|
import pytest
|
|
4
|
-
import torch
|
|
5
18
|
from vllm.config import VllmConfig
|
|
6
|
-
from vllm.v1.core.sched.output import
|
|
7
|
-
SchedulerOutput)
|
|
19
|
+
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
|
8
20
|
from vllm.v1.core.sched.scheduler import Scheduler
|
|
9
|
-
from vllm.v1.engine import EngineCoreOutputs
|
|
10
21
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
11
22
|
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
|
|
12
|
-
from vllm.v1.outputs import ModelRunnerOutput
|
|
13
23
|
from vllm.v1.request import Request
|
|
14
24
|
|
|
15
25
|
from tpu_inference.core.sched.dp_scheduler import (
|
|
16
|
-
DPScheduler, DPSchedulerOutput,
|
|
26
|
+
DPScheduler, DPSchedulerOutput, SchedulerCommand,
|
|
27
|
+
update_vllm_config_for_dp_scheduler)
|
|
17
28
|
|
|
18
29
|
|
|
19
30
|
class TestDPScheduler:
|
|
@@ -43,387 +54,241 @@ class TestDPScheduler:
|
|
|
43
54
|
"""Create a mock StructuredOutputManager."""
|
|
44
55
|
return MagicMock()
|
|
45
56
|
|
|
46
|
-
def
|
|
47
|
-
mock_kv_cache_config,
|
|
48
|
-
mock_structured_output_manager,
|
|
49
|
-
**kwargs):
|
|
50
|
-
"""Helper to create a DPScheduler with properly mocked schedulers."""
|
|
51
|
-
# Create individual mock scheduler instances
|
|
52
|
-
mock_scheduler_0 = MagicMock()
|
|
53
|
-
mock_scheduler_1 = MagicMock()
|
|
54
|
-
|
|
55
|
-
# Patch the Scheduler class to return our mock instances
|
|
56
|
-
with patch.object(
|
|
57
|
-
mock_vllm_config.scheduler_config, '_original_scheduler_cls',
|
|
58
|
-
MagicMock(side_effect=[mock_scheduler_0, mock_scheduler_1])):
|
|
59
|
-
scheduler = DPScheduler(
|
|
60
|
-
vllm_config=mock_vllm_config,
|
|
61
|
-
kv_cache_config=mock_kv_cache_config,
|
|
62
|
-
structured_output_manager=mock_structured_output_manager,
|
|
63
|
-
block_size=16,
|
|
64
|
-
**kwargs)
|
|
65
|
-
|
|
66
|
-
return scheduler
|
|
67
|
-
|
|
68
|
-
def test_init_creates_per_rank_schedulers(
|
|
57
|
+
def test_init_creates_worker_processes(
|
|
69
58
|
self,
|
|
70
59
|
mock_vllm_config,
|
|
71
60
|
mock_kv_cache_config,
|
|
72
61
|
mock_structured_output_manager,
|
|
73
62
|
):
|
|
74
|
-
"""Test
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
assert
|
|
63
|
+
"""Test initialization creates worker processes for each DP rank."""
|
|
64
|
+
with patch(
|
|
65
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
66
|
+
):
|
|
67
|
+
with patch('multiprocessing.get_context') as mock_get_context:
|
|
68
|
+
# Setup mock context
|
|
69
|
+
mock_ctx = MagicMock()
|
|
70
|
+
mock_process = MagicMock()
|
|
71
|
+
mock_queue = MagicMock()
|
|
72
|
+
|
|
73
|
+
mock_ctx.Queue = MagicMock(return_value=mock_queue)
|
|
74
|
+
mock_ctx.Process = MagicMock(return_value=mock_process)
|
|
75
|
+
mock_get_context.return_value = mock_ctx
|
|
76
|
+
|
|
77
|
+
scheduler = DPScheduler(
|
|
78
|
+
vllm_config=mock_vllm_config,
|
|
79
|
+
kv_cache_config=mock_kv_cache_config,
|
|
80
|
+
structured_output_manager=mock_structured_output_manager,
|
|
81
|
+
block_size=16,
|
|
82
|
+
log_stats=True,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Verify processes and queues were created
|
|
86
|
+
assert scheduler.dp_size == 2
|
|
87
|
+
assert len(scheduler.processes) == 2
|
|
88
|
+
assert len(scheduler.input_queues) == 2
|
|
89
|
+
# output_queues is a dict with (rank, command) tuple keys
|
|
90
|
+
# 2 ranks × 14 commands (SchedulerCommand enum)
|
|
91
|
+
assert len(scheduler.output_queues) == 28
|
|
92
|
+
assert scheduler.log_stats is True
|
|
93
|
+
assert len(scheduler.per_rank_kv_cache_configs) == 2
|
|
94
|
+
|
|
95
|
+
# Verify each rank got the correct config
|
|
96
|
+
for rank_config in scheduler.per_rank_kv_cache_configs:
|
|
97
|
+
assert rank_config.num_blocks == 50 # 100 / 2
|
|
98
|
+
|
|
99
|
+
# Verify processes were started
|
|
100
|
+
assert mock_process.start.call_count == 2
|
|
98
101
|
|
|
99
102
|
def test_get_rank_token_counts(self, mock_vllm_config,
|
|
100
103
|
mock_kv_cache_config,
|
|
101
104
|
mock_structured_output_manager):
|
|
102
|
-
"""Test _get_rank_token_counts
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
105
|
+
"""Test _get_rank_token_counts queries workers and aggregates tokens."""
|
|
106
|
+
with patch(
|
|
107
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
108
|
+
):
|
|
109
|
+
with patch('multiprocessing.get_context'):
|
|
110
|
+
scheduler = DPScheduler(
|
|
111
|
+
vllm_config=mock_vllm_config,
|
|
112
|
+
kv_cache_config=mock_kv_cache_config,
|
|
113
|
+
structured_output_manager=mock_structured_output_manager,
|
|
114
|
+
block_size=16,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Mock the queues - need to mock the .get() method to return the value
|
|
118
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
119
|
+
|
|
120
|
+
mock_queue_0 = MagicMock()
|
|
121
|
+
mock_queue_0.get.return_value = 30
|
|
122
|
+
mock_queue_1 = MagicMock()
|
|
123
|
+
mock_queue_1.get.return_value = 15
|
|
124
|
+
|
|
125
|
+
scheduler.output_queues = {
|
|
126
|
+
(0, "get_token_count"): mock_queue_0,
|
|
127
|
+
(1, "get_token_count"): mock_queue_1,
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
rank_tokens = scheduler._get_rank_token_counts()
|
|
131
|
+
|
|
132
|
+
# Verify correct commands were sent
|
|
133
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
134
|
+
(SchedulerCommand.GET_TOKEN_COUNT, None))
|
|
135
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
136
|
+
(SchedulerCommand.GET_TOKEN_COUNT, None))
|
|
137
|
+
|
|
138
|
+
assert rank_tokens[0] == 30
|
|
139
|
+
assert rank_tokens[1] == 15
|
|
124
140
|
|
|
125
141
|
def test_find_best_rank_with_cache_hit(self, mock_vllm_config,
|
|
126
142
|
mock_kv_cache_config,
|
|
127
143
|
mock_structured_output_manager):
|
|
128
|
-
"""Test _find_best_rank_for_request
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
144
|
+
"""Test _find_best_rank_for_request prefers cache hits."""
|
|
145
|
+
with patch(
|
|
146
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
147
|
+
):
|
|
148
|
+
with patch('multiprocessing.get_context'):
|
|
149
|
+
scheduler = DPScheduler(
|
|
150
|
+
vllm_config=mock_vllm_config,
|
|
151
|
+
kv_cache_config=mock_kv_cache_config,
|
|
152
|
+
structured_output_manager=mock_structured_output_manager,
|
|
153
|
+
block_size=16,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
mock_request = MagicMock(spec=Request)
|
|
157
|
+
|
|
158
|
+
# Mock the queues with tuple keys (rank, command)
|
|
159
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
160
|
+
|
|
161
|
+
# Create proper mocks for queue.get() calls
|
|
162
|
+
mock_queue_get_token_0 = MagicMock()
|
|
163
|
+
mock_queue_get_token_0.get.return_value = 100
|
|
164
|
+
mock_queue_get_token_1 = MagicMock()
|
|
165
|
+
mock_queue_get_token_1.get.return_value = 50
|
|
166
|
+
mock_queue_computed_0 = MagicMock()
|
|
167
|
+
mock_queue_computed_0.get.return_value = ([], 10)
|
|
168
|
+
mock_queue_computed_1 = MagicMock()
|
|
169
|
+
mock_queue_computed_1.get.return_value = ([], 25)
|
|
170
|
+
|
|
171
|
+
scheduler.output_queues = {
|
|
172
|
+
(0, "get_token_count"): mock_queue_get_token_0,
|
|
173
|
+
(1, "get_token_count"): mock_queue_get_token_1,
|
|
174
|
+
(0, "get_computed_blocks"): mock_queue_computed_0,
|
|
175
|
+
(1, "get_computed_blocks"): mock_queue_computed_1,
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
rank = scheduler._find_best_rank_for_request(mock_request)
|
|
179
|
+
|
|
180
|
+
# Should prefer rank with better cache hit
|
|
181
|
+
assert rank == 1
|
|
161
182
|
|
|
162
183
|
def test_find_best_rank_without_cache_hit(self, mock_vllm_config,
|
|
163
184
|
mock_kv_cache_config,
|
|
164
185
|
mock_structured_output_manager):
|
|
165
|
-
"""Test _find_best_rank_for_request without cache hit
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
186
|
+
"""Test _find_best_rank_for_request uses load balancing without cache hit."""
|
|
187
|
+
with patch(
|
|
188
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
189
|
+
):
|
|
190
|
+
with patch('multiprocessing.get_context'):
|
|
191
|
+
scheduler = DPScheduler(
|
|
192
|
+
vllm_config=mock_vllm_config,
|
|
193
|
+
kv_cache_config=mock_kv_cache_config,
|
|
194
|
+
structured_output_manager=mock_structured_output_manager,
|
|
195
|
+
block_size=16,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
mock_request = MagicMock(spec=Request)
|
|
199
|
+
|
|
200
|
+
# Mock the queues with tuple keys (rank, command)
|
|
201
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
202
|
+
|
|
203
|
+
# Create proper mocks for queue.get() calls
|
|
204
|
+
mock_queue_get_token_0 = MagicMock()
|
|
205
|
+
mock_queue_get_token_0.get.return_value = 100
|
|
206
|
+
mock_queue_get_token_1 = MagicMock()
|
|
207
|
+
mock_queue_get_token_1.get.return_value = 50
|
|
208
|
+
mock_queue_computed_0 = MagicMock()
|
|
209
|
+
mock_queue_computed_0.get.return_value = ([], 0)
|
|
210
|
+
mock_queue_computed_1 = MagicMock()
|
|
211
|
+
mock_queue_computed_1.get.return_value = ([], 0)
|
|
212
|
+
|
|
213
|
+
scheduler.output_queues = {
|
|
214
|
+
(0, "get_token_count"): mock_queue_get_token_0,
|
|
215
|
+
(1, "get_token_count"): mock_queue_get_token_1,
|
|
216
|
+
(0, "get_computed_blocks"): mock_queue_computed_0,
|
|
217
|
+
(1, "get_computed_blocks"): mock_queue_computed_1,
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
rank = scheduler._find_best_rank_for_request(mock_request)
|
|
221
|
+
|
|
222
|
+
# Should choose rank with fewer tokens (rank 1)
|
|
223
|
+
assert rank == 1
|
|
197
224
|
|
|
198
225
|
def test_add_request_assigns_to_best_rank(self, mock_vllm_config,
|
|
199
226
|
mock_kv_cache_config,
|
|
200
227
|
mock_structured_output_manager):
|
|
201
|
-
"""Test add_request assigns
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
resumed_req_ids=[],
|
|
242
|
-
new_token_ids=[],
|
|
243
|
-
all_token_ids=[],
|
|
244
|
-
new_block_ids=[],
|
|
245
|
-
num_computed_tokens=[],
|
|
246
|
-
num_output_tokens=[],
|
|
247
|
-
)
|
|
248
|
-
mock_output_0.scheduled_spec_decode_tokens = {}
|
|
249
|
-
mock_output_0.scheduled_encoder_inputs = {}
|
|
250
|
-
mock_output_0.num_common_prefix_blocks = []
|
|
251
|
-
|
|
252
|
-
mock_output_1 = MagicMock(spec=SchedulerOutput)
|
|
253
|
-
mock_output_1.scheduled_new_reqs = []
|
|
254
|
-
mock_output_1.num_scheduled_tokens = {"req2": 20}
|
|
255
|
-
mock_output_1.total_num_scheduled_tokens = 20
|
|
256
|
-
mock_output_1.finished_req_ids = set()
|
|
257
|
-
mock_output_1.scheduled_cached_reqs = CachedRequestData(
|
|
258
|
-
req_ids=[],
|
|
259
|
-
resumed_req_ids=[],
|
|
260
|
-
new_token_ids=[],
|
|
261
|
-
all_token_ids=[],
|
|
262
|
-
new_block_ids=[],
|
|
263
|
-
num_computed_tokens=[],
|
|
264
|
-
num_output_tokens=[],
|
|
265
|
-
)
|
|
266
|
-
mock_output_1.scheduled_spec_decode_tokens = {}
|
|
267
|
-
mock_output_1.scheduled_encoder_inputs = {}
|
|
268
|
-
mock_output_1.num_common_prefix_blocks = []
|
|
269
|
-
|
|
270
|
-
scheduler.schedulers[0].schedule = MagicMock(
|
|
271
|
-
return_value=mock_output_0)
|
|
272
|
-
scheduler.schedulers[1].schedule = MagicMock(
|
|
273
|
-
return_value=mock_output_1)
|
|
274
|
-
scheduler.schedulers[0].running = []
|
|
275
|
-
scheduler.schedulers[0].waiting = []
|
|
276
|
-
scheduler.schedulers[1].running = []
|
|
277
|
-
scheduler.schedulers[1].waiting = []
|
|
278
|
-
|
|
279
|
-
# Assign ranks for requests
|
|
280
|
-
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1}
|
|
281
|
-
|
|
282
|
-
output = scheduler.schedule()
|
|
283
|
-
|
|
284
|
-
# Verify combined output
|
|
285
|
-
assert isinstance(output, DPSchedulerOutput)
|
|
286
|
-
assert output.total_num_scheduled_tokens == 30 # 10 + 20
|
|
287
|
-
assert "req1" in output.num_scheduled_tokens
|
|
288
|
-
assert "req2" in output.num_scheduled_tokens
|
|
289
|
-
assert output.assigned_dp_rank == {"req1": 0, "req2": 1}
|
|
290
|
-
|
|
291
|
-
def test_combine_cached_request_data(self, mock_vllm_config,
|
|
292
|
-
mock_kv_cache_config,
|
|
293
|
-
mock_structured_output_manager):
|
|
294
|
-
"""Test _combine_cached_request_data combines data from all ranks."""
|
|
295
|
-
mock_scheduler_cls = MagicMock(return_value=MagicMock())
|
|
296
|
-
with patch.object(mock_vllm_config.scheduler_config,
|
|
297
|
-
'_original_scheduler_cls', mock_scheduler_cls):
|
|
298
|
-
scheduler = DPScheduler(
|
|
299
|
-
vllm_config=mock_vllm_config,
|
|
300
|
-
kv_cache_config=mock_kv_cache_config,
|
|
301
|
-
structured_output_manager=mock_structured_output_manager,
|
|
302
|
-
block_size=16,
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
# Create mock rank outputs with different cached request data
|
|
306
|
-
output_0 = MagicMock(spec=SchedulerOutput)
|
|
307
|
-
output_0.scheduled_cached_reqs = CachedRequestData(
|
|
308
|
-
req_ids=["req1"],
|
|
309
|
-
resumed_req_ids=["req1"],
|
|
310
|
-
new_token_ids=[[1, 2, 3]],
|
|
311
|
-
all_token_ids=[[1, 2, 3, 4, 5]],
|
|
312
|
-
new_block_ids=[[10, 11]],
|
|
313
|
-
num_computed_tokens=[5],
|
|
314
|
-
num_output_tokens=[3],
|
|
315
|
-
)
|
|
316
|
-
|
|
317
|
-
output_1 = MagicMock(spec=SchedulerOutput)
|
|
318
|
-
output_1.scheduled_cached_reqs = CachedRequestData(
|
|
319
|
-
req_ids=["req2"],
|
|
320
|
-
resumed_req_ids=[],
|
|
321
|
-
new_token_ids=[[6, 7]],
|
|
322
|
-
all_token_ids=[[6, 7, 8, 9]],
|
|
323
|
-
new_block_ids=[[20, 21]],
|
|
324
|
-
num_computed_tokens=[4],
|
|
325
|
-
num_output_tokens=[2],
|
|
326
|
-
)
|
|
327
|
-
|
|
328
|
-
rank_outputs = [output_0, output_1]
|
|
329
|
-
combined = scheduler._combine_cached_request_data(rank_outputs)
|
|
330
|
-
|
|
331
|
-
# Verify combined data
|
|
332
|
-
assert combined.req_ids == ["req1", "req2"]
|
|
333
|
-
assert combined.resumed_req_ids == ["req1"]
|
|
334
|
-
assert combined.new_token_ids == [[1, 2, 3], [6, 7]]
|
|
335
|
-
assert combined.all_token_ids == [[1, 2, 3, 4, 5], [6, 7, 8, 9]]
|
|
336
|
-
assert combined.new_block_ids == [[10, 11], [20, 21]]
|
|
337
|
-
assert combined.num_computed_tokens == [5, 4]
|
|
338
|
-
assert combined.num_output_tokens == [3, 2]
|
|
339
|
-
|
|
340
|
-
def test_get_grammar_bitmask_with_structured_output(
|
|
228
|
+
"""Test add_request assigns request to best rank."""
|
|
229
|
+
with patch(
|
|
230
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
231
|
+
):
|
|
232
|
+
with patch('multiprocessing.get_context'):
|
|
233
|
+
scheduler = DPScheduler(
|
|
234
|
+
vllm_config=mock_vllm_config,
|
|
235
|
+
kv_cache_config=mock_kv_cache_config,
|
|
236
|
+
structured_output_manager=mock_structured_output_manager,
|
|
237
|
+
block_size=16,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
mock_request = MagicMock(spec=Request)
|
|
241
|
+
mock_request.request_id = "req1"
|
|
242
|
+
|
|
243
|
+
# Mock the queues with tuple keys
|
|
244
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
245
|
+
scheduler.output_queues = {
|
|
246
|
+
(0, "add_request"): MagicMock(),
|
|
247
|
+
(1, "add_request"): MagicMock(),
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
# Mock _find_best_rank_for_request to return rank 1
|
|
251
|
+
scheduler._find_best_rank_for_request = MagicMock(
|
|
252
|
+
return_value=1)
|
|
253
|
+
|
|
254
|
+
scheduler.add_request(mock_request)
|
|
255
|
+
|
|
256
|
+
# Verify request was assigned to rank 1
|
|
257
|
+
assert scheduler.assigned_dp_rank["req1"] == 1
|
|
258
|
+
|
|
259
|
+
# Verify ADD_REQUEST command was sent to rank 1
|
|
260
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
261
|
+
(SchedulerCommand.ADD_REQUEST, mock_request))
|
|
262
|
+
|
|
263
|
+
# Verify we waited for completion
|
|
264
|
+
scheduler.output_queues[(
|
|
265
|
+
1, "add_request")].get.assert_called_once()
|
|
266
|
+
|
|
267
|
+
def test_schedule_sends_commands_and_combines_output(
|
|
341
268
|
self, mock_vllm_config, mock_kv_cache_config,
|
|
342
269
|
mock_structured_output_manager):
|
|
343
|
-
"""Test
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
return_value=grammar_output_1)
|
|
366
|
-
|
|
367
|
-
# Cache scheduler outputs
|
|
368
|
-
scheduler.cached_schedulers_output.append(
|
|
369
|
-
[mock_output_0, mock_output_1])
|
|
370
|
-
|
|
371
|
-
# Create a DPSchedulerOutput
|
|
372
|
-
dp_output = DPSchedulerOutput(
|
|
373
|
-
scheduled_new_reqs=[],
|
|
374
|
-
scheduled_cached_reqs=CachedRequestData(
|
|
375
|
-
req_ids=[],
|
|
376
|
-
resumed_req_ids=[],
|
|
377
|
-
new_token_ids=[],
|
|
378
|
-
all_token_ids=[],
|
|
379
|
-
new_block_ids=[],
|
|
380
|
-
num_computed_tokens=[],
|
|
381
|
-
num_output_tokens=[],
|
|
382
|
-
),
|
|
383
|
-
num_scheduled_tokens={},
|
|
384
|
-
total_num_scheduled_tokens=0,
|
|
385
|
-
scheduled_spec_decode_tokens={},
|
|
386
|
-
scheduled_encoder_inputs={},
|
|
387
|
-
num_common_prefix_blocks=[],
|
|
388
|
-
finished_req_ids=set(),
|
|
389
|
-
free_encoder_mm_hashes=set(),
|
|
390
|
-
)
|
|
391
|
-
|
|
392
|
-
result = scheduler.get_grammar_bitmask(dp_output)
|
|
393
|
-
|
|
394
|
-
assert result is not None
|
|
395
|
-
assert result.structured_output_request_ids == ["req1", "req2"]
|
|
396
|
-
assert result.grammar_bitmask.shape == (2, 100)
|
|
397
|
-
|
|
398
|
-
def test_get_grammar_bitmask_no_structured_output(
|
|
399
|
-
self, mock_vllm_config, mock_kv_cache_config,
|
|
400
|
-
mock_structured_output_manager):
|
|
401
|
-
"""Test get_grammar_bitmask returns None when no structured output."""
|
|
402
|
-
mock_scheduler_cls = MagicMock(return_value=MagicMock())
|
|
403
|
-
with patch.object(mock_vllm_config.scheduler_config,
|
|
404
|
-
'_original_scheduler_cls', mock_scheduler_cls):
|
|
405
|
-
scheduler = DPScheduler(
|
|
406
|
-
vllm_config=mock_vllm_config,
|
|
407
|
-
kv_cache_config=mock_kv_cache_config,
|
|
408
|
-
structured_output_manager=mock_structured_output_manager,
|
|
409
|
-
block_size=16,
|
|
410
|
-
)
|
|
411
|
-
|
|
412
|
-
# Mock schedulers returning None
|
|
413
|
-
scheduler.schedulers[0].get_grammar_bitmask = MagicMock(
|
|
414
|
-
return_value=None)
|
|
415
|
-
scheduler.schedulers[1].get_grammar_bitmask = MagicMock(
|
|
416
|
-
return_value=None)
|
|
417
|
-
|
|
418
|
-
# Cache scheduler outputs
|
|
419
|
-
mock_output_0 = MagicMock()
|
|
420
|
-
mock_output_1 = MagicMock()
|
|
421
|
-
scheduler.cached_schedulers_output.append(
|
|
422
|
-
[mock_output_0, mock_output_1])
|
|
423
|
-
|
|
424
|
-
dp_output = DPSchedulerOutput(
|
|
425
|
-
scheduled_new_reqs=[],
|
|
426
|
-
scheduled_cached_reqs=CachedRequestData(
|
|
270
|
+
"""Test schedule sends SCHEDULE command to all workers and combines output."""
|
|
271
|
+
with patch(
|
|
272
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
273
|
+
):
|
|
274
|
+
with patch('multiprocessing.get_context'):
|
|
275
|
+
scheduler = DPScheduler(
|
|
276
|
+
vllm_config=mock_vllm_config,
|
|
277
|
+
kv_cache_config=mock_kv_cache_config,
|
|
278
|
+
structured_output_manager=mock_structured_output_manager,
|
|
279
|
+
block_size=16,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Mock the queues with tuple keys
|
|
283
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
284
|
+
|
|
285
|
+
# Create mock scheduler outputs
|
|
286
|
+
mock_output_0 = MagicMock(spec=SchedulerOutput)
|
|
287
|
+
mock_output_0.scheduled_new_reqs = []
|
|
288
|
+
mock_output_0.num_scheduled_tokens = {"req1": 10}
|
|
289
|
+
mock_output_0.total_num_scheduled_tokens = 10
|
|
290
|
+
mock_output_0.finished_req_ids = set()
|
|
291
|
+
mock_output_0.scheduled_cached_reqs = CachedRequestData(
|
|
427
292
|
req_ids=[],
|
|
428
293
|
resumed_req_ids=[],
|
|
429
294
|
new_token_ids=[],
|
|
@@ -431,40 +296,17 @@ class TestDPScheduler:
|
|
|
431
296
|
new_block_ids=[],
|
|
432
297
|
num_computed_tokens=[],
|
|
433
298
|
num_output_tokens=[],
|
|
434
|
-
)
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
assert result is None
|
|
446
|
-
|
|
447
|
-
def test_update_from_output_routes_to_schedulers(
|
|
448
|
-
self, mock_vllm_config, mock_kv_cache_config,
|
|
449
|
-
mock_structured_output_manager):
|
|
450
|
-
"""Test update_from_output splits output and updates each scheduler."""
|
|
451
|
-
mock_scheduler_cls = MagicMock(return_value=MagicMock())
|
|
452
|
-
with patch.object(mock_vllm_config.scheduler_config,
|
|
453
|
-
'_original_scheduler_cls', mock_scheduler_cls):
|
|
454
|
-
scheduler = DPScheduler(
|
|
455
|
-
vllm_config=mock_vllm_config,
|
|
456
|
-
kv_cache_config=mock_kv_cache_config,
|
|
457
|
-
structured_output_manager=mock_structured_output_manager,
|
|
458
|
-
block_size=16,
|
|
459
|
-
)
|
|
460
|
-
|
|
461
|
-
# Setup assigned ranks
|
|
462
|
-
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
|
|
463
|
-
|
|
464
|
-
# Create DPSchedulerOutput
|
|
465
|
-
dp_output = DPSchedulerOutput(
|
|
466
|
-
scheduled_new_reqs=[],
|
|
467
|
-
scheduled_cached_reqs=CachedRequestData(
|
|
299
|
+
)
|
|
300
|
+
mock_output_0.scheduled_spec_decode_tokens = {}
|
|
301
|
+
mock_output_0.scheduled_encoder_inputs = {}
|
|
302
|
+
mock_output_0.num_common_prefix_blocks = []
|
|
303
|
+
|
|
304
|
+
mock_output_1 = MagicMock(spec=SchedulerOutput)
|
|
305
|
+
mock_output_1.scheduled_new_reqs = []
|
|
306
|
+
mock_output_1.num_scheduled_tokens = {"req2": 20}
|
|
307
|
+
mock_output_1.total_num_scheduled_tokens = 20
|
|
308
|
+
mock_output_1.finished_req_ids = set()
|
|
309
|
+
mock_output_1.scheduled_cached_reqs = CachedRequestData(
|
|
468
310
|
req_ids=[],
|
|
469
311
|
resumed_req_ids=[],
|
|
470
312
|
new_token_ids=[],
|
|
@@ -472,397 +314,437 @@ class TestDPScheduler:
|
|
|
472
314
|
new_block_ids=[],
|
|
473
315
|
num_computed_tokens=[],
|
|
474
316
|
num_output_tokens=[],
|
|
475
|
-
)
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
"
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
)
|
|
509
|
-
|
|
510
|
-
# Mock rank scheduler outputs (cached from schedule call)
|
|
511
|
-
rank_output_0 = MagicMock()
|
|
512
|
-
rank_output_1 = MagicMock()
|
|
513
|
-
scheduler.cached_schedulers_output.append(
|
|
514
|
-
[rank_output_0, rank_output_1])
|
|
515
|
-
|
|
516
|
-
# Mock scheduler update_from_output
|
|
517
|
-
engine_output_0 = EngineCoreOutputs()
|
|
518
|
-
engine_output_0.engine_index = 0
|
|
519
|
-
engine_output_0.outputs = []
|
|
520
|
-
engine_output_0.finished_requests = {"req3"}
|
|
521
|
-
|
|
522
|
-
engine_output_1 = EngineCoreOutputs()
|
|
523
|
-
engine_output_1.engine_index = 0
|
|
524
|
-
engine_output_1.outputs = []
|
|
525
|
-
engine_output_1.finished_requests = set()
|
|
526
|
-
|
|
527
|
-
scheduler.schedulers[0].update_from_output = MagicMock(
|
|
528
|
-
return_value={0: engine_output_0})
|
|
529
|
-
scheduler.schedulers[1].update_from_output = MagicMock(
|
|
530
|
-
return_value={0: engine_output_1})
|
|
531
|
-
|
|
532
|
-
# Mock make_stats
|
|
533
|
-
scheduler.make_stats = MagicMock(return_value=None)
|
|
534
|
-
|
|
535
|
-
_ = scheduler.update_from_output(dp_output, model_output)
|
|
536
|
-
|
|
537
|
-
# Verify schedulers were updated
|
|
538
|
-
assert scheduler.schedulers[0].update_from_output.called
|
|
539
|
-
assert scheduler.schedulers[1].update_from_output.called
|
|
540
|
-
|
|
541
|
-
# Verify finished request was cleaned up
|
|
542
|
-
assert "req3" not in scheduler.assigned_dp_rank
|
|
543
|
-
assert "req1" in scheduler.assigned_dp_rank
|
|
544
|
-
assert "req2" in scheduler.assigned_dp_rank
|
|
545
|
-
|
|
546
|
-
def test_split_model_output_by_rank(self, mock_vllm_config,
|
|
547
|
-
mock_kv_cache_config,
|
|
548
|
-
mock_structured_output_manager):
|
|
549
|
-
"""Test _split_model_output_by_rank distributes output correctly."""
|
|
550
|
-
mock_scheduler_cls = MagicMock(return_value=MagicMock())
|
|
551
|
-
with patch.object(mock_vllm_config.scheduler_config,
|
|
552
|
-
'_original_scheduler_cls', mock_scheduler_cls):
|
|
553
|
-
scheduler = DPScheduler(
|
|
554
|
-
vllm_config=mock_vllm_config,
|
|
555
|
-
kv_cache_config=mock_kv_cache_config,
|
|
556
|
-
structured_output_manager=mock_structured_output_manager,
|
|
557
|
-
block_size=16,
|
|
558
|
-
)
|
|
559
|
-
|
|
560
|
-
# Setup assigned ranks
|
|
561
|
-
scheduler.assigned_dp_rank = {
|
|
562
|
-
"req1": 0,
|
|
563
|
-
"req2": 1,
|
|
564
|
-
"req3": 0,
|
|
565
|
-
"req4": 1
|
|
566
|
-
}
|
|
567
|
-
|
|
568
|
-
# Create global model output
|
|
569
|
-
global_output = ModelRunnerOutput(
|
|
570
|
-
req_ids=["req1", "req2", "req3", "req4"],
|
|
571
|
-
req_id_to_index={
|
|
572
|
-
"req1": 0,
|
|
573
|
-
"req2": 1,
|
|
574
|
-
"req3": 2,
|
|
575
|
-
"req4": 3
|
|
576
|
-
},
|
|
577
|
-
sampled_token_ids=torch.tensor([100, 200, 300, 400]),
|
|
578
|
-
logprobs=None,
|
|
579
|
-
prompt_logprobs_dict={},
|
|
580
|
-
pooler_output=None,
|
|
581
|
-
num_nans_in_logits=0,
|
|
582
|
-
kv_connector_output=None,
|
|
583
|
-
)
|
|
584
|
-
|
|
585
|
-
rank_outputs = scheduler._split_model_output_by_rank(global_output)
|
|
586
|
-
|
|
587
|
-
# Verify split outputs
|
|
588
|
-
assert len(rank_outputs) == 2
|
|
589
|
-
assert rank_outputs[0].req_ids == ["req1", "req3"]
|
|
590
|
-
assert rank_outputs[1].req_ids == ["req2", "req4"]
|
|
591
|
-
|
|
592
|
-
def test_cleanup_finished_requests(self, mock_vllm_config,
|
|
593
|
-
mock_kv_cache_config,
|
|
594
|
-
mock_structured_output_manager):
|
|
595
|
-
"""Test _cleanup_finished_requests removes finished requests."""
|
|
596
|
-
mock_scheduler_cls = MagicMock(return_value=MagicMock())
|
|
597
|
-
with patch.object(mock_vllm_config.scheduler_config,
|
|
598
|
-
'_original_scheduler_cls', mock_scheduler_cls):
|
|
599
|
-
scheduler = DPScheduler(
|
|
600
|
-
vllm_config=mock_vllm_config,
|
|
601
|
-
kv_cache_config=mock_kv_cache_config,
|
|
602
|
-
structured_output_manager=mock_structured_output_manager,
|
|
603
|
-
block_size=16,
|
|
604
|
-
)
|
|
605
|
-
|
|
606
|
-
# Setup assigned ranks
|
|
607
|
-
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
|
|
608
|
-
|
|
609
|
-
# Clean up finished requests
|
|
610
|
-
scheduler._cleanup_finished_requests({"req1", "req3"})
|
|
611
|
-
|
|
612
|
-
# Verify cleanup
|
|
613
|
-
assert "req1" not in scheduler.assigned_dp_rank
|
|
614
|
-
assert "req3" not in scheduler.assigned_dp_rank
|
|
615
|
-
assert "req2" in scheduler.assigned_dp_rank
|
|
616
|
-
|
|
617
|
-
def test_finish_requests_single_and_multiple(
|
|
618
|
-
self, mock_vllm_config, mock_kv_cache_config,
|
|
619
|
-
mock_structured_output_manager):
|
|
620
|
-
"""Test finish_requests handles single string and list."""
|
|
621
|
-
scheduler = self._create_dp_scheduler_with_mocks(
|
|
622
|
-
mock_vllm_config, mock_kv_cache_config,
|
|
623
|
-
mock_structured_output_manager)
|
|
624
|
-
|
|
625
|
-
# Setup assigned ranks
|
|
626
|
-
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
|
|
627
|
-
|
|
628
|
-
# Mock scheduler finish_requests
|
|
629
|
-
scheduler.schedulers[0].finish_requests = MagicMock()
|
|
630
|
-
scheduler.schedulers[1].finish_requests = MagicMock()
|
|
631
|
-
|
|
632
|
-
# Test with single string
|
|
633
|
-
scheduler.finish_requests("req1", finished_status="completed")
|
|
634
|
-
scheduler.schedulers[0].finish_requests.assert_called_with(["req1"],
|
|
635
|
-
"completed")
|
|
636
|
-
|
|
637
|
-
# Test with list
|
|
638
|
-
scheduler.schedulers[0].finish_requests.reset_mock()
|
|
639
|
-
scheduler.schedulers[1].finish_requests.reset_mock()
|
|
640
|
-
|
|
641
|
-
scheduler.finish_requests(["req1", "req2"],
|
|
642
|
-
finished_status="completed")
|
|
643
|
-
scheduler.schedulers[0].finish_requests.assert_called_once_with(
|
|
644
|
-
["req1"], "completed")
|
|
645
|
-
scheduler.schedulers[1].finish_requests.assert_called_once_with(
|
|
646
|
-
["req2"], "completed")
|
|
317
|
+
)
|
|
318
|
+
mock_output_1.scheduled_spec_decode_tokens = {}
|
|
319
|
+
mock_output_1.scheduled_encoder_inputs = {}
|
|
320
|
+
mock_output_1.num_common_prefix_blocks = []
|
|
321
|
+
|
|
322
|
+
# Setup mock queue responses with tuple keys - need to mock .get()
|
|
323
|
+
mock_queue_0 = MagicMock()
|
|
324
|
+
mock_queue_0.get.return_value = mock_output_0
|
|
325
|
+
mock_queue_1 = MagicMock()
|
|
326
|
+
mock_queue_1.get.return_value = mock_output_1
|
|
327
|
+
|
|
328
|
+
scheduler.output_queues = {
|
|
329
|
+
(0, "schedule"): mock_queue_0,
|
|
330
|
+
(1, "schedule"): mock_queue_1,
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
# Setup assigned ranks
|
|
334
|
+
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1}
|
|
335
|
+
|
|
336
|
+
output = scheduler.schedule()
|
|
337
|
+
|
|
338
|
+
# Verify SCHEDULE commands were sent
|
|
339
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
340
|
+
(SchedulerCommand.SCHEDULE, None))
|
|
341
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
342
|
+
(SchedulerCommand.SCHEDULE, None))
|
|
343
|
+
|
|
344
|
+
# Verify combined output
|
|
345
|
+
assert isinstance(output, DPSchedulerOutput)
|
|
346
|
+
assert output.total_num_scheduled_tokens == 30
|
|
347
|
+
assert "req1" in output.num_scheduled_tokens
|
|
348
|
+
assert "req2" in output.num_scheduled_tokens
|
|
349
|
+
assert output.assigned_dp_rank == {"req1": 0, "req2": 1}
|
|
647
350
|
|
|
648
|
-
def
|
|
351
|
+
def test_combine_cached_request_data(self, mock_vllm_config,
|
|
649
352
|
mock_kv_cache_config,
|
|
650
353
|
mock_structured_output_manager):
|
|
651
|
-
"""Test
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
354
|
+
"""Test _combine_cached_request_data combines data from all ranks."""
|
|
355
|
+
with patch(
|
|
356
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
357
|
+
):
|
|
358
|
+
with patch('multiprocessing.get_context'):
|
|
359
|
+
scheduler = DPScheduler(
|
|
360
|
+
vllm_config=mock_vllm_config,
|
|
361
|
+
kv_cache_config=mock_kv_cache_config,
|
|
362
|
+
structured_output_manager=mock_structured_output_manager,
|
|
363
|
+
block_size=16,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Create mock rank outputs
|
|
367
|
+
output_0 = MagicMock(spec=SchedulerOutput)
|
|
368
|
+
output_0.scheduled_cached_reqs = CachedRequestData(
|
|
369
|
+
req_ids=["req1"],
|
|
370
|
+
resumed_req_ids=["req1"],
|
|
371
|
+
new_token_ids=[[1, 2, 3]],
|
|
372
|
+
all_token_ids=[[1, 2, 3, 4, 5]],
|
|
373
|
+
new_block_ids=[[10, 11]],
|
|
374
|
+
num_computed_tokens=[5],
|
|
375
|
+
num_output_tokens=[3],
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
output_1 = MagicMock(spec=SchedulerOutput)
|
|
379
|
+
output_1.scheduled_cached_reqs = CachedRequestData(
|
|
380
|
+
req_ids=["req2"],
|
|
381
|
+
resumed_req_ids=[],
|
|
382
|
+
new_token_ids=[[6, 7]],
|
|
383
|
+
all_token_ids=[[6, 7, 8, 9]],
|
|
384
|
+
new_block_ids=[[20, 21]],
|
|
385
|
+
num_computed_tokens=[4],
|
|
386
|
+
num_output_tokens=[2],
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
combined = scheduler._combine_cached_request_data(
|
|
390
|
+
[output_0, output_1])
|
|
391
|
+
|
|
392
|
+
# Verify combined data
|
|
393
|
+
assert combined.req_ids == ["req1", "req2"]
|
|
394
|
+
assert combined.resumed_req_ids == ["req1"]
|
|
395
|
+
assert combined.new_token_ids == [[1, 2, 3], [6, 7]]
|
|
396
|
+
assert combined.num_computed_tokens == [5, 4]
|
|
397
|
+
assert combined.num_output_tokens == [3, 2]
|
|
398
|
+
|
|
399
|
+
def test_finish_requests_routes_to_workers(self, mock_vllm_config,
|
|
400
|
+
mock_kv_cache_config,
|
|
401
|
+
mock_structured_output_manager):
|
|
402
|
+
"""Test finish_requests sends FINISH_REQUESTS command to appropriate workers."""
|
|
403
|
+
with patch(
|
|
404
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
405
|
+
):
|
|
406
|
+
with patch('multiprocessing.get_context'):
|
|
407
|
+
scheduler = DPScheduler(
|
|
408
|
+
vllm_config=mock_vllm_config,
|
|
409
|
+
kv_cache_config=mock_kv_cache_config,
|
|
410
|
+
structured_output_manager=mock_structured_output_manager,
|
|
411
|
+
block_size=16,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
415
|
+
scheduler.output_queues = {
|
|
416
|
+
(0, "finish_requests"): MagicMock(),
|
|
417
|
+
(1, "finish_requests"): MagicMock(),
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
|
|
421
|
+
|
|
422
|
+
# Test with list of requests
|
|
423
|
+
scheduler.finish_requests(["req1", "req2"],
|
|
424
|
+
finished_status="completed")
|
|
425
|
+
|
|
426
|
+
# Verify FINISH_REQUESTS commands were sent to correct ranks
|
|
427
|
+
scheduler.input_queues[0].put.assert_called()
|
|
428
|
+
scheduler.input_queues[1].put.assert_called()
|
|
660
429
|
|
|
661
|
-
|
|
662
|
-
|
|
430
|
+
def test_get_num_unfinished_requests(self, mock_vllm_config,
|
|
431
|
+
mock_kv_cache_config,
|
|
432
|
+
mock_structured_output_manager):
|
|
433
|
+
"""Test get_num_unfinished_requests queries all workers."""
|
|
434
|
+
with patch(
|
|
435
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
436
|
+
):
|
|
437
|
+
with patch('multiprocessing.get_context'):
|
|
438
|
+
scheduler = DPScheduler(
|
|
439
|
+
vllm_config=mock_vllm_config,
|
|
440
|
+
kv_cache_config=mock_kv_cache_config,
|
|
441
|
+
structured_output_manager=mock_structured_output_manager,
|
|
442
|
+
block_size=16,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
446
|
+
|
|
447
|
+
# Create proper mocks for queue.get() calls
|
|
448
|
+
mock_queue_0 = MagicMock()
|
|
449
|
+
mock_queue_0.get.return_value = 5
|
|
450
|
+
mock_queue_1 = MagicMock()
|
|
451
|
+
mock_queue_1.get.return_value = 3
|
|
452
|
+
|
|
453
|
+
scheduler.output_queues = {
|
|
454
|
+
(0, "get_num_unfinished_requests"): mock_queue_0,
|
|
455
|
+
(1, "get_num_unfinished_requests"): mock_queue_1,
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
total = scheduler.get_num_unfinished_requests()
|
|
459
|
+
|
|
460
|
+
# Verify commands were sent
|
|
461
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
462
|
+
(SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS, None))
|
|
463
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
464
|
+
(SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS, None))
|
|
465
|
+
|
|
466
|
+
assert total == 8
|
|
663
467
|
|
|
664
468
|
def test_has_finished_requests(self, mock_vllm_config,
|
|
665
469
|
mock_kv_cache_config,
|
|
666
470
|
mock_structured_output_manager):
|
|
667
|
-
"""Test has_finished_requests checks all
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
471
|
+
"""Test has_finished_requests checks all workers."""
|
|
472
|
+
with patch(
|
|
473
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
474
|
+
):
|
|
475
|
+
with patch('multiprocessing.get_context'):
|
|
476
|
+
scheduler = DPScheduler(
|
|
477
|
+
vllm_config=mock_vllm_config,
|
|
478
|
+
kv_cache_config=mock_kv_cache_config,
|
|
479
|
+
structured_output_manager=mock_structured_output_manager,
|
|
480
|
+
block_size=16,
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
484
|
+
|
|
485
|
+
# Create proper mocks for queue.get() calls
|
|
486
|
+
mock_queue_0 = MagicMock()
|
|
487
|
+
mock_queue_0.get.return_value = False
|
|
488
|
+
mock_queue_1 = MagicMock()
|
|
489
|
+
mock_queue_1.get.return_value = True
|
|
490
|
+
|
|
491
|
+
scheduler.output_queues = {
|
|
492
|
+
(0, "has_finished_requests"): mock_queue_0,
|
|
493
|
+
(1, "has_finished_requests"): mock_queue_1,
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
result = scheduler.has_finished_requests()
|
|
497
|
+
|
|
498
|
+
assert result is True
|
|
499
|
+
|
|
500
|
+
# Verify commands were sent
|
|
501
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
502
|
+
(SchedulerCommand.HAS_FINISHED_REQUESTS, None))
|
|
503
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
504
|
+
(SchedulerCommand.HAS_FINISHED_REQUESTS, None))
|
|
690
505
|
|
|
691
506
|
def test_get_request_counts(self, mock_vllm_config, mock_kv_cache_config,
|
|
692
507
|
mock_structured_output_manager):
|
|
693
|
-
"""Test get_request_counts
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
508
|
+
"""Test get_request_counts queries all workers."""
|
|
509
|
+
with patch(
|
|
510
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
511
|
+
):
|
|
512
|
+
with patch('multiprocessing.get_context'):
|
|
513
|
+
scheduler = DPScheduler(
|
|
514
|
+
vllm_config=mock_vllm_config,
|
|
515
|
+
kv_cache_config=mock_kv_cache_config,
|
|
516
|
+
structured_output_manager=mock_structured_output_manager,
|
|
517
|
+
block_size=16,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
521
|
+
|
|
522
|
+
# Create proper mocks for queue.get() calls
|
|
523
|
+
mock_queue_0 = MagicMock()
|
|
524
|
+
mock_queue_0.get.return_value = (2, 1)
|
|
525
|
+
mock_queue_1 = MagicMock()
|
|
526
|
+
mock_queue_1.get.return_value = (1, 3)
|
|
527
|
+
|
|
528
|
+
scheduler.output_queues = {
|
|
529
|
+
(0, "get_request_counts"): mock_queue_0,
|
|
530
|
+
(1, "get_request_counts"): mock_queue_1,
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
running, waiting = scheduler.get_request_counts()
|
|
534
|
+
|
|
535
|
+
# Verify commands were sent
|
|
536
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
537
|
+
(SchedulerCommand.GET_REQUEST_COUNTS, None))
|
|
538
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
539
|
+
(SchedulerCommand.GET_REQUEST_COUNTS, None))
|
|
540
|
+
|
|
541
|
+
assert running == 3
|
|
542
|
+
assert waiting == 4
|
|
711
543
|
|
|
712
544
|
def test_reset_prefix_cache(self, mock_vllm_config, mock_kv_cache_config,
|
|
713
545
|
mock_structured_output_manager):
|
|
714
|
-
"""Test reset_prefix_cache
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
)
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
546
|
+
"""Test reset_prefix_cache sends command to all workers."""
|
|
547
|
+
with patch(
|
|
548
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
549
|
+
):
|
|
550
|
+
with patch('multiprocessing.get_context'):
|
|
551
|
+
scheduler = DPScheduler(
|
|
552
|
+
vllm_config=mock_vllm_config,
|
|
553
|
+
kv_cache_config=mock_kv_cache_config,
|
|
554
|
+
structured_output_manager=mock_structured_output_manager,
|
|
555
|
+
block_size=16,
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
559
|
+
|
|
560
|
+
# Create proper mocks for queue.get() calls
|
|
561
|
+
mock_queue_0 = MagicMock()
|
|
562
|
+
mock_queue_0.get.return_value = True
|
|
563
|
+
mock_queue_1 = MagicMock()
|
|
564
|
+
mock_queue_1.get.return_value = True
|
|
565
|
+
|
|
566
|
+
scheduler.output_queues = {
|
|
567
|
+
(0, "reset_prefix_cache"): mock_queue_0,
|
|
568
|
+
(1, "reset_prefix_cache"): mock_queue_1,
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
result = scheduler.reset_prefix_cache()
|
|
572
|
+
|
|
573
|
+
# Verify commands were sent
|
|
574
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
575
|
+
(SchedulerCommand.RESET_PREFIX_CACHE, None))
|
|
576
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
577
|
+
(SchedulerCommand.RESET_PREFIX_CACHE, None))
|
|
578
|
+
|
|
579
|
+
assert result is True
|
|
580
|
+
|
|
581
|
+
def test_make_stats_aggregates_from_workers(
|
|
582
|
+
self, mock_vllm_config, mock_kv_cache_config,
|
|
583
|
+
mock_structured_output_manager):
|
|
584
|
+
"""Test make_stats aggregates statistics from all workers."""
|
|
585
|
+
with patch(
|
|
586
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
587
|
+
):
|
|
588
|
+
with patch('multiprocessing.get_context'):
|
|
589
|
+
scheduler = DPScheduler(
|
|
590
|
+
vllm_config=mock_vllm_config,
|
|
591
|
+
kv_cache_config=mock_kv_cache_config,
|
|
592
|
+
structured_output_manager=mock_structured_output_manager,
|
|
593
|
+
block_size=16,
|
|
594
|
+
log_stats=True,
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
598
|
+
|
|
599
|
+
# Create mock stats
|
|
600
|
+
stats_0 = SchedulerStats(
|
|
601
|
+
num_running_reqs=3,
|
|
602
|
+
num_waiting_reqs=2,
|
|
603
|
+
kv_cache_usage=0.5,
|
|
604
|
+
prefix_cache_stats=PrefixCacheStats(reset=False,
|
|
605
|
+
requests=10,
|
|
606
|
+
queries=8,
|
|
607
|
+
hits=5),
|
|
608
|
+
connector_prefix_cache_stats=PrefixCacheStats(reset=False,
|
|
609
|
+
requests=5,
|
|
610
|
+
queries=4,
|
|
611
|
+
hits=2),
|
|
612
|
+
spec_decoding_stats=None,
|
|
613
|
+
kv_connector_stats=None,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
stats_1 = SchedulerStats(
|
|
617
|
+
num_running_reqs=4,
|
|
618
|
+
num_waiting_reqs=1,
|
|
619
|
+
kv_cache_usage=0.7,
|
|
620
|
+
prefix_cache_stats=PrefixCacheStats(reset=False,
|
|
621
|
+
requests=15,
|
|
622
|
+
queries=12,
|
|
623
|
+
hits=8),
|
|
624
|
+
connector_prefix_cache_stats=PrefixCacheStats(reset=False,
|
|
625
|
+
requests=6,
|
|
626
|
+
queries=5,
|
|
627
|
+
hits=3),
|
|
628
|
+
spec_decoding_stats=None,
|
|
629
|
+
kv_connector_stats=None,
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
# Create proper mocks for queue.get() calls
|
|
633
|
+
mock_queue_0 = MagicMock()
|
|
634
|
+
mock_queue_0.get.return_value = stats_0
|
|
635
|
+
mock_queue_1 = MagicMock()
|
|
636
|
+
mock_queue_1.get.return_value = stats_1
|
|
637
|
+
|
|
638
|
+
scheduler.output_queues = {
|
|
639
|
+
(0, "make_stats"): mock_queue_0,
|
|
640
|
+
(1, "make_stats"): mock_queue_1,
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
combined_stats = scheduler.make_stats()
|
|
644
|
+
|
|
645
|
+
# Verify commands were sent
|
|
646
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
647
|
+
(SchedulerCommand.MAKE_STATS, (None, None)))
|
|
648
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
649
|
+
(SchedulerCommand.MAKE_STATS, (None, None)))
|
|
650
|
+
|
|
651
|
+
assert combined_stats.num_running_reqs == 7
|
|
652
|
+
assert combined_stats.num_waiting_reqs == 3
|
|
653
|
+
assert combined_stats.kv_cache_usage == 0.6
|
|
654
|
+
|
|
655
|
+
def test_make_stats_returns_none_when_disabled(
|
|
656
|
+
self, mock_vllm_config, mock_kv_cache_config,
|
|
657
|
+
mock_structured_output_manager):
|
|
658
|
+
"""Test make_stats returns None when logging disabled."""
|
|
659
|
+
with patch(
|
|
660
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
661
|
+
):
|
|
662
|
+
with patch('multiprocessing.get_context'):
|
|
663
|
+
scheduler = DPScheduler(
|
|
664
|
+
vllm_config=mock_vllm_config,
|
|
665
|
+
kv_cache_config=mock_kv_cache_config,
|
|
666
|
+
structured_output_manager=mock_structured_output_manager,
|
|
667
|
+
block_size=16,
|
|
668
|
+
log_stats=False,
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
stats = scheduler.make_stats()
|
|
672
|
+
assert stats is None
|
|
810
673
|
|
|
811
674
|
def test_update_draft_token_ids(self, mock_vllm_config,
|
|
812
675
|
mock_kv_cache_config,
|
|
813
676
|
mock_structured_output_manager):
|
|
814
|
-
"""Test update_draft_token_ids routes
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
# Check rank 1 got req2
|
|
848
|
-
call_args_1 = scheduler.schedulers[1].update_draft_token_ids.call_args[
|
|
849
|
-
0][0]
|
|
850
|
-
assert "req2" in call_args_1.req_ids
|
|
677
|
+
"""Test update_draft_token_ids routes to correct workers."""
|
|
678
|
+
with patch(
|
|
679
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
680
|
+
):
|
|
681
|
+
with patch('multiprocessing.get_context'):
|
|
682
|
+
scheduler = DPScheduler(
|
|
683
|
+
vllm_config=mock_vllm_config,
|
|
684
|
+
kv_cache_config=mock_kv_cache_config,
|
|
685
|
+
structured_output_manager=mock_structured_output_manager,
|
|
686
|
+
block_size=16,
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
690
|
+
scheduler.output_queues = {
|
|
691
|
+
(0, "update_draft_token_ids"): MagicMock(),
|
|
692
|
+
(1, "update_draft_token_ids"): MagicMock(),
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
|
|
696
|
+
|
|
697
|
+
draft_token_ids = MagicMock()
|
|
698
|
+
draft_token_ids.req_ids = ["req1", "req2", "req3"]
|
|
699
|
+
draft_token_ids.draft_token_ids = [
|
|
700
|
+
[101, 102, 103],
|
|
701
|
+
[201, 202],
|
|
702
|
+
[301, 302, 303, 304],
|
|
703
|
+
]
|
|
704
|
+
|
|
705
|
+
scheduler.update_draft_token_ids(draft_token_ids)
|
|
706
|
+
|
|
707
|
+
# Verify commands were sent to correct workers
|
|
708
|
+
scheduler.input_queues[0].put.assert_called()
|
|
709
|
+
scheduler.input_queues[1].put.assert_called()
|
|
851
710
|
|
|
852
711
|
def test_shutdown(self, mock_vllm_config, mock_kv_cache_config,
|
|
853
712
|
mock_structured_output_manager):
|
|
854
|
-
"""Test shutdown
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
713
|
+
"""Test shutdown sends SHUTDOWN command to all workers."""
|
|
714
|
+
with patch(
|
|
715
|
+
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
|
|
716
|
+
):
|
|
717
|
+
with patch('multiprocessing.get_context'):
|
|
718
|
+
scheduler = DPScheduler(
|
|
719
|
+
vllm_config=mock_vllm_config,
|
|
720
|
+
kv_cache_config=mock_kv_cache_config,
|
|
721
|
+
structured_output_manager=mock_structured_output_manager,
|
|
722
|
+
block_size=16,
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
726
|
+
scheduler.output_queues = {
|
|
727
|
+
(0, "shutdown"): MagicMock(),
|
|
728
|
+
(1, "shutdown"): MagicMock(),
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
mock_process_0 = MagicMock()
|
|
732
|
+
mock_process_1 = MagicMock()
|
|
733
|
+
mock_process_0.is_alive = MagicMock(return_value=False)
|
|
734
|
+
mock_process_1.is_alive = MagicMock(return_value=False)
|
|
735
|
+
scheduler.processes = [mock_process_0, mock_process_1]
|
|
736
|
+
|
|
737
|
+
scheduler.shutdown()
|
|
738
|
+
|
|
739
|
+
# Verify SHUTDOWN commands were sent
|
|
740
|
+
scheduler.input_queues[0].put.assert_called_with(
|
|
741
|
+
(SchedulerCommand.SHUTDOWN, None))
|
|
742
|
+
scheduler.input_queues[1].put.assert_called_with(
|
|
743
|
+
(SchedulerCommand.SHUTDOWN, None))
|
|
744
|
+
|
|
745
|
+
# Verify processes were joined
|
|
746
|
+
mock_process_0.join.assert_called()
|
|
747
|
+
mock_process_1.join.assert_called()
|
|
866
748
|
|
|
867
749
|
|
|
868
750
|
class TestUpdateVllmConfigForDPScheduler:
|