tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,899 @@
|
|
|
1
|
+
from unittest.mock import MagicMock, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
import torch
|
|
5
|
+
from vllm.config import VllmConfig
|
|
6
|
+
from vllm.v1.core.sched.output import (CachedRequestData, GrammarOutput,
|
|
7
|
+
SchedulerOutput)
|
|
8
|
+
from vllm.v1.core.sched.scheduler import Scheduler
|
|
9
|
+
from vllm.v1.engine import EngineCoreOutputs
|
|
10
|
+
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
11
|
+
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
|
|
12
|
+
from vllm.v1.outputs import ModelRunnerOutput
|
|
13
|
+
from vllm.v1.request import Request
|
|
14
|
+
|
|
15
|
+
from tpu_inference.core.sched.dp_scheduler import (
|
|
16
|
+
DPScheduler, DPSchedulerOutput, update_vllm_config_for_dp_scheduler)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TestDPScheduler:
|
|
20
|
+
|
|
21
|
+
@pytest.fixture
|
|
22
|
+
def mock_vllm_config(self):
|
|
23
|
+
"""Create a mock VllmConfig for testing."""
|
|
24
|
+
config = MagicMock(spec=VllmConfig)
|
|
25
|
+
config.sharding_config = MagicMock()
|
|
26
|
+
config.sharding_config.total_dp_size = 2
|
|
27
|
+
config.scheduler_config = MagicMock()
|
|
28
|
+
config.scheduler_config._original_scheduler_cls = Scheduler
|
|
29
|
+
config.scheduler_config.max_num_seqs = 8
|
|
30
|
+
config.scheduler_config.max_num_batched_tokens = 1024
|
|
31
|
+
config.scheduler_config.async_scheduling = False
|
|
32
|
+
return config
|
|
33
|
+
|
|
34
|
+
@pytest.fixture
|
|
35
|
+
def mock_kv_cache_config(self):
|
|
36
|
+
"""Create a mock KVCacheConfig for testing."""
|
|
37
|
+
config = MagicMock(spec=KVCacheConfig)
|
|
38
|
+
config.num_blocks = 100
|
|
39
|
+
return config
|
|
40
|
+
|
|
41
|
+
@pytest.fixture
|
|
42
|
+
def mock_structured_output_manager(self):
|
|
43
|
+
"""Create a mock StructuredOutputManager."""
|
|
44
|
+
return MagicMock()
|
|
45
|
+
|
|
46
|
+
def _create_dp_scheduler_with_mocks(self, mock_vllm_config,
|
|
47
|
+
mock_kv_cache_config,
|
|
48
|
+
mock_structured_output_manager,
|
|
49
|
+
**kwargs):
|
|
50
|
+
"""Helper to create a DPScheduler with properly mocked schedulers."""
|
|
51
|
+
# Create individual mock scheduler instances
|
|
52
|
+
mock_scheduler_0 = MagicMock()
|
|
53
|
+
mock_scheduler_1 = MagicMock()
|
|
54
|
+
|
|
55
|
+
# Patch the Scheduler class to return our mock instances
|
|
56
|
+
with patch.object(
|
|
57
|
+
mock_vllm_config.scheduler_config, '_original_scheduler_cls',
|
|
58
|
+
MagicMock(side_effect=[mock_scheduler_0, mock_scheduler_1])):
|
|
59
|
+
scheduler = DPScheduler(
|
|
60
|
+
vllm_config=mock_vllm_config,
|
|
61
|
+
kv_cache_config=mock_kv_cache_config,
|
|
62
|
+
structured_output_manager=mock_structured_output_manager,
|
|
63
|
+
block_size=16,
|
|
64
|
+
**kwargs)
|
|
65
|
+
|
|
66
|
+
return scheduler
|
|
67
|
+
|
|
68
|
+
def test_init_creates_per_rank_schedulers(
|
|
69
|
+
self,
|
|
70
|
+
mock_vllm_config,
|
|
71
|
+
mock_kv_cache_config,
|
|
72
|
+
mock_structured_output_manager,
|
|
73
|
+
):
|
|
74
|
+
"""Test Initialization creates schedulers for each DP rank."""
|
|
75
|
+
# Mock the scheduler class
|
|
76
|
+
mock_scheduler_instance = MagicMock()
|
|
77
|
+
mock_scheduler_cls = MagicMock(return_value=mock_scheduler_instance)
|
|
78
|
+
|
|
79
|
+
with patch.object(mock_vllm_config.scheduler_config,
|
|
80
|
+
'_original_scheduler_cls', mock_scheduler_cls):
|
|
81
|
+
scheduler = DPScheduler(
|
|
82
|
+
vllm_config=mock_vllm_config,
|
|
83
|
+
kv_cache_config=mock_kv_cache_config,
|
|
84
|
+
structured_output_manager=mock_structured_output_manager,
|
|
85
|
+
block_size=16,
|
|
86
|
+
log_stats=True,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Verify schedulers were created
|
|
90
|
+
assert len(scheduler.schedulers) == 2
|
|
91
|
+
assert scheduler.dp_size == 2
|
|
92
|
+
assert scheduler.log_stats is True
|
|
93
|
+
assert len(scheduler.per_rank_kv_cache_configs) == 2
|
|
94
|
+
|
|
95
|
+
# Verify each rank got the correct config
|
|
96
|
+
for rank_config in scheduler.per_rank_kv_cache_configs:
|
|
97
|
+
assert rank_config.num_blocks == 50 # 100 / 2
|
|
98
|
+
|
|
99
|
+
def test_get_rank_token_counts(self, mock_vllm_config,
|
|
100
|
+
mock_kv_cache_config,
|
|
101
|
+
mock_structured_output_manager):
|
|
102
|
+
"""Test _get_rank_token_counts calculates tokens per rank."""
|
|
103
|
+
scheduler = self._create_dp_scheduler_with_mocks(
|
|
104
|
+
mock_vllm_config, mock_kv_cache_config,
|
|
105
|
+
mock_structured_output_manager)
|
|
106
|
+
|
|
107
|
+
# Mock requests on different ranks
|
|
108
|
+
req1 = MagicMock()
|
|
109
|
+
req1.num_tokens = 10
|
|
110
|
+
req2 = MagicMock()
|
|
111
|
+
req2.num_tokens = 20
|
|
112
|
+
req3 = MagicMock()
|
|
113
|
+
req3.num_tokens = 15
|
|
114
|
+
|
|
115
|
+
scheduler.schedulers[0].running = [req1]
|
|
116
|
+
scheduler.schedulers[0].waiting = [req2]
|
|
117
|
+
scheduler.schedulers[1].running = [req3]
|
|
118
|
+
scheduler.schedulers[1].waiting = []
|
|
119
|
+
|
|
120
|
+
rank_tokens = scheduler._get_rank_token_counts()
|
|
121
|
+
|
|
122
|
+
assert rank_tokens[0] == 30 # 10 + 20
|
|
123
|
+
assert rank_tokens[1] == 15
|
|
124
|
+
|
|
125
|
+
def test_find_best_rank_with_cache_hit(self, mock_vllm_config,
|
|
126
|
+
mock_kv_cache_config,
|
|
127
|
+
mock_structured_output_manager):
|
|
128
|
+
"""Test _find_best_rank_for_request with cache hit."""
|
|
129
|
+
scheduler = self._create_dp_scheduler_with_mocks(
|
|
130
|
+
mock_vllm_config, mock_kv_cache_config,
|
|
131
|
+
mock_structured_output_manager)
|
|
132
|
+
|
|
133
|
+
# Mock request
|
|
134
|
+
mock_request = MagicMock(spec=Request)
|
|
135
|
+
|
|
136
|
+
# Mock KV cache managers with different cache hits
|
|
137
|
+
scheduler.schedulers[0].kv_cache_manager = MagicMock()
|
|
138
|
+
scheduler.schedulers[
|
|
139
|
+
0].kv_cache_manager.get_computed_blocks.return_value = (
|
|
140
|
+
[],
|
|
141
|
+
10,
|
|
142
|
+
) # 10 cached tokens
|
|
143
|
+
|
|
144
|
+
scheduler.schedulers[1].kv_cache_manager = MagicMock()
|
|
145
|
+
scheduler.schedulers[
|
|
146
|
+
1].kv_cache_manager.get_computed_blocks.return_value = (
|
|
147
|
+
[],
|
|
148
|
+
20,
|
|
149
|
+
) # 20 cached tokens (better)
|
|
150
|
+
|
|
151
|
+
# Mock empty running/waiting queues
|
|
152
|
+
scheduler.schedulers[0].running = []
|
|
153
|
+
scheduler.schedulers[0].waiting = []
|
|
154
|
+
scheduler.schedulers[1].running = []
|
|
155
|
+
scheduler.schedulers[1].waiting = []
|
|
156
|
+
|
|
157
|
+
rank = scheduler._find_best_rank_for_request(mock_request)
|
|
158
|
+
|
|
159
|
+
# Should choose rank 1 with better cache hit
|
|
160
|
+
assert rank == 1
|
|
161
|
+
|
|
162
|
+
def test_find_best_rank_without_cache_hit(self, mock_vllm_config,
|
|
163
|
+
mock_kv_cache_config,
|
|
164
|
+
mock_structured_output_manager):
|
|
165
|
+
"""Test _find_best_rank_for_request without cache hit (load balancing)."""
|
|
166
|
+
scheduler = self._create_dp_scheduler_with_mocks(
|
|
167
|
+
mock_vllm_config, mock_kv_cache_config,
|
|
168
|
+
mock_structured_output_manager)
|
|
169
|
+
|
|
170
|
+
# Mock request
|
|
171
|
+
mock_request = MagicMock(spec=Request)
|
|
172
|
+
|
|
173
|
+
# Mock KV cache managers with no cache hits
|
|
174
|
+
scheduler.schedulers[0].kv_cache_manager = MagicMock()
|
|
175
|
+
scheduler.schedulers[
|
|
176
|
+
0].kv_cache_manager.get_computed_blocks.return_value = ([], 0)
|
|
177
|
+
|
|
178
|
+
scheduler.schedulers[1].kv_cache_manager = MagicMock()
|
|
179
|
+
scheduler.schedulers[
|
|
180
|
+
1].kv_cache_manager.get_computed_blocks.return_value = ([], 0)
|
|
181
|
+
|
|
182
|
+
# Mock requests with different token counts
|
|
183
|
+
req1 = MagicMock()
|
|
184
|
+
req1.num_tokens = 50
|
|
185
|
+
req2 = MagicMock()
|
|
186
|
+
req2.num_tokens = 30
|
|
187
|
+
|
|
188
|
+
scheduler.schedulers[0].running = [req1]
|
|
189
|
+
scheduler.schedulers[0].waiting = []
|
|
190
|
+
scheduler.schedulers[1].running = [req2]
|
|
191
|
+
scheduler.schedulers[1].waiting = []
|
|
192
|
+
|
|
193
|
+
rank = scheduler._find_best_rank_for_request(mock_request)
|
|
194
|
+
|
|
195
|
+
# Should choose rank 1 with fewer tokens
|
|
196
|
+
assert rank == 1
|
|
197
|
+
|
|
198
|
+
def test_add_request_assigns_to_best_rank(self, mock_vllm_config,
|
|
199
|
+
mock_kv_cache_config,
|
|
200
|
+
mock_structured_output_manager):
|
|
201
|
+
"""Test add_request assigns and adds request to best rank."""
|
|
202
|
+
scheduler = self._create_dp_scheduler_with_mocks(
|
|
203
|
+
mock_vllm_config, mock_kv_cache_config,
|
|
204
|
+
mock_structured_output_manager)
|
|
205
|
+
|
|
206
|
+
# Mock the rank selection
|
|
207
|
+
mock_request = MagicMock(spec=Request)
|
|
208
|
+
mock_request.request_id = "req1"
|
|
209
|
+
|
|
210
|
+
# Mock _find_best_rank_for_request to return rank 1
|
|
211
|
+
scheduler._find_best_rank_for_request = MagicMock(return_value=1)
|
|
212
|
+
|
|
213
|
+
# Mock schedulers
|
|
214
|
+
scheduler.schedulers[0].add_request = MagicMock()
|
|
215
|
+
scheduler.schedulers[1].add_request = MagicMock()
|
|
216
|
+
|
|
217
|
+
scheduler.add_request(mock_request)
|
|
218
|
+
|
|
219
|
+
# Verify request was assigned to rank 1
|
|
220
|
+
assert scheduler.assigned_dp_rank["req1"] == 1
|
|
221
|
+
scheduler.schedulers[1].add_request.assert_called_once_with(
|
|
222
|
+
mock_request)
|
|
223
|
+
scheduler.schedulers[0].add_request.assert_not_called()
|
|
224
|
+
|
|
225
|
+
def test_schedule_runs_all_schedulers(self, mock_vllm_config,
|
|
226
|
+
mock_kv_cache_config,
|
|
227
|
+
mock_structured_output_manager):
|
|
228
|
+
"""Test schedule runs all schedulers and combines output."""
|
|
229
|
+
scheduler = self._create_dp_scheduler_with_mocks(
|
|
230
|
+
mock_vllm_config, mock_kv_cache_config,
|
|
231
|
+
mock_structured_output_manager)
|
|
232
|
+
|
|
233
|
+
# Mock scheduler outputs
|
|
234
|
+
mock_output_0 = MagicMock(spec=SchedulerOutput)
|
|
235
|
+
mock_output_0.scheduled_new_reqs = []
|
|
236
|
+
mock_output_0.num_scheduled_tokens = {"req1": 10}
|
|
237
|
+
mock_output_0.total_num_scheduled_tokens = 10
|
|
238
|
+
mock_output_0.finished_req_ids = set()
|
|
239
|
+
mock_output_0.scheduled_cached_reqs = CachedRequestData(
|
|
240
|
+
req_ids=[],
|
|
241
|
+
resumed_req_ids=[],
|
|
242
|
+
new_token_ids=[],
|
|
243
|
+
all_token_ids=[],
|
|
244
|
+
new_block_ids=[],
|
|
245
|
+
num_computed_tokens=[],
|
|
246
|
+
num_output_tokens=[],
|
|
247
|
+
)
|
|
248
|
+
mock_output_0.scheduled_spec_decode_tokens = {}
|
|
249
|
+
mock_output_0.scheduled_encoder_inputs = {}
|
|
250
|
+
mock_output_0.num_common_prefix_blocks = []
|
|
251
|
+
|
|
252
|
+
mock_output_1 = MagicMock(spec=SchedulerOutput)
|
|
253
|
+
mock_output_1.scheduled_new_reqs = []
|
|
254
|
+
mock_output_1.num_scheduled_tokens = {"req2": 20}
|
|
255
|
+
mock_output_1.total_num_scheduled_tokens = 20
|
|
256
|
+
mock_output_1.finished_req_ids = set()
|
|
257
|
+
mock_output_1.scheduled_cached_reqs = CachedRequestData(
|
|
258
|
+
req_ids=[],
|
|
259
|
+
resumed_req_ids=[],
|
|
260
|
+
new_token_ids=[],
|
|
261
|
+
all_token_ids=[],
|
|
262
|
+
new_block_ids=[],
|
|
263
|
+
num_computed_tokens=[],
|
|
264
|
+
num_output_tokens=[],
|
|
265
|
+
)
|
|
266
|
+
mock_output_1.scheduled_spec_decode_tokens = {}
|
|
267
|
+
mock_output_1.scheduled_encoder_inputs = {}
|
|
268
|
+
mock_output_1.num_common_prefix_blocks = []
|
|
269
|
+
|
|
270
|
+
scheduler.schedulers[0].schedule = MagicMock(
|
|
271
|
+
return_value=mock_output_0)
|
|
272
|
+
scheduler.schedulers[1].schedule = MagicMock(
|
|
273
|
+
return_value=mock_output_1)
|
|
274
|
+
scheduler.schedulers[0].running = []
|
|
275
|
+
scheduler.schedulers[0].waiting = []
|
|
276
|
+
scheduler.schedulers[1].running = []
|
|
277
|
+
scheduler.schedulers[1].waiting = []
|
|
278
|
+
|
|
279
|
+
# Assign ranks for requests
|
|
280
|
+
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1}
|
|
281
|
+
|
|
282
|
+
output = scheduler.schedule()
|
|
283
|
+
|
|
284
|
+
# Verify combined output
|
|
285
|
+
assert isinstance(output, DPSchedulerOutput)
|
|
286
|
+
assert output.total_num_scheduled_tokens == 30 # 10 + 20
|
|
287
|
+
assert "req1" in output.num_scheduled_tokens
|
|
288
|
+
assert "req2" in output.num_scheduled_tokens
|
|
289
|
+
assert output.assigned_dp_rank == {"req1": 0, "req2": 1}
|
|
290
|
+
|
|
291
|
+
def test_combine_cached_request_data(self, mock_vllm_config,
|
|
292
|
+
mock_kv_cache_config,
|
|
293
|
+
mock_structured_output_manager):
|
|
294
|
+
"""Test _combine_cached_request_data combines data from all ranks."""
|
|
295
|
+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
|
|
296
|
+
with patch.object(mock_vllm_config.scheduler_config,
|
|
297
|
+
'_original_scheduler_cls', mock_scheduler_cls):
|
|
298
|
+
scheduler = DPScheduler(
|
|
299
|
+
vllm_config=mock_vllm_config,
|
|
300
|
+
kv_cache_config=mock_kv_cache_config,
|
|
301
|
+
structured_output_manager=mock_structured_output_manager,
|
|
302
|
+
block_size=16,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Create mock rank outputs with different cached request data
|
|
306
|
+
output_0 = MagicMock(spec=SchedulerOutput)
|
|
307
|
+
output_0.scheduled_cached_reqs = CachedRequestData(
|
|
308
|
+
req_ids=["req1"],
|
|
309
|
+
resumed_req_ids=["req1"],
|
|
310
|
+
new_token_ids=[[1, 2, 3]],
|
|
311
|
+
all_token_ids=[[1, 2, 3, 4, 5]],
|
|
312
|
+
new_block_ids=[[10, 11]],
|
|
313
|
+
num_computed_tokens=[5],
|
|
314
|
+
num_output_tokens=[3],
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
output_1 = MagicMock(spec=SchedulerOutput)
|
|
318
|
+
output_1.scheduled_cached_reqs = CachedRequestData(
|
|
319
|
+
req_ids=["req2"],
|
|
320
|
+
resumed_req_ids=[],
|
|
321
|
+
new_token_ids=[[6, 7]],
|
|
322
|
+
all_token_ids=[[6, 7, 8, 9]],
|
|
323
|
+
new_block_ids=[[20, 21]],
|
|
324
|
+
num_computed_tokens=[4],
|
|
325
|
+
num_output_tokens=[2],
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
rank_outputs = [output_0, output_1]
|
|
329
|
+
combined = scheduler._combine_cached_request_data(rank_outputs)
|
|
330
|
+
|
|
331
|
+
# Verify combined data
|
|
332
|
+
assert combined.req_ids == ["req1", "req2"]
|
|
333
|
+
assert combined.resumed_req_ids == ["req1"]
|
|
334
|
+
assert combined.new_token_ids == [[1, 2, 3], [6, 7]]
|
|
335
|
+
assert combined.all_token_ids == [[1, 2, 3, 4, 5], [6, 7, 8, 9]]
|
|
336
|
+
assert combined.new_block_ids == [[10, 11], [20, 21]]
|
|
337
|
+
assert combined.num_computed_tokens == [5, 4]
|
|
338
|
+
assert combined.num_output_tokens == [3, 2]
|
|
339
|
+
|
|
340
|
+
def test_get_grammar_bitmask_with_structured_output(
|
|
341
|
+
self, mock_vllm_config, mock_kv_cache_config,
|
|
342
|
+
mock_structured_output_manager):
|
|
343
|
+
"""Test get_grammar_bitmask combines bitmasks from all ranks."""
|
|
344
|
+
scheduler = self._create_dp_scheduler_with_mocks(
|
|
345
|
+
mock_vllm_config, mock_kv_cache_config,
|
|
346
|
+
mock_structured_output_manager)
|
|
347
|
+
|
|
348
|
+
# Create mock scheduler outputs
|
|
349
|
+
mock_output_0 = MagicMock()
|
|
350
|
+
mock_output_1 = MagicMock()
|
|
351
|
+
|
|
352
|
+
# Mock grammar outputs from each rank
|
|
353
|
+
grammar_output_0 = GrammarOutput(
|
|
354
|
+
structured_output_request_ids=["req1"],
|
|
355
|
+
grammar_bitmask=torch.ones((1, 100), dtype=torch.bool),
|
|
356
|
+
)
|
|
357
|
+
grammar_output_1 = GrammarOutput(
|
|
358
|
+
structured_output_request_ids=["req2"],
|
|
359
|
+
grammar_bitmask=torch.ones((1, 100), dtype=torch.bool) * 0,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
scheduler.schedulers[0].get_grammar_bitmask = MagicMock(
|
|
363
|
+
return_value=grammar_output_0)
|
|
364
|
+
scheduler.schedulers[1].get_grammar_bitmask = MagicMock(
|
|
365
|
+
return_value=grammar_output_1)
|
|
366
|
+
|
|
367
|
+
# Cache scheduler outputs
|
|
368
|
+
scheduler.cached_schedulers_output.append(
|
|
369
|
+
[mock_output_0, mock_output_1])
|
|
370
|
+
|
|
371
|
+
# Create a DPSchedulerOutput
|
|
372
|
+
dp_output = DPSchedulerOutput(
|
|
373
|
+
scheduled_new_reqs=[],
|
|
374
|
+
scheduled_cached_reqs=CachedRequestData(
|
|
375
|
+
req_ids=[],
|
|
376
|
+
resumed_req_ids=[],
|
|
377
|
+
new_token_ids=[],
|
|
378
|
+
all_token_ids=[],
|
|
379
|
+
new_block_ids=[],
|
|
380
|
+
num_computed_tokens=[],
|
|
381
|
+
num_output_tokens=[],
|
|
382
|
+
),
|
|
383
|
+
num_scheduled_tokens={},
|
|
384
|
+
total_num_scheduled_tokens=0,
|
|
385
|
+
scheduled_spec_decode_tokens={},
|
|
386
|
+
scheduled_encoder_inputs={},
|
|
387
|
+
num_common_prefix_blocks=[],
|
|
388
|
+
finished_req_ids=set(),
|
|
389
|
+
free_encoder_mm_hashes=set(),
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
result = scheduler.get_grammar_bitmask(dp_output)
|
|
393
|
+
|
|
394
|
+
assert result is not None
|
|
395
|
+
assert result.structured_output_request_ids == ["req1", "req2"]
|
|
396
|
+
assert result.grammar_bitmask.shape == (2, 100)
|
|
397
|
+
|
|
398
|
+
def test_get_grammar_bitmask_no_structured_output(
|
|
399
|
+
self, mock_vllm_config, mock_kv_cache_config,
|
|
400
|
+
mock_structured_output_manager):
|
|
401
|
+
"""Test get_grammar_bitmask returns None when no structured output."""
|
|
402
|
+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
|
|
403
|
+
with patch.object(mock_vllm_config.scheduler_config,
|
|
404
|
+
'_original_scheduler_cls', mock_scheduler_cls):
|
|
405
|
+
scheduler = DPScheduler(
|
|
406
|
+
vllm_config=mock_vllm_config,
|
|
407
|
+
kv_cache_config=mock_kv_cache_config,
|
|
408
|
+
structured_output_manager=mock_structured_output_manager,
|
|
409
|
+
block_size=16,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
# Mock schedulers returning None
|
|
413
|
+
scheduler.schedulers[0].get_grammar_bitmask = MagicMock(
|
|
414
|
+
return_value=None)
|
|
415
|
+
scheduler.schedulers[1].get_grammar_bitmask = MagicMock(
|
|
416
|
+
return_value=None)
|
|
417
|
+
|
|
418
|
+
# Cache scheduler outputs
|
|
419
|
+
mock_output_0 = MagicMock()
|
|
420
|
+
mock_output_1 = MagicMock()
|
|
421
|
+
scheduler.cached_schedulers_output.append(
|
|
422
|
+
[mock_output_0, mock_output_1])
|
|
423
|
+
|
|
424
|
+
dp_output = DPSchedulerOutput(
|
|
425
|
+
scheduled_new_reqs=[],
|
|
426
|
+
scheduled_cached_reqs=CachedRequestData(
|
|
427
|
+
req_ids=[],
|
|
428
|
+
resumed_req_ids=[],
|
|
429
|
+
new_token_ids=[],
|
|
430
|
+
all_token_ids=[],
|
|
431
|
+
new_block_ids=[],
|
|
432
|
+
num_computed_tokens=[],
|
|
433
|
+
num_output_tokens=[],
|
|
434
|
+
),
|
|
435
|
+
num_scheduled_tokens={},
|
|
436
|
+
total_num_scheduled_tokens=0,
|
|
437
|
+
scheduled_spec_decode_tokens={},
|
|
438
|
+
scheduled_encoder_inputs={},
|
|
439
|
+
num_common_prefix_blocks=[],
|
|
440
|
+
finished_req_ids=set(),
|
|
441
|
+
free_encoder_mm_hashes=set(),
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
result = scheduler.get_grammar_bitmask(dp_output)
|
|
445
|
+
assert result is None
|
|
446
|
+
|
|
447
|
+
def test_update_from_output_routes_to_schedulers(
|
|
448
|
+
self, mock_vllm_config, mock_kv_cache_config,
|
|
449
|
+
mock_structured_output_manager):
|
|
450
|
+
"""Test update_from_output splits output and updates each scheduler."""
|
|
451
|
+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
|
|
452
|
+
with patch.object(mock_vllm_config.scheduler_config,
|
|
453
|
+
'_original_scheduler_cls', mock_scheduler_cls):
|
|
454
|
+
scheduler = DPScheduler(
|
|
455
|
+
vllm_config=mock_vllm_config,
|
|
456
|
+
kv_cache_config=mock_kv_cache_config,
|
|
457
|
+
structured_output_manager=mock_structured_output_manager,
|
|
458
|
+
block_size=16,
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Setup assigned ranks
|
|
462
|
+
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
|
|
463
|
+
|
|
464
|
+
# Create DPSchedulerOutput
|
|
465
|
+
dp_output = DPSchedulerOutput(
|
|
466
|
+
scheduled_new_reqs=[],
|
|
467
|
+
scheduled_cached_reqs=CachedRequestData(
|
|
468
|
+
req_ids=[],
|
|
469
|
+
resumed_req_ids=[],
|
|
470
|
+
new_token_ids=[],
|
|
471
|
+
all_token_ids=[],
|
|
472
|
+
new_block_ids=[],
|
|
473
|
+
num_computed_tokens=[],
|
|
474
|
+
num_output_tokens=[],
|
|
475
|
+
),
|
|
476
|
+
num_scheduled_tokens={
|
|
477
|
+
"req1": 10,
|
|
478
|
+
"req2": 20,
|
|
479
|
+
"req3": 15
|
|
480
|
+
},
|
|
481
|
+
total_num_scheduled_tokens=45,
|
|
482
|
+
scheduled_spec_decode_tokens={},
|
|
483
|
+
scheduled_encoder_inputs={},
|
|
484
|
+
num_common_prefix_blocks=[],
|
|
485
|
+
finished_req_ids={"req3"}, # req3 finished
|
|
486
|
+
free_encoder_mm_hashes=set(),
|
|
487
|
+
assigned_dp_rank={
|
|
488
|
+
"req1": 0,
|
|
489
|
+
"req2": 1,
|
|
490
|
+
"req3": 0
|
|
491
|
+
},
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# Create mock model runner output
|
|
495
|
+
model_output = ModelRunnerOutput(
|
|
496
|
+
req_ids=["req1", "req2", "req3"],
|
|
497
|
+
req_id_to_index={
|
|
498
|
+
"req1": 0,
|
|
499
|
+
"req2": 1,
|
|
500
|
+
"req3": 2
|
|
501
|
+
},
|
|
502
|
+
sampled_token_ids=torch.tensor([100, 200, 300]),
|
|
503
|
+
logprobs=None,
|
|
504
|
+
prompt_logprobs_dict={},
|
|
505
|
+
pooler_output=None,
|
|
506
|
+
num_nans_in_logits=0,
|
|
507
|
+
kv_connector_output=None,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# Mock rank scheduler outputs (cached from schedule call)
|
|
511
|
+
rank_output_0 = MagicMock()
|
|
512
|
+
rank_output_1 = MagicMock()
|
|
513
|
+
scheduler.cached_schedulers_output.append(
|
|
514
|
+
[rank_output_0, rank_output_1])
|
|
515
|
+
|
|
516
|
+
# Mock scheduler update_from_output
|
|
517
|
+
engine_output_0 = EngineCoreOutputs()
|
|
518
|
+
engine_output_0.engine_index = 0
|
|
519
|
+
engine_output_0.outputs = []
|
|
520
|
+
engine_output_0.finished_requests = {"req3"}
|
|
521
|
+
|
|
522
|
+
engine_output_1 = EngineCoreOutputs()
|
|
523
|
+
engine_output_1.engine_index = 0
|
|
524
|
+
engine_output_1.outputs = []
|
|
525
|
+
engine_output_1.finished_requests = set()
|
|
526
|
+
|
|
527
|
+
scheduler.schedulers[0].update_from_output = MagicMock(
|
|
528
|
+
return_value={0: engine_output_0})
|
|
529
|
+
scheduler.schedulers[1].update_from_output = MagicMock(
|
|
530
|
+
return_value={0: engine_output_1})
|
|
531
|
+
|
|
532
|
+
# Mock make_stats
|
|
533
|
+
scheduler.make_stats = MagicMock(return_value=None)
|
|
534
|
+
|
|
535
|
+
_ = scheduler.update_from_output(dp_output, model_output)
|
|
536
|
+
|
|
537
|
+
# Verify schedulers were updated
|
|
538
|
+
assert scheduler.schedulers[0].update_from_output.called
|
|
539
|
+
assert scheduler.schedulers[1].update_from_output.called
|
|
540
|
+
|
|
541
|
+
# Verify finished request was cleaned up
|
|
542
|
+
assert "req3" not in scheduler.assigned_dp_rank
|
|
543
|
+
assert "req1" in scheduler.assigned_dp_rank
|
|
544
|
+
assert "req2" in scheduler.assigned_dp_rank
|
|
545
|
+
|
|
546
|
+
def test_split_model_output_by_rank(self, mock_vllm_config,
|
|
547
|
+
mock_kv_cache_config,
|
|
548
|
+
mock_structured_output_manager):
|
|
549
|
+
"""Test _split_model_output_by_rank distributes output correctly."""
|
|
550
|
+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
|
|
551
|
+
with patch.object(mock_vllm_config.scheduler_config,
|
|
552
|
+
'_original_scheduler_cls', mock_scheduler_cls):
|
|
553
|
+
scheduler = DPScheduler(
|
|
554
|
+
vllm_config=mock_vllm_config,
|
|
555
|
+
kv_cache_config=mock_kv_cache_config,
|
|
556
|
+
structured_output_manager=mock_structured_output_manager,
|
|
557
|
+
block_size=16,
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
# Setup assigned ranks
|
|
561
|
+
scheduler.assigned_dp_rank = {
|
|
562
|
+
"req1": 0,
|
|
563
|
+
"req2": 1,
|
|
564
|
+
"req3": 0,
|
|
565
|
+
"req4": 1
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
# Create global model output
|
|
569
|
+
global_output = ModelRunnerOutput(
|
|
570
|
+
req_ids=["req1", "req2", "req3", "req4"],
|
|
571
|
+
req_id_to_index={
|
|
572
|
+
"req1": 0,
|
|
573
|
+
"req2": 1,
|
|
574
|
+
"req3": 2,
|
|
575
|
+
"req4": 3
|
|
576
|
+
},
|
|
577
|
+
sampled_token_ids=torch.tensor([100, 200, 300, 400]),
|
|
578
|
+
logprobs=None,
|
|
579
|
+
prompt_logprobs_dict={},
|
|
580
|
+
pooler_output=None,
|
|
581
|
+
num_nans_in_logits=0,
|
|
582
|
+
kv_connector_output=None,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
rank_outputs = scheduler._split_model_output_by_rank(global_output)
|
|
586
|
+
|
|
587
|
+
# Verify split outputs
|
|
588
|
+
assert len(rank_outputs) == 2
|
|
589
|
+
assert rank_outputs[0].req_ids == ["req1", "req3"]
|
|
590
|
+
assert rank_outputs[1].req_ids == ["req2", "req4"]
|
|
591
|
+
|
|
592
|
+
def test_cleanup_finished_requests(self, mock_vllm_config,
|
|
593
|
+
mock_kv_cache_config,
|
|
594
|
+
mock_structured_output_manager):
|
|
595
|
+
"""Test _cleanup_finished_requests removes finished requests."""
|
|
596
|
+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
|
|
597
|
+
with patch.object(mock_vllm_config.scheduler_config,
|
|
598
|
+
'_original_scheduler_cls', mock_scheduler_cls):
|
|
599
|
+
scheduler = DPScheduler(
|
|
600
|
+
vllm_config=mock_vllm_config,
|
|
601
|
+
kv_cache_config=mock_kv_cache_config,
|
|
602
|
+
structured_output_manager=mock_structured_output_manager,
|
|
603
|
+
block_size=16,
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
# Setup assigned ranks
|
|
607
|
+
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
|
|
608
|
+
|
|
609
|
+
# Clean up finished requests
|
|
610
|
+
scheduler._cleanup_finished_requests({"req1", "req3"})
|
|
611
|
+
|
|
612
|
+
# Verify cleanup
|
|
613
|
+
assert "req1" not in scheduler.assigned_dp_rank
|
|
614
|
+
assert "req3" not in scheduler.assigned_dp_rank
|
|
615
|
+
assert "req2" in scheduler.assigned_dp_rank
|
|
616
|
+
|
|
617
|
+
def test_finish_requests_single_and_multiple(
|
|
618
|
+
self, mock_vllm_config, mock_kv_cache_config,
|
|
619
|
+
mock_structured_output_manager):
|
|
620
|
+
"""Test finish_requests handles single string and list."""
|
|
621
|
+
scheduler = self._create_dp_scheduler_with_mocks(
|
|
622
|
+
mock_vllm_config, mock_kv_cache_config,
|
|
623
|
+
mock_structured_output_manager)
|
|
624
|
+
|
|
625
|
+
# Setup assigned ranks
|
|
626
|
+
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
|
|
627
|
+
|
|
628
|
+
# Mock scheduler finish_requests
|
|
629
|
+
scheduler.schedulers[0].finish_requests = MagicMock()
|
|
630
|
+
scheduler.schedulers[1].finish_requests = MagicMock()
|
|
631
|
+
|
|
632
|
+
# Test with single string
|
|
633
|
+
scheduler.finish_requests("req1", finished_status="completed")
|
|
634
|
+
scheduler.schedulers[0].finish_requests.assert_called_with(["req1"],
|
|
635
|
+
"completed")
|
|
636
|
+
|
|
637
|
+
# Test with list
|
|
638
|
+
scheduler.schedulers[0].finish_requests.reset_mock()
|
|
639
|
+
scheduler.schedulers[1].finish_requests.reset_mock()
|
|
640
|
+
|
|
641
|
+
scheduler.finish_requests(["req1", "req2"],
|
|
642
|
+
finished_status="completed")
|
|
643
|
+
scheduler.schedulers[0].finish_requests.assert_called_once_with(
|
|
644
|
+
["req1"], "completed")
|
|
645
|
+
scheduler.schedulers[1].finish_requests.assert_called_once_with(
|
|
646
|
+
["req2"], "completed")
|
|
647
|
+
|
|
648
|
+
def test_get_num_unfinished_requests(self, mock_vllm_config,
|
|
649
|
+
mock_kv_cache_config,
|
|
650
|
+
mock_structured_output_manager):
|
|
651
|
+
"""Test get_num_unfinished_requests aggregates across ranks."""
|
|
652
|
+
scheduler = self._create_dp_scheduler_with_mocks(
|
|
653
|
+
mock_vllm_config, mock_kv_cache_config,
|
|
654
|
+
mock_structured_output_manager)
|
|
655
|
+
|
|
656
|
+
scheduler.schedulers[0].get_num_unfinished_requests = MagicMock(
|
|
657
|
+
return_value=5)
|
|
658
|
+
scheduler.schedulers[1].get_num_unfinished_requests = MagicMock(
|
|
659
|
+
return_value=3)
|
|
660
|
+
|
|
661
|
+
total = scheduler.get_num_unfinished_requests()
|
|
662
|
+
assert total == 8
|
|
663
|
+
|
|
664
|
+
def test_has_finished_requests(self, mock_vllm_config,
|
|
665
|
+
mock_kv_cache_config,
|
|
666
|
+
mock_structured_output_manager):
|
|
667
|
+
"""Test has_finished_requests checks all ranks."""
|
|
668
|
+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
|
|
669
|
+
with patch.object(mock_vllm_config.scheduler_config,
|
|
670
|
+
'_original_scheduler_cls', mock_scheduler_cls):
|
|
671
|
+
scheduler = DPScheduler(
|
|
672
|
+
vllm_config=mock_vllm_config,
|
|
673
|
+
kv_cache_config=mock_kv_cache_config,
|
|
674
|
+
structured_output_manager=mock_structured_output_manager,
|
|
675
|
+
block_size=16,
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
# Test when one rank has finished requests
|
|
679
|
+
scheduler.schedulers[0].has_finished_requests = MagicMock(
|
|
680
|
+
return_value=False)
|
|
681
|
+
scheduler.schedulers[1].has_finished_requests = MagicMock(
|
|
682
|
+
return_value=True)
|
|
683
|
+
|
|
684
|
+
assert scheduler.has_finished_requests() is True
|
|
685
|
+
|
|
686
|
+
# Test when no rank has finished requests
|
|
687
|
+
scheduler.schedulers[1].has_finished_requests = MagicMock(
|
|
688
|
+
return_value=False)
|
|
689
|
+
assert scheduler.has_finished_requests() is False
|
|
690
|
+
|
|
691
|
+
def test_get_request_counts(self, mock_vllm_config, mock_kv_cache_config,
|
|
692
|
+
mock_structured_output_manager):
|
|
693
|
+
"""Test get_request_counts aggregates across ranks."""
|
|
694
|
+
scheduler = self._create_dp_scheduler_with_mocks(
|
|
695
|
+
mock_vllm_config, mock_kv_cache_config,
|
|
696
|
+
mock_structured_output_manager)
|
|
697
|
+
|
|
698
|
+
# Mock running and waiting queues
|
|
699
|
+
scheduler.schedulers[0].running = [MagicMock(),
|
|
700
|
+
MagicMock()] # 2 running
|
|
701
|
+
scheduler.schedulers[0].waiting = [MagicMock()] # 1 waiting
|
|
702
|
+
scheduler.schedulers[1].running = [MagicMock()] # 1 running
|
|
703
|
+
scheduler.schedulers[1].waiting = [
|
|
704
|
+
MagicMock(), MagicMock(), MagicMock()
|
|
705
|
+
] # 3 waiting
|
|
706
|
+
|
|
707
|
+
running, waiting = scheduler.get_request_counts()
|
|
708
|
+
|
|
709
|
+
assert running == 3 # 2 + 1
|
|
710
|
+
assert waiting == 4 # 1 + 3
|
|
711
|
+
|
|
712
|
+
def test_reset_prefix_cache(self, mock_vllm_config, mock_kv_cache_config,
|
|
713
|
+
mock_structured_output_manager):
|
|
714
|
+
"""Test reset_prefix_cache resets all ranks."""
|
|
715
|
+
scheduler = self._create_dp_scheduler_with_mocks(
|
|
716
|
+
mock_vllm_config, mock_kv_cache_config,
|
|
717
|
+
mock_structured_output_manager)
|
|
718
|
+
|
|
719
|
+
scheduler.schedulers[0].reset_prefix_cache = MagicMock(
|
|
720
|
+
return_value=True)
|
|
721
|
+
scheduler.schedulers[1].reset_prefix_cache = MagicMock(
|
|
722
|
+
return_value=True)
|
|
723
|
+
|
|
724
|
+
result = scheduler.reset_prefix_cache()
|
|
725
|
+
|
|
726
|
+
assert result is True
|
|
727
|
+
scheduler.schedulers[0].reset_prefix_cache.assert_called_once()
|
|
728
|
+
scheduler.schedulers[1].reset_prefix_cache.assert_called_once()
|
|
729
|
+
|
|
730
|
+
def test_make_stats_with_logging_enabled(self, mock_vllm_config,
|
|
731
|
+
mock_kv_cache_config,
|
|
732
|
+
mock_structured_output_manager):
|
|
733
|
+
"""Test make_stats aggregates stats from all ranks."""
|
|
734
|
+
scheduler = self._create_dp_scheduler_with_mocks(
|
|
735
|
+
mock_vllm_config,
|
|
736
|
+
mock_kv_cache_config,
|
|
737
|
+
mock_structured_output_manager,
|
|
738
|
+
log_stats=True)
|
|
739
|
+
|
|
740
|
+
# Create mock stats for each rank
|
|
741
|
+
stats_0 = SchedulerStats(
|
|
742
|
+
num_running_reqs=3,
|
|
743
|
+
num_waiting_reqs=2,
|
|
744
|
+
kv_cache_usage=0.5,
|
|
745
|
+
prefix_cache_stats=PrefixCacheStats(reset=False,
|
|
746
|
+
requests=10,
|
|
747
|
+
queries=8,
|
|
748
|
+
hits=5),
|
|
749
|
+
connector_prefix_cache_stats=PrefixCacheStats(reset=False,
|
|
750
|
+
requests=5,
|
|
751
|
+
queries=4,
|
|
752
|
+
hits=2),
|
|
753
|
+
spec_decoding_stats=None,
|
|
754
|
+
kv_connector_stats=None,
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
stats_1 = SchedulerStats(
|
|
758
|
+
num_running_reqs=4,
|
|
759
|
+
num_waiting_reqs=1,
|
|
760
|
+
kv_cache_usage=0.7,
|
|
761
|
+
prefix_cache_stats=PrefixCacheStats(reset=False,
|
|
762
|
+
requests=15,
|
|
763
|
+
queries=12,
|
|
764
|
+
hits=8),
|
|
765
|
+
connector_prefix_cache_stats=PrefixCacheStats(reset=False,
|
|
766
|
+
requests=6,
|
|
767
|
+
queries=5,
|
|
768
|
+
hits=3),
|
|
769
|
+
spec_decoding_stats=None,
|
|
770
|
+
kv_connector_stats=None,
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
scheduler.schedulers[0].make_stats = MagicMock(return_value=stats_0)
|
|
774
|
+
scheduler.schedulers[1].make_stats = MagicMock(return_value=stats_1)
|
|
775
|
+
|
|
776
|
+
combined_stats = scheduler.make_stats()
|
|
777
|
+
|
|
778
|
+
# Verify aggregated stats
|
|
779
|
+
assert combined_stats.num_running_reqs == 7 # 3 + 4
|
|
780
|
+
assert combined_stats.num_waiting_reqs == 3 # 2 + 1
|
|
781
|
+
assert combined_stats.kv_cache_usage == 0.6 # (0.5 + 0.7) / 2
|
|
782
|
+
|
|
783
|
+
# Verify prefix cache stats
|
|
784
|
+
assert combined_stats.prefix_cache_stats.requests == 25 # 10 + 15
|
|
785
|
+
assert combined_stats.prefix_cache_stats.queries == 20 # 8 + 12
|
|
786
|
+
assert combined_stats.prefix_cache_stats.hits == 13 # 5 + 8
|
|
787
|
+
|
|
788
|
+
# Verify connector prefix cache stats
|
|
789
|
+
assert combined_stats.connector_prefix_cache_stats.requests == 11 # 5 + 6
|
|
790
|
+
assert combined_stats.connector_prefix_cache_stats.queries == 9 # 4 + 5
|
|
791
|
+
assert combined_stats.connector_prefix_cache_stats.hits == 5 # 2 + 3
|
|
792
|
+
|
|
793
|
+
def test_make_stats_with_logging_disabled(self, mock_vllm_config,
|
|
794
|
+
mock_kv_cache_config,
|
|
795
|
+
mock_structured_output_manager):
|
|
796
|
+
"""Test make_stats returns None when logging is disabled."""
|
|
797
|
+
mock_scheduler_cls = MagicMock(return_value=MagicMock())
|
|
798
|
+
with patch.object(mock_vllm_config.scheduler_config,
|
|
799
|
+
'_original_scheduler_cls', mock_scheduler_cls):
|
|
800
|
+
scheduler = DPScheduler(
|
|
801
|
+
vllm_config=mock_vllm_config,
|
|
802
|
+
kv_cache_config=mock_kv_cache_config,
|
|
803
|
+
structured_output_manager=mock_structured_output_manager,
|
|
804
|
+
block_size=16,
|
|
805
|
+
log_stats=False,
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
stats = scheduler.make_stats()
|
|
809
|
+
assert stats is None
|
|
810
|
+
|
|
811
|
+
def test_update_draft_token_ids(self, mock_vllm_config,
|
|
812
|
+
mock_kv_cache_config,
|
|
813
|
+
mock_structured_output_manager):
|
|
814
|
+
"""Test update_draft_token_ids routes tokens to correct ranks."""
|
|
815
|
+
scheduler = self._create_dp_scheduler_with_mocks(
|
|
816
|
+
mock_vllm_config, mock_kv_cache_config,
|
|
817
|
+
mock_structured_output_manager)
|
|
818
|
+
|
|
819
|
+
# Setup assigned ranks
|
|
820
|
+
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
|
|
821
|
+
|
|
822
|
+
# Create mock draft token IDs
|
|
823
|
+
draft_token_ids = MagicMock()
|
|
824
|
+
draft_token_ids.req_ids = ["req1", "req2", "req3"]
|
|
825
|
+
draft_token_ids.draft_token_ids = [
|
|
826
|
+
[101, 102, 103],
|
|
827
|
+
[201, 202],
|
|
828
|
+
[301, 302, 303, 304],
|
|
829
|
+
]
|
|
830
|
+
|
|
831
|
+
# Mock scheduler update_draft_token_ids
|
|
832
|
+
scheduler.schedulers[0].update_draft_token_ids = MagicMock()
|
|
833
|
+
scheduler.schedulers[1].update_draft_token_ids = MagicMock()
|
|
834
|
+
|
|
835
|
+
scheduler.update_draft_token_ids(draft_token_ids)
|
|
836
|
+
|
|
837
|
+
# Verify each scheduler received correct tokens
|
|
838
|
+
assert scheduler.schedulers[0].update_draft_token_ids.called
|
|
839
|
+
assert scheduler.schedulers[1].update_draft_token_ids.called
|
|
840
|
+
|
|
841
|
+
# Check rank 0 got req1 and req3
|
|
842
|
+
call_args_0 = scheduler.schedulers[0].update_draft_token_ids.call_args[
|
|
843
|
+
0][0]
|
|
844
|
+
assert "req1" in call_args_0.req_ids
|
|
845
|
+
assert "req3" in call_args_0.req_ids
|
|
846
|
+
|
|
847
|
+
# Check rank 1 got req2
|
|
848
|
+
call_args_1 = scheduler.schedulers[1].update_draft_token_ids.call_args[
|
|
849
|
+
0][0]
|
|
850
|
+
assert "req2" in call_args_1.req_ids
|
|
851
|
+
|
|
852
|
+
def test_shutdown(self, mock_vllm_config, mock_kv_cache_config,
|
|
853
|
+
mock_structured_output_manager):
|
|
854
|
+
"""Test shutdown calls shutdown on all schedulers."""
|
|
855
|
+
scheduler = self._create_dp_scheduler_with_mocks(
|
|
856
|
+
mock_vllm_config, mock_kv_cache_config,
|
|
857
|
+
mock_structured_output_manager)
|
|
858
|
+
|
|
859
|
+
scheduler.schedulers[0].shutdown = MagicMock()
|
|
860
|
+
scheduler.schedulers[1].shutdown = MagicMock()
|
|
861
|
+
|
|
862
|
+
scheduler.shutdown()
|
|
863
|
+
|
|
864
|
+
scheduler.schedulers[0].shutdown.assert_called_once()
|
|
865
|
+
scheduler.schedulers[1].shutdown.assert_called_once()
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
class TestUpdateVllmConfigForDPScheduler:
|
|
869
|
+
"""Test the update_vllm_config_for_dp_scheduler function."""
|
|
870
|
+
|
|
871
|
+
def test_update_config_with_dp_size_greater_than_one(self):
|
|
872
|
+
"""Test Config is updated when DP size > 1."""
|
|
873
|
+
mock_config = MagicMock()
|
|
874
|
+
mock_config.sharding_config.total_dp_size = 2
|
|
875
|
+
mock_config.scheduler_config._original_scheduler_cls = None
|
|
876
|
+
mock_config.scheduler_config.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
|
|
877
|
+
mock_config.scheduler_config.async_scheduling = False
|
|
878
|
+
|
|
879
|
+
update_vllm_config_for_dp_scheduler(mock_config)
|
|
880
|
+
|
|
881
|
+
# Verify config was updated
|
|
882
|
+
assert mock_config.scheduler_config._original_scheduler_cls == Scheduler
|
|
883
|
+
assert mock_config.scheduler_config.scheduler_cls == DPScheduler
|
|
884
|
+
|
|
885
|
+
def test_update_config_with_dp_size_one(self):
|
|
886
|
+
"""Test that config is NOT updated when DP size == 1."""
|
|
887
|
+
mock_config = MagicMock()
|
|
888
|
+
mock_config.sharding_config.total_dp_size = 1
|
|
889
|
+
original_scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
|
|
890
|
+
mock_config.scheduler_config.scheduler_cls = original_scheduler_cls
|
|
891
|
+
|
|
892
|
+
update_vllm_config_for_dp_scheduler(mock_config)
|
|
893
|
+
|
|
894
|
+
# Verify config was NOT changed
|
|
895
|
+
assert mock_config.scheduler_config.scheduler_cls == original_scheduler_cls
|
|
896
|
+
|
|
897
|
+
|
|
898
|
+
if __name__ == "__main__":
|
|
899
|
+
pytest.main([__file__, "-v"])
|