tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,899 @@
1
+ from unittest.mock import MagicMock, patch
2
+
3
+ import pytest
4
+ import torch
5
+ from vllm.config import VllmConfig
6
+ from vllm.v1.core.sched.output import (CachedRequestData, GrammarOutput,
7
+ SchedulerOutput)
8
+ from vllm.v1.core.sched.scheduler import Scheduler
9
+ from vllm.v1.engine import EngineCoreOutputs
10
+ from vllm.v1.kv_cache_interface import KVCacheConfig
11
+ from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
12
+ from vllm.v1.outputs import ModelRunnerOutput
13
+ from vllm.v1.request import Request
14
+
15
+ from tpu_inference.core.sched.dp_scheduler import (
16
+ DPScheduler, DPSchedulerOutput, update_vllm_config_for_dp_scheduler)
17
+
18
+
19
+ class TestDPScheduler:
20
+
21
+ @pytest.fixture
22
+ def mock_vllm_config(self):
23
+ """Create a mock VllmConfig for testing."""
24
+ config = MagicMock(spec=VllmConfig)
25
+ config.sharding_config = MagicMock()
26
+ config.sharding_config.total_dp_size = 2
27
+ config.scheduler_config = MagicMock()
28
+ config.scheduler_config._original_scheduler_cls = Scheduler
29
+ config.scheduler_config.max_num_seqs = 8
30
+ config.scheduler_config.max_num_batched_tokens = 1024
31
+ config.scheduler_config.async_scheduling = False
32
+ return config
33
+
34
+ @pytest.fixture
35
+ def mock_kv_cache_config(self):
36
+ """Create a mock KVCacheConfig for testing."""
37
+ config = MagicMock(spec=KVCacheConfig)
38
+ config.num_blocks = 100
39
+ return config
40
+
41
+ @pytest.fixture
42
+ def mock_structured_output_manager(self):
43
+ """Create a mock StructuredOutputManager."""
44
+ return MagicMock()
45
+
46
+ def _create_dp_scheduler_with_mocks(self, mock_vllm_config,
47
+ mock_kv_cache_config,
48
+ mock_structured_output_manager,
49
+ **kwargs):
50
+ """Helper to create a DPScheduler with properly mocked schedulers."""
51
+ # Create individual mock scheduler instances
52
+ mock_scheduler_0 = MagicMock()
53
+ mock_scheduler_1 = MagicMock()
54
+
55
+ # Patch the Scheduler class to return our mock instances
56
+ with patch.object(
57
+ mock_vllm_config.scheduler_config, '_original_scheduler_cls',
58
+ MagicMock(side_effect=[mock_scheduler_0, mock_scheduler_1])):
59
+ scheduler = DPScheduler(
60
+ vllm_config=mock_vllm_config,
61
+ kv_cache_config=mock_kv_cache_config,
62
+ structured_output_manager=mock_structured_output_manager,
63
+ block_size=16,
64
+ **kwargs)
65
+
66
+ return scheduler
67
+
68
+ def test_init_creates_per_rank_schedulers(
69
+ self,
70
+ mock_vllm_config,
71
+ mock_kv_cache_config,
72
+ mock_structured_output_manager,
73
+ ):
74
+ """Test Initialization creates schedulers for each DP rank."""
75
+ # Mock the scheduler class
76
+ mock_scheduler_instance = MagicMock()
77
+ mock_scheduler_cls = MagicMock(return_value=mock_scheduler_instance)
78
+
79
+ with patch.object(mock_vllm_config.scheduler_config,
80
+ '_original_scheduler_cls', mock_scheduler_cls):
81
+ scheduler = DPScheduler(
82
+ vllm_config=mock_vllm_config,
83
+ kv_cache_config=mock_kv_cache_config,
84
+ structured_output_manager=mock_structured_output_manager,
85
+ block_size=16,
86
+ log_stats=True,
87
+ )
88
+
89
+ # Verify schedulers were created
90
+ assert len(scheduler.schedulers) == 2
91
+ assert scheduler.dp_size == 2
92
+ assert scheduler.log_stats is True
93
+ assert len(scheduler.per_rank_kv_cache_configs) == 2
94
+
95
+ # Verify each rank got the correct config
96
+ for rank_config in scheduler.per_rank_kv_cache_configs:
97
+ assert rank_config.num_blocks == 50 # 100 / 2
98
+
99
+ def test_get_rank_token_counts(self, mock_vllm_config,
100
+ mock_kv_cache_config,
101
+ mock_structured_output_manager):
102
+ """Test _get_rank_token_counts calculates tokens per rank."""
103
+ scheduler = self._create_dp_scheduler_with_mocks(
104
+ mock_vllm_config, mock_kv_cache_config,
105
+ mock_structured_output_manager)
106
+
107
+ # Mock requests on different ranks
108
+ req1 = MagicMock()
109
+ req1.num_tokens = 10
110
+ req2 = MagicMock()
111
+ req2.num_tokens = 20
112
+ req3 = MagicMock()
113
+ req3.num_tokens = 15
114
+
115
+ scheduler.schedulers[0].running = [req1]
116
+ scheduler.schedulers[0].waiting = [req2]
117
+ scheduler.schedulers[1].running = [req3]
118
+ scheduler.schedulers[1].waiting = []
119
+
120
+ rank_tokens = scheduler._get_rank_token_counts()
121
+
122
+ assert rank_tokens[0] == 30 # 10 + 20
123
+ assert rank_tokens[1] == 15
124
+
125
+ def test_find_best_rank_with_cache_hit(self, mock_vllm_config,
126
+ mock_kv_cache_config,
127
+ mock_structured_output_manager):
128
+ """Test _find_best_rank_for_request with cache hit."""
129
+ scheduler = self._create_dp_scheduler_with_mocks(
130
+ mock_vllm_config, mock_kv_cache_config,
131
+ mock_structured_output_manager)
132
+
133
+ # Mock request
134
+ mock_request = MagicMock(spec=Request)
135
+
136
+ # Mock KV cache managers with different cache hits
137
+ scheduler.schedulers[0].kv_cache_manager = MagicMock()
138
+ scheduler.schedulers[
139
+ 0].kv_cache_manager.get_computed_blocks.return_value = (
140
+ [],
141
+ 10,
142
+ ) # 10 cached tokens
143
+
144
+ scheduler.schedulers[1].kv_cache_manager = MagicMock()
145
+ scheduler.schedulers[
146
+ 1].kv_cache_manager.get_computed_blocks.return_value = (
147
+ [],
148
+ 20,
149
+ ) # 20 cached tokens (better)
150
+
151
+ # Mock empty running/waiting queues
152
+ scheduler.schedulers[0].running = []
153
+ scheduler.schedulers[0].waiting = []
154
+ scheduler.schedulers[1].running = []
155
+ scheduler.schedulers[1].waiting = []
156
+
157
+ rank = scheduler._find_best_rank_for_request(mock_request)
158
+
159
+ # Should choose rank 1 with better cache hit
160
+ assert rank == 1
161
+
162
+ def test_find_best_rank_without_cache_hit(self, mock_vllm_config,
163
+ mock_kv_cache_config,
164
+ mock_structured_output_manager):
165
+ """Test _find_best_rank_for_request without cache hit (load balancing)."""
166
+ scheduler = self._create_dp_scheduler_with_mocks(
167
+ mock_vllm_config, mock_kv_cache_config,
168
+ mock_structured_output_manager)
169
+
170
+ # Mock request
171
+ mock_request = MagicMock(spec=Request)
172
+
173
+ # Mock KV cache managers with no cache hits
174
+ scheduler.schedulers[0].kv_cache_manager = MagicMock()
175
+ scheduler.schedulers[
176
+ 0].kv_cache_manager.get_computed_blocks.return_value = ([], 0)
177
+
178
+ scheduler.schedulers[1].kv_cache_manager = MagicMock()
179
+ scheduler.schedulers[
180
+ 1].kv_cache_manager.get_computed_blocks.return_value = ([], 0)
181
+
182
+ # Mock requests with different token counts
183
+ req1 = MagicMock()
184
+ req1.num_tokens = 50
185
+ req2 = MagicMock()
186
+ req2.num_tokens = 30
187
+
188
+ scheduler.schedulers[0].running = [req1]
189
+ scheduler.schedulers[0].waiting = []
190
+ scheduler.schedulers[1].running = [req2]
191
+ scheduler.schedulers[1].waiting = []
192
+
193
+ rank = scheduler._find_best_rank_for_request(mock_request)
194
+
195
+ # Should choose rank 1 with fewer tokens
196
+ assert rank == 1
197
+
198
+ def test_add_request_assigns_to_best_rank(self, mock_vllm_config,
199
+ mock_kv_cache_config,
200
+ mock_structured_output_manager):
201
+ """Test add_request assigns and adds request to best rank."""
202
+ scheduler = self._create_dp_scheduler_with_mocks(
203
+ mock_vllm_config, mock_kv_cache_config,
204
+ mock_structured_output_manager)
205
+
206
+ # Mock the rank selection
207
+ mock_request = MagicMock(spec=Request)
208
+ mock_request.request_id = "req1"
209
+
210
+ # Mock _find_best_rank_for_request to return rank 1
211
+ scheduler._find_best_rank_for_request = MagicMock(return_value=1)
212
+
213
+ # Mock schedulers
214
+ scheduler.schedulers[0].add_request = MagicMock()
215
+ scheduler.schedulers[1].add_request = MagicMock()
216
+
217
+ scheduler.add_request(mock_request)
218
+
219
+ # Verify request was assigned to rank 1
220
+ assert scheduler.assigned_dp_rank["req1"] == 1
221
+ scheduler.schedulers[1].add_request.assert_called_once_with(
222
+ mock_request)
223
+ scheduler.schedulers[0].add_request.assert_not_called()
224
+
225
+ def test_schedule_runs_all_schedulers(self, mock_vllm_config,
226
+ mock_kv_cache_config,
227
+ mock_structured_output_manager):
228
+ """Test schedule runs all schedulers and combines output."""
229
+ scheduler = self._create_dp_scheduler_with_mocks(
230
+ mock_vllm_config, mock_kv_cache_config,
231
+ mock_structured_output_manager)
232
+
233
+ # Mock scheduler outputs
234
+ mock_output_0 = MagicMock(spec=SchedulerOutput)
235
+ mock_output_0.scheduled_new_reqs = []
236
+ mock_output_0.num_scheduled_tokens = {"req1": 10}
237
+ mock_output_0.total_num_scheduled_tokens = 10
238
+ mock_output_0.finished_req_ids = set()
239
+ mock_output_0.scheduled_cached_reqs = CachedRequestData(
240
+ req_ids=[],
241
+ resumed_req_ids=[],
242
+ new_token_ids=[],
243
+ all_token_ids=[],
244
+ new_block_ids=[],
245
+ num_computed_tokens=[],
246
+ num_output_tokens=[],
247
+ )
248
+ mock_output_0.scheduled_spec_decode_tokens = {}
249
+ mock_output_0.scheduled_encoder_inputs = {}
250
+ mock_output_0.num_common_prefix_blocks = []
251
+
252
+ mock_output_1 = MagicMock(spec=SchedulerOutput)
253
+ mock_output_1.scheduled_new_reqs = []
254
+ mock_output_1.num_scheduled_tokens = {"req2": 20}
255
+ mock_output_1.total_num_scheduled_tokens = 20
256
+ mock_output_1.finished_req_ids = set()
257
+ mock_output_1.scheduled_cached_reqs = CachedRequestData(
258
+ req_ids=[],
259
+ resumed_req_ids=[],
260
+ new_token_ids=[],
261
+ all_token_ids=[],
262
+ new_block_ids=[],
263
+ num_computed_tokens=[],
264
+ num_output_tokens=[],
265
+ )
266
+ mock_output_1.scheduled_spec_decode_tokens = {}
267
+ mock_output_1.scheduled_encoder_inputs = {}
268
+ mock_output_1.num_common_prefix_blocks = []
269
+
270
+ scheduler.schedulers[0].schedule = MagicMock(
271
+ return_value=mock_output_0)
272
+ scheduler.schedulers[1].schedule = MagicMock(
273
+ return_value=mock_output_1)
274
+ scheduler.schedulers[0].running = []
275
+ scheduler.schedulers[0].waiting = []
276
+ scheduler.schedulers[1].running = []
277
+ scheduler.schedulers[1].waiting = []
278
+
279
+ # Assign ranks for requests
280
+ scheduler.assigned_dp_rank = {"req1": 0, "req2": 1}
281
+
282
+ output = scheduler.schedule()
283
+
284
+ # Verify combined output
285
+ assert isinstance(output, DPSchedulerOutput)
286
+ assert output.total_num_scheduled_tokens == 30 # 10 + 20
287
+ assert "req1" in output.num_scheduled_tokens
288
+ assert "req2" in output.num_scheduled_tokens
289
+ assert output.assigned_dp_rank == {"req1": 0, "req2": 1}
290
+
291
+ def test_combine_cached_request_data(self, mock_vllm_config,
292
+ mock_kv_cache_config,
293
+ mock_structured_output_manager):
294
+ """Test _combine_cached_request_data combines data from all ranks."""
295
+ mock_scheduler_cls = MagicMock(return_value=MagicMock())
296
+ with patch.object(mock_vllm_config.scheduler_config,
297
+ '_original_scheduler_cls', mock_scheduler_cls):
298
+ scheduler = DPScheduler(
299
+ vllm_config=mock_vllm_config,
300
+ kv_cache_config=mock_kv_cache_config,
301
+ structured_output_manager=mock_structured_output_manager,
302
+ block_size=16,
303
+ )
304
+
305
+ # Create mock rank outputs with different cached request data
306
+ output_0 = MagicMock(spec=SchedulerOutput)
307
+ output_0.scheduled_cached_reqs = CachedRequestData(
308
+ req_ids=["req1"],
309
+ resumed_req_ids=["req1"],
310
+ new_token_ids=[[1, 2, 3]],
311
+ all_token_ids=[[1, 2, 3, 4, 5]],
312
+ new_block_ids=[[10, 11]],
313
+ num_computed_tokens=[5],
314
+ num_output_tokens=[3],
315
+ )
316
+
317
+ output_1 = MagicMock(spec=SchedulerOutput)
318
+ output_1.scheduled_cached_reqs = CachedRequestData(
319
+ req_ids=["req2"],
320
+ resumed_req_ids=[],
321
+ new_token_ids=[[6, 7]],
322
+ all_token_ids=[[6, 7, 8, 9]],
323
+ new_block_ids=[[20, 21]],
324
+ num_computed_tokens=[4],
325
+ num_output_tokens=[2],
326
+ )
327
+
328
+ rank_outputs = [output_0, output_1]
329
+ combined = scheduler._combine_cached_request_data(rank_outputs)
330
+
331
+ # Verify combined data
332
+ assert combined.req_ids == ["req1", "req2"]
333
+ assert combined.resumed_req_ids == ["req1"]
334
+ assert combined.new_token_ids == [[1, 2, 3], [6, 7]]
335
+ assert combined.all_token_ids == [[1, 2, 3, 4, 5], [6, 7, 8, 9]]
336
+ assert combined.new_block_ids == [[10, 11], [20, 21]]
337
+ assert combined.num_computed_tokens == [5, 4]
338
+ assert combined.num_output_tokens == [3, 2]
339
+
340
+ def test_get_grammar_bitmask_with_structured_output(
341
+ self, mock_vllm_config, mock_kv_cache_config,
342
+ mock_structured_output_manager):
343
+ """Test get_grammar_bitmask combines bitmasks from all ranks."""
344
+ scheduler = self._create_dp_scheduler_with_mocks(
345
+ mock_vllm_config, mock_kv_cache_config,
346
+ mock_structured_output_manager)
347
+
348
+ # Create mock scheduler outputs
349
+ mock_output_0 = MagicMock()
350
+ mock_output_1 = MagicMock()
351
+
352
+ # Mock grammar outputs from each rank
353
+ grammar_output_0 = GrammarOutput(
354
+ structured_output_request_ids=["req1"],
355
+ grammar_bitmask=torch.ones((1, 100), dtype=torch.bool),
356
+ )
357
+ grammar_output_1 = GrammarOutput(
358
+ structured_output_request_ids=["req2"],
359
+ grammar_bitmask=torch.ones((1, 100), dtype=torch.bool) * 0,
360
+ )
361
+
362
+ scheduler.schedulers[0].get_grammar_bitmask = MagicMock(
363
+ return_value=grammar_output_0)
364
+ scheduler.schedulers[1].get_grammar_bitmask = MagicMock(
365
+ return_value=grammar_output_1)
366
+
367
+ # Cache scheduler outputs
368
+ scheduler.cached_schedulers_output.append(
369
+ [mock_output_0, mock_output_1])
370
+
371
+ # Create a DPSchedulerOutput
372
+ dp_output = DPSchedulerOutput(
373
+ scheduled_new_reqs=[],
374
+ scheduled_cached_reqs=CachedRequestData(
375
+ req_ids=[],
376
+ resumed_req_ids=[],
377
+ new_token_ids=[],
378
+ all_token_ids=[],
379
+ new_block_ids=[],
380
+ num_computed_tokens=[],
381
+ num_output_tokens=[],
382
+ ),
383
+ num_scheduled_tokens={},
384
+ total_num_scheduled_tokens=0,
385
+ scheduled_spec_decode_tokens={},
386
+ scheduled_encoder_inputs={},
387
+ num_common_prefix_blocks=[],
388
+ finished_req_ids=set(),
389
+ free_encoder_mm_hashes=set(),
390
+ )
391
+
392
+ result = scheduler.get_grammar_bitmask(dp_output)
393
+
394
+ assert result is not None
395
+ assert result.structured_output_request_ids == ["req1", "req2"]
396
+ assert result.grammar_bitmask.shape == (2, 100)
397
+
398
+ def test_get_grammar_bitmask_no_structured_output(
399
+ self, mock_vllm_config, mock_kv_cache_config,
400
+ mock_structured_output_manager):
401
+ """Test get_grammar_bitmask returns None when no structured output."""
402
+ mock_scheduler_cls = MagicMock(return_value=MagicMock())
403
+ with patch.object(mock_vllm_config.scheduler_config,
404
+ '_original_scheduler_cls', mock_scheduler_cls):
405
+ scheduler = DPScheduler(
406
+ vllm_config=mock_vllm_config,
407
+ kv_cache_config=mock_kv_cache_config,
408
+ structured_output_manager=mock_structured_output_manager,
409
+ block_size=16,
410
+ )
411
+
412
+ # Mock schedulers returning None
413
+ scheduler.schedulers[0].get_grammar_bitmask = MagicMock(
414
+ return_value=None)
415
+ scheduler.schedulers[1].get_grammar_bitmask = MagicMock(
416
+ return_value=None)
417
+
418
+ # Cache scheduler outputs
419
+ mock_output_0 = MagicMock()
420
+ mock_output_1 = MagicMock()
421
+ scheduler.cached_schedulers_output.append(
422
+ [mock_output_0, mock_output_1])
423
+
424
+ dp_output = DPSchedulerOutput(
425
+ scheduled_new_reqs=[],
426
+ scheduled_cached_reqs=CachedRequestData(
427
+ req_ids=[],
428
+ resumed_req_ids=[],
429
+ new_token_ids=[],
430
+ all_token_ids=[],
431
+ new_block_ids=[],
432
+ num_computed_tokens=[],
433
+ num_output_tokens=[],
434
+ ),
435
+ num_scheduled_tokens={},
436
+ total_num_scheduled_tokens=0,
437
+ scheduled_spec_decode_tokens={},
438
+ scheduled_encoder_inputs={},
439
+ num_common_prefix_blocks=[],
440
+ finished_req_ids=set(),
441
+ free_encoder_mm_hashes=set(),
442
+ )
443
+
444
+ result = scheduler.get_grammar_bitmask(dp_output)
445
+ assert result is None
446
+
447
+ def test_update_from_output_routes_to_schedulers(
448
+ self, mock_vllm_config, mock_kv_cache_config,
449
+ mock_structured_output_manager):
450
+ """Test update_from_output splits output and updates each scheduler."""
451
+ mock_scheduler_cls = MagicMock(return_value=MagicMock())
452
+ with patch.object(mock_vllm_config.scheduler_config,
453
+ '_original_scheduler_cls', mock_scheduler_cls):
454
+ scheduler = DPScheduler(
455
+ vllm_config=mock_vllm_config,
456
+ kv_cache_config=mock_kv_cache_config,
457
+ structured_output_manager=mock_structured_output_manager,
458
+ block_size=16,
459
+ )
460
+
461
+ # Setup assigned ranks
462
+ scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
463
+
464
+ # Create DPSchedulerOutput
465
+ dp_output = DPSchedulerOutput(
466
+ scheduled_new_reqs=[],
467
+ scheduled_cached_reqs=CachedRequestData(
468
+ req_ids=[],
469
+ resumed_req_ids=[],
470
+ new_token_ids=[],
471
+ all_token_ids=[],
472
+ new_block_ids=[],
473
+ num_computed_tokens=[],
474
+ num_output_tokens=[],
475
+ ),
476
+ num_scheduled_tokens={
477
+ "req1": 10,
478
+ "req2": 20,
479
+ "req3": 15
480
+ },
481
+ total_num_scheduled_tokens=45,
482
+ scheduled_spec_decode_tokens={},
483
+ scheduled_encoder_inputs={},
484
+ num_common_prefix_blocks=[],
485
+ finished_req_ids={"req3"}, # req3 finished
486
+ free_encoder_mm_hashes=set(),
487
+ assigned_dp_rank={
488
+ "req1": 0,
489
+ "req2": 1,
490
+ "req3": 0
491
+ },
492
+ )
493
+
494
+ # Create mock model runner output
495
+ model_output = ModelRunnerOutput(
496
+ req_ids=["req1", "req2", "req3"],
497
+ req_id_to_index={
498
+ "req1": 0,
499
+ "req2": 1,
500
+ "req3": 2
501
+ },
502
+ sampled_token_ids=torch.tensor([100, 200, 300]),
503
+ logprobs=None,
504
+ prompt_logprobs_dict={},
505
+ pooler_output=None,
506
+ num_nans_in_logits=0,
507
+ kv_connector_output=None,
508
+ )
509
+
510
+ # Mock rank scheduler outputs (cached from schedule call)
511
+ rank_output_0 = MagicMock()
512
+ rank_output_1 = MagicMock()
513
+ scheduler.cached_schedulers_output.append(
514
+ [rank_output_0, rank_output_1])
515
+
516
+ # Mock scheduler update_from_output
517
+ engine_output_0 = EngineCoreOutputs()
518
+ engine_output_0.engine_index = 0
519
+ engine_output_0.outputs = []
520
+ engine_output_0.finished_requests = {"req3"}
521
+
522
+ engine_output_1 = EngineCoreOutputs()
523
+ engine_output_1.engine_index = 0
524
+ engine_output_1.outputs = []
525
+ engine_output_1.finished_requests = set()
526
+
527
+ scheduler.schedulers[0].update_from_output = MagicMock(
528
+ return_value={0: engine_output_0})
529
+ scheduler.schedulers[1].update_from_output = MagicMock(
530
+ return_value={0: engine_output_1})
531
+
532
+ # Mock make_stats
533
+ scheduler.make_stats = MagicMock(return_value=None)
534
+
535
+ _ = scheduler.update_from_output(dp_output, model_output)
536
+
537
+ # Verify schedulers were updated
538
+ assert scheduler.schedulers[0].update_from_output.called
539
+ assert scheduler.schedulers[1].update_from_output.called
540
+
541
+ # Verify finished request was cleaned up
542
+ assert "req3" not in scheduler.assigned_dp_rank
543
+ assert "req1" in scheduler.assigned_dp_rank
544
+ assert "req2" in scheduler.assigned_dp_rank
545
+
546
+ def test_split_model_output_by_rank(self, mock_vllm_config,
547
+ mock_kv_cache_config,
548
+ mock_structured_output_manager):
549
+ """Test _split_model_output_by_rank distributes output correctly."""
550
+ mock_scheduler_cls = MagicMock(return_value=MagicMock())
551
+ with patch.object(mock_vllm_config.scheduler_config,
552
+ '_original_scheduler_cls', mock_scheduler_cls):
553
+ scheduler = DPScheduler(
554
+ vllm_config=mock_vllm_config,
555
+ kv_cache_config=mock_kv_cache_config,
556
+ structured_output_manager=mock_structured_output_manager,
557
+ block_size=16,
558
+ )
559
+
560
+ # Setup assigned ranks
561
+ scheduler.assigned_dp_rank = {
562
+ "req1": 0,
563
+ "req2": 1,
564
+ "req3": 0,
565
+ "req4": 1
566
+ }
567
+
568
+ # Create global model output
569
+ global_output = ModelRunnerOutput(
570
+ req_ids=["req1", "req2", "req3", "req4"],
571
+ req_id_to_index={
572
+ "req1": 0,
573
+ "req2": 1,
574
+ "req3": 2,
575
+ "req4": 3
576
+ },
577
+ sampled_token_ids=torch.tensor([100, 200, 300, 400]),
578
+ logprobs=None,
579
+ prompt_logprobs_dict={},
580
+ pooler_output=None,
581
+ num_nans_in_logits=0,
582
+ kv_connector_output=None,
583
+ )
584
+
585
+ rank_outputs = scheduler._split_model_output_by_rank(global_output)
586
+
587
+ # Verify split outputs
588
+ assert len(rank_outputs) == 2
589
+ assert rank_outputs[0].req_ids == ["req1", "req3"]
590
+ assert rank_outputs[1].req_ids == ["req2", "req4"]
591
+
592
+ def test_cleanup_finished_requests(self, mock_vllm_config,
593
+ mock_kv_cache_config,
594
+ mock_structured_output_manager):
595
+ """Test _cleanup_finished_requests removes finished requests."""
596
+ mock_scheduler_cls = MagicMock(return_value=MagicMock())
597
+ with patch.object(mock_vllm_config.scheduler_config,
598
+ '_original_scheduler_cls', mock_scheduler_cls):
599
+ scheduler = DPScheduler(
600
+ vllm_config=mock_vllm_config,
601
+ kv_cache_config=mock_kv_cache_config,
602
+ structured_output_manager=mock_structured_output_manager,
603
+ block_size=16,
604
+ )
605
+
606
+ # Setup assigned ranks
607
+ scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
608
+
609
+ # Clean up finished requests
610
+ scheduler._cleanup_finished_requests({"req1", "req3"})
611
+
612
+ # Verify cleanup
613
+ assert "req1" not in scheduler.assigned_dp_rank
614
+ assert "req3" not in scheduler.assigned_dp_rank
615
+ assert "req2" in scheduler.assigned_dp_rank
616
+
617
+ def test_finish_requests_single_and_multiple(
618
+ self, mock_vllm_config, mock_kv_cache_config,
619
+ mock_structured_output_manager):
620
+ """Test finish_requests handles single string and list."""
621
+ scheduler = self._create_dp_scheduler_with_mocks(
622
+ mock_vllm_config, mock_kv_cache_config,
623
+ mock_structured_output_manager)
624
+
625
+ # Setup assigned ranks
626
+ scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
627
+
628
+ # Mock scheduler finish_requests
629
+ scheduler.schedulers[0].finish_requests = MagicMock()
630
+ scheduler.schedulers[1].finish_requests = MagicMock()
631
+
632
+ # Test with single string
633
+ scheduler.finish_requests("req1", finished_status="completed")
634
+ scheduler.schedulers[0].finish_requests.assert_called_with(["req1"],
635
+ "completed")
636
+
637
+ # Test with list
638
+ scheduler.schedulers[0].finish_requests.reset_mock()
639
+ scheduler.schedulers[1].finish_requests.reset_mock()
640
+
641
+ scheduler.finish_requests(["req1", "req2"],
642
+ finished_status="completed")
643
+ scheduler.schedulers[0].finish_requests.assert_called_once_with(
644
+ ["req1"], "completed")
645
+ scheduler.schedulers[1].finish_requests.assert_called_once_with(
646
+ ["req2"], "completed")
647
+
648
+ def test_get_num_unfinished_requests(self, mock_vllm_config,
649
+ mock_kv_cache_config,
650
+ mock_structured_output_manager):
651
+ """Test get_num_unfinished_requests aggregates across ranks."""
652
+ scheduler = self._create_dp_scheduler_with_mocks(
653
+ mock_vllm_config, mock_kv_cache_config,
654
+ mock_structured_output_manager)
655
+
656
+ scheduler.schedulers[0].get_num_unfinished_requests = MagicMock(
657
+ return_value=5)
658
+ scheduler.schedulers[1].get_num_unfinished_requests = MagicMock(
659
+ return_value=3)
660
+
661
+ total = scheduler.get_num_unfinished_requests()
662
+ assert total == 8
663
+
664
+ def test_has_finished_requests(self, mock_vllm_config,
665
+ mock_kv_cache_config,
666
+ mock_structured_output_manager):
667
+ """Test has_finished_requests checks all ranks."""
668
+ mock_scheduler_cls = MagicMock(return_value=MagicMock())
669
+ with patch.object(mock_vllm_config.scheduler_config,
670
+ '_original_scheduler_cls', mock_scheduler_cls):
671
+ scheduler = DPScheduler(
672
+ vllm_config=mock_vllm_config,
673
+ kv_cache_config=mock_kv_cache_config,
674
+ structured_output_manager=mock_structured_output_manager,
675
+ block_size=16,
676
+ )
677
+
678
+ # Test when one rank has finished requests
679
+ scheduler.schedulers[0].has_finished_requests = MagicMock(
680
+ return_value=False)
681
+ scheduler.schedulers[1].has_finished_requests = MagicMock(
682
+ return_value=True)
683
+
684
+ assert scheduler.has_finished_requests() is True
685
+
686
+ # Test when no rank has finished requests
687
+ scheduler.schedulers[1].has_finished_requests = MagicMock(
688
+ return_value=False)
689
+ assert scheduler.has_finished_requests() is False
690
+
691
+ def test_get_request_counts(self, mock_vllm_config, mock_kv_cache_config,
692
+ mock_structured_output_manager):
693
+ """Test get_request_counts aggregates across ranks."""
694
+ scheduler = self._create_dp_scheduler_with_mocks(
695
+ mock_vllm_config, mock_kv_cache_config,
696
+ mock_structured_output_manager)
697
+
698
+ # Mock running and waiting queues
699
+ scheduler.schedulers[0].running = [MagicMock(),
700
+ MagicMock()] # 2 running
701
+ scheduler.schedulers[0].waiting = [MagicMock()] # 1 waiting
702
+ scheduler.schedulers[1].running = [MagicMock()] # 1 running
703
+ scheduler.schedulers[1].waiting = [
704
+ MagicMock(), MagicMock(), MagicMock()
705
+ ] # 3 waiting
706
+
707
+ running, waiting = scheduler.get_request_counts()
708
+
709
+ assert running == 3 # 2 + 1
710
+ assert waiting == 4 # 1 + 3
711
+
712
+ def test_reset_prefix_cache(self, mock_vllm_config, mock_kv_cache_config,
713
+ mock_structured_output_manager):
714
+ """Test reset_prefix_cache resets all ranks."""
715
+ scheduler = self._create_dp_scheduler_with_mocks(
716
+ mock_vllm_config, mock_kv_cache_config,
717
+ mock_structured_output_manager)
718
+
719
+ scheduler.schedulers[0].reset_prefix_cache = MagicMock(
720
+ return_value=True)
721
+ scheduler.schedulers[1].reset_prefix_cache = MagicMock(
722
+ return_value=True)
723
+
724
+ result = scheduler.reset_prefix_cache()
725
+
726
+ assert result is True
727
+ scheduler.schedulers[0].reset_prefix_cache.assert_called_once()
728
+ scheduler.schedulers[1].reset_prefix_cache.assert_called_once()
729
+
730
+ def test_make_stats_with_logging_enabled(self, mock_vllm_config,
731
+ mock_kv_cache_config,
732
+ mock_structured_output_manager):
733
+ """Test make_stats aggregates stats from all ranks."""
734
+ scheduler = self._create_dp_scheduler_with_mocks(
735
+ mock_vllm_config,
736
+ mock_kv_cache_config,
737
+ mock_structured_output_manager,
738
+ log_stats=True)
739
+
740
+ # Create mock stats for each rank
741
+ stats_0 = SchedulerStats(
742
+ num_running_reqs=3,
743
+ num_waiting_reqs=2,
744
+ kv_cache_usage=0.5,
745
+ prefix_cache_stats=PrefixCacheStats(reset=False,
746
+ requests=10,
747
+ queries=8,
748
+ hits=5),
749
+ connector_prefix_cache_stats=PrefixCacheStats(reset=False,
750
+ requests=5,
751
+ queries=4,
752
+ hits=2),
753
+ spec_decoding_stats=None,
754
+ kv_connector_stats=None,
755
+ )
756
+
757
+ stats_1 = SchedulerStats(
758
+ num_running_reqs=4,
759
+ num_waiting_reqs=1,
760
+ kv_cache_usage=0.7,
761
+ prefix_cache_stats=PrefixCacheStats(reset=False,
762
+ requests=15,
763
+ queries=12,
764
+ hits=8),
765
+ connector_prefix_cache_stats=PrefixCacheStats(reset=False,
766
+ requests=6,
767
+ queries=5,
768
+ hits=3),
769
+ spec_decoding_stats=None,
770
+ kv_connector_stats=None,
771
+ )
772
+
773
+ scheduler.schedulers[0].make_stats = MagicMock(return_value=stats_0)
774
+ scheduler.schedulers[1].make_stats = MagicMock(return_value=stats_1)
775
+
776
+ combined_stats = scheduler.make_stats()
777
+
778
+ # Verify aggregated stats
779
+ assert combined_stats.num_running_reqs == 7 # 3 + 4
780
+ assert combined_stats.num_waiting_reqs == 3 # 2 + 1
781
+ assert combined_stats.kv_cache_usage == 0.6 # (0.5 + 0.7) / 2
782
+
783
+ # Verify prefix cache stats
784
+ assert combined_stats.prefix_cache_stats.requests == 25 # 10 + 15
785
+ assert combined_stats.prefix_cache_stats.queries == 20 # 8 + 12
786
+ assert combined_stats.prefix_cache_stats.hits == 13 # 5 + 8
787
+
788
+ # Verify connector prefix cache stats
789
+ assert combined_stats.connector_prefix_cache_stats.requests == 11 # 5 + 6
790
+ assert combined_stats.connector_prefix_cache_stats.queries == 9 # 4 + 5
791
+ assert combined_stats.connector_prefix_cache_stats.hits == 5 # 2 + 3
792
+
793
+ def test_make_stats_with_logging_disabled(self, mock_vllm_config,
794
+ mock_kv_cache_config,
795
+ mock_structured_output_manager):
796
+ """Test make_stats returns None when logging is disabled."""
797
+ mock_scheduler_cls = MagicMock(return_value=MagicMock())
798
+ with patch.object(mock_vllm_config.scheduler_config,
799
+ '_original_scheduler_cls', mock_scheduler_cls):
800
+ scheduler = DPScheduler(
801
+ vllm_config=mock_vllm_config,
802
+ kv_cache_config=mock_kv_cache_config,
803
+ structured_output_manager=mock_structured_output_manager,
804
+ block_size=16,
805
+ log_stats=False,
806
+ )
807
+
808
+ stats = scheduler.make_stats()
809
+ assert stats is None
810
+
811
+ def test_update_draft_token_ids(self, mock_vllm_config,
812
+ mock_kv_cache_config,
813
+ mock_structured_output_manager):
814
+ """Test update_draft_token_ids routes tokens to correct ranks."""
815
+ scheduler = self._create_dp_scheduler_with_mocks(
816
+ mock_vllm_config, mock_kv_cache_config,
817
+ mock_structured_output_manager)
818
+
819
+ # Setup assigned ranks
820
+ scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
821
+
822
+ # Create mock draft token IDs
823
+ draft_token_ids = MagicMock()
824
+ draft_token_ids.req_ids = ["req1", "req2", "req3"]
825
+ draft_token_ids.draft_token_ids = [
826
+ [101, 102, 103],
827
+ [201, 202],
828
+ [301, 302, 303, 304],
829
+ ]
830
+
831
+ # Mock scheduler update_draft_token_ids
832
+ scheduler.schedulers[0].update_draft_token_ids = MagicMock()
833
+ scheduler.schedulers[1].update_draft_token_ids = MagicMock()
834
+
835
+ scheduler.update_draft_token_ids(draft_token_ids)
836
+
837
+ # Verify each scheduler received correct tokens
838
+ assert scheduler.schedulers[0].update_draft_token_ids.called
839
+ assert scheduler.schedulers[1].update_draft_token_ids.called
840
+
841
+ # Check rank 0 got req1 and req3
842
+ call_args_0 = scheduler.schedulers[0].update_draft_token_ids.call_args[
843
+ 0][0]
844
+ assert "req1" in call_args_0.req_ids
845
+ assert "req3" in call_args_0.req_ids
846
+
847
+ # Check rank 1 got req2
848
+ call_args_1 = scheduler.schedulers[1].update_draft_token_ids.call_args[
849
+ 0][0]
850
+ assert "req2" in call_args_1.req_ids
851
+
852
+ def test_shutdown(self, mock_vllm_config, mock_kv_cache_config,
853
+ mock_structured_output_manager):
854
+ """Test shutdown calls shutdown on all schedulers."""
855
+ scheduler = self._create_dp_scheduler_with_mocks(
856
+ mock_vllm_config, mock_kv_cache_config,
857
+ mock_structured_output_manager)
858
+
859
+ scheduler.schedulers[0].shutdown = MagicMock()
860
+ scheduler.schedulers[1].shutdown = MagicMock()
861
+
862
+ scheduler.shutdown()
863
+
864
+ scheduler.schedulers[0].shutdown.assert_called_once()
865
+ scheduler.schedulers[1].shutdown.assert_called_once()
866
+
867
+
868
+ class TestUpdateVllmConfigForDPScheduler:
869
+ """Test the update_vllm_config_for_dp_scheduler function."""
870
+
871
+ def test_update_config_with_dp_size_greater_than_one(self):
872
+ """Test Config is updated when DP size > 1."""
873
+ mock_config = MagicMock()
874
+ mock_config.sharding_config.total_dp_size = 2
875
+ mock_config.scheduler_config._original_scheduler_cls = None
876
+ mock_config.scheduler_config.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
877
+ mock_config.scheduler_config.async_scheduling = False
878
+
879
+ update_vllm_config_for_dp_scheduler(mock_config)
880
+
881
+ # Verify config was updated
882
+ assert mock_config.scheduler_config._original_scheduler_cls == Scheduler
883
+ assert mock_config.scheduler_config.scheduler_cls == DPScheduler
884
+
885
+ def test_update_config_with_dp_size_one(self):
886
+ """Test that config is NOT updated when DP size == 1."""
887
+ mock_config = MagicMock()
888
+ mock_config.sharding_config.total_dp_size = 1
889
+ original_scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
890
+ mock_config.scheduler_config.scheduler_cls = original_scheduler_cls
891
+
892
+ update_vllm_config_for_dp_scheduler(mock_config)
893
+
894
+ # Verify config was NOT changed
895
+ assert mock_config.scheduler_config.scheduler_cls == original_scheduler_cls
896
+
897
+
898
+ if __name__ == "__main__":
899
+ pytest.main([__file__, "-v"])