tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,724 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock, patch
16
+
17
+ import pytest
18
+ from vllm.config import VllmConfig
19
+ from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
20
+ from vllm.v1.core.sched.scheduler import Scheduler
21
+ from vllm.v1.kv_cache_interface import KVCacheConfig
22
+ from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
23
+ from vllm.v1.request import Request
24
+
25
+ from tpu_inference.core.sched.dp_scheduler import (
26
+ DPScheduler, DPSchedulerOutput, SchedulerCommand,
27
+ update_vllm_config_for_dp_scheduler)
28
+
29
+
30
+ class TestDPScheduler:
31
+
32
+ @pytest.fixture
33
+ def mock_vllm_config(self):
34
+ """Create a mock VllmConfig for testing."""
35
+ config = MagicMock(spec=VllmConfig)
36
+ config.sharding_config = MagicMock()
37
+ config.sharding_config.total_dp_size = 2
38
+ config.scheduler_config = MagicMock()
39
+ config.scheduler_config._original_scheduler_cls = Scheduler
40
+ config.scheduler_config.max_num_seqs = 8
41
+ config.scheduler_config.max_num_batched_tokens = 1024
42
+ config.scheduler_config.async_scheduling = False
43
+ return config
44
+
45
+ @pytest.fixture
46
+ def mock_kv_cache_config(self):
47
+ """Create a mock KVCacheConfig for testing."""
48
+ config = MagicMock(spec=KVCacheConfig)
49
+ config.num_blocks = 100
50
+ return config
51
+
52
+ @pytest.fixture
53
+ def mock_structured_output_manager(self):
54
+ """Create a mock StructuredOutputManager."""
55
+ return MagicMock()
56
+
57
+ def test_init_creates_worker_processes(
58
+ self,
59
+ mock_vllm_config,
60
+ mock_kv_cache_config,
61
+ mock_structured_output_manager,
62
+ ):
63
+ """Test initialization creates worker processes for each DP rank."""
64
+ with patch(
65
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
66
+ ):
67
+ with patch('multiprocessing.get_context') as mock_get_context:
68
+ # Setup mock context
69
+ mock_ctx = MagicMock()
70
+ mock_process = MagicMock()
71
+ mock_queue = MagicMock()
72
+
73
+ mock_ctx.Queue = MagicMock(return_value=mock_queue)
74
+ mock_ctx.Process = MagicMock(return_value=mock_process)
75
+ mock_get_context.return_value = mock_ctx
76
+
77
+ scheduler = DPScheduler(
78
+ vllm_config=mock_vllm_config,
79
+ kv_cache_config=mock_kv_cache_config,
80
+ structured_output_manager=mock_structured_output_manager,
81
+ block_size=16,
82
+ log_stats=True,
83
+ )
84
+
85
+ # Verify processes and queues were created
86
+ assert scheduler.dp_size == 2
87
+ assert len(scheduler.processes) == 2
88
+ assert len(scheduler.input_queues) == 2
89
+ assert len(scheduler.output_queues) == 2
90
+ assert scheduler.log_stats is True
91
+ assert len(scheduler.per_rank_kv_cache_configs) == 2
92
+
93
+ # Verify each rank got the correct config
94
+ for rank_config in scheduler.per_rank_kv_cache_configs:
95
+ assert rank_config.num_blocks == 50 # 100 / 2
96
+
97
+ # Verify processes were started
98
+ assert mock_process.start.call_count == 2
99
+
100
+ def test_get_rank_token_counts(self, mock_vllm_config,
101
+ mock_kv_cache_config,
102
+ mock_structured_output_manager):
103
+ """Test _get_rank_token_counts queries workers and aggregates tokens."""
104
+ with patch(
105
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
106
+ ):
107
+ with patch('multiprocessing.get_context'):
108
+ scheduler = DPScheduler(
109
+ vllm_config=mock_vllm_config,
110
+ kv_cache_config=mock_kv_cache_config,
111
+ structured_output_manager=mock_structured_output_manager,
112
+ block_size=16,
113
+ )
114
+
115
+ # Mock the queues
116
+ scheduler.input_queues = [MagicMock(), MagicMock()]
117
+ scheduler.output_queues = [MagicMock(), MagicMock()]
118
+
119
+ # Mock responses from workers
120
+ scheduler.output_queues[0].get = MagicMock(return_value=30)
121
+ scheduler.output_queues[1].get = MagicMock(return_value=15)
122
+
123
+ rank_tokens = scheduler._get_rank_token_counts()
124
+
125
+ # Verify correct commands were sent
126
+ scheduler.input_queues[0].put.assert_called_with(
127
+ (SchedulerCommand.GET_TOKEN_COUNT, None))
128
+ scheduler.input_queues[1].put.assert_called_with(
129
+ (SchedulerCommand.GET_TOKEN_COUNT, None))
130
+
131
+ assert rank_tokens[0] == 30
132
+ assert rank_tokens[1] == 15
133
+
134
+ def test_find_best_rank_with_cache_hit(self, mock_vllm_config,
135
+ mock_kv_cache_config,
136
+ mock_structured_output_manager):
137
+ """Test _find_best_rank_for_request prefers cache hits."""
138
+ with patch(
139
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
140
+ ):
141
+ with patch('multiprocessing.get_context'):
142
+ scheduler = DPScheduler(
143
+ vllm_config=mock_vllm_config,
144
+ kv_cache_config=mock_kv_cache_config,
145
+ structured_output_manager=mock_structured_output_manager,
146
+ block_size=16,
147
+ )
148
+
149
+ mock_request = MagicMock(spec=Request)
150
+
151
+ # Mock the queues
152
+ scheduler.input_queues = [MagicMock(), MagicMock()]
153
+ scheduler.output_queues = [MagicMock(), MagicMock()]
154
+
155
+ # Track call counts for proper sequencing
156
+ call_sequence = [100, 50, ([], 10), ([], 25)]
157
+
158
+ # Both queues use the same sequence
159
+ for q in scheduler.output_queues:
160
+ q.get = MagicMock(
161
+ side_effect=lambda timeout=None: call_sequence[len([
162
+ c for c in scheduler.output_queues if c.get.called
163
+ ])])
164
+
165
+ # Simpler mock setup
166
+ responses_0 = [100, ([], 10)]
167
+ responses_1 = [50, ([], 25)]
168
+ scheduler.output_queues[0].get = MagicMock(
169
+ side_effect=responses_0)
170
+ scheduler.output_queues[1].get = MagicMock(
171
+ side_effect=responses_1)
172
+
173
+ rank = scheduler._find_best_rank_for_request(mock_request)
174
+
175
+ # Should prefer rank with better cache hit
176
+ assert rank == 1
177
+
178
+ def test_find_best_rank_without_cache_hit(self, mock_vllm_config,
179
+ mock_kv_cache_config,
180
+ mock_structured_output_manager):
181
+ """Test _find_best_rank_for_request uses load balancing without cache hit."""
182
+ with patch(
183
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
184
+ ):
185
+ with patch('multiprocessing.get_context'):
186
+ scheduler = DPScheduler(
187
+ vllm_config=mock_vllm_config,
188
+ kv_cache_config=mock_kv_cache_config,
189
+ structured_output_manager=mock_structured_output_manager,
190
+ block_size=16,
191
+ )
192
+
193
+ mock_request = MagicMock(spec=Request)
194
+
195
+ # Mock the queues
196
+ scheduler.input_queues = [MagicMock(), MagicMock()]
197
+ scheduler.output_queues = [MagicMock(), MagicMock()]
198
+
199
+ # No cache hits - both return 0
200
+ scheduler.output_queues[0].get = MagicMock(
201
+ side_effect=[100, ([], 0)])
202
+ scheduler.output_queues[1].get = MagicMock(
203
+ side_effect=[50, ([], 0)])
204
+
205
+ rank = scheduler._find_best_rank_for_request(mock_request)
206
+
207
+ # Should choose rank with fewer tokens (rank 1)
208
+ assert rank == 1
209
+
210
+ def test_add_request_assigns_to_best_rank(self, mock_vllm_config,
211
+ mock_kv_cache_config,
212
+ mock_structured_output_manager):
213
+ """Test add_request assigns request to best rank."""
214
+ with patch(
215
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
216
+ ):
217
+ with patch('multiprocessing.get_context'):
218
+ scheduler = DPScheduler(
219
+ vllm_config=mock_vllm_config,
220
+ kv_cache_config=mock_kv_cache_config,
221
+ structured_output_manager=mock_structured_output_manager,
222
+ block_size=16,
223
+ )
224
+
225
+ mock_request = MagicMock(spec=Request)
226
+ mock_request.request_id = "req1"
227
+
228
+ # Mock the queues
229
+ scheduler.input_queues = [MagicMock(), MagicMock()]
230
+ scheduler.output_queues = [MagicMock(), MagicMock()]
231
+ scheduler.output_queues[0].get = MagicMock()
232
+ scheduler.output_queues[1].get = MagicMock()
233
+
234
+ # Mock _find_best_rank_for_request to return rank 1
235
+ scheduler._find_best_rank_for_request = MagicMock(
236
+ return_value=1)
237
+
238
+ scheduler.add_request(mock_request)
239
+
240
+ # Verify request was assigned to rank 1
241
+ assert scheduler.assigned_dp_rank["req1"] == 1
242
+
243
+ # Verify ADD_REQUEST command was sent to rank 1
244
+ scheduler.input_queues[1].put.assert_called_with(
245
+ (SchedulerCommand.ADD_REQUEST, mock_request))
246
+
247
+ # Verify we waited for completion
248
+ scheduler.output_queues[1].get.assert_called_once()
249
+
250
+ def test_schedule_sends_commands_and_combines_output(
251
+ self, mock_vllm_config, mock_kv_cache_config,
252
+ mock_structured_output_manager):
253
+ """Test schedule sends SCHEDULE command to all workers and combines output."""
254
+ with patch(
255
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
256
+ ):
257
+ with patch('multiprocessing.get_context'):
258
+ scheduler = DPScheduler(
259
+ vllm_config=mock_vllm_config,
260
+ kv_cache_config=mock_kv_cache_config,
261
+ structured_output_manager=mock_structured_output_manager,
262
+ block_size=16,
263
+ )
264
+
265
+ # Mock the queues
266
+ scheduler.input_queues = [MagicMock(), MagicMock()]
267
+ scheduler.output_queues = [MagicMock(), MagicMock()]
268
+
269
+ # Create mock scheduler outputs
270
+ mock_output_0 = MagicMock(spec=SchedulerOutput)
271
+ mock_output_0.scheduled_new_reqs = []
272
+ mock_output_0.num_scheduled_tokens = {"req1": 10}
273
+ mock_output_0.total_num_scheduled_tokens = 10
274
+ mock_output_0.finished_req_ids = set()
275
+ mock_output_0.scheduled_cached_reqs = CachedRequestData(
276
+ req_ids=[],
277
+ resumed_req_ids=[],
278
+ new_token_ids=[],
279
+ all_token_ids=[],
280
+ new_block_ids=[],
281
+ num_computed_tokens=[],
282
+ num_output_tokens=[],
283
+ )
284
+ mock_output_0.scheduled_spec_decode_tokens = {}
285
+ mock_output_0.scheduled_encoder_inputs = {}
286
+ mock_output_0.num_common_prefix_blocks = []
287
+
288
+ mock_output_1 = MagicMock(spec=SchedulerOutput)
289
+ mock_output_1.scheduled_new_reqs = []
290
+ mock_output_1.num_scheduled_tokens = {"req2": 20}
291
+ mock_output_1.total_num_scheduled_tokens = 20
292
+ mock_output_1.finished_req_ids = set()
293
+ mock_output_1.scheduled_cached_reqs = CachedRequestData(
294
+ req_ids=[],
295
+ resumed_req_ids=[],
296
+ new_token_ids=[],
297
+ all_token_ids=[],
298
+ new_block_ids=[],
299
+ num_computed_tokens=[],
300
+ num_output_tokens=[],
301
+ )
302
+ mock_output_1.scheduled_spec_decode_tokens = {}
303
+ mock_output_1.scheduled_encoder_inputs = {}
304
+ mock_output_1.num_common_prefix_blocks = []
305
+
306
+ # Setup mock queue responses
307
+ scheduler.output_queues[0].get = MagicMock(
308
+ return_value=mock_output_0)
309
+ scheduler.output_queues[1].get = MagicMock(
310
+ return_value=mock_output_1)
311
+
312
+ # Setup assigned ranks
313
+ scheduler.assigned_dp_rank = {"req1": 0, "req2": 1}
314
+
315
+ output = scheduler.schedule()
316
+
317
+ # Verify SCHEDULE commands were sent
318
+ scheduler.input_queues[0].put.assert_called_with(
319
+ (SchedulerCommand.SCHEDULE, None))
320
+ scheduler.input_queues[1].put.assert_called_with(
321
+ (SchedulerCommand.SCHEDULE, None))
322
+
323
+ # Verify combined output
324
+ assert isinstance(output, DPSchedulerOutput)
325
+ assert output.total_num_scheduled_tokens == 30
326
+ assert "req1" in output.num_scheduled_tokens
327
+ assert "req2" in output.num_scheduled_tokens
328
+ assert output.assigned_dp_rank == {"req1": 0, "req2": 1}
329
+
330
+ def test_combine_cached_request_data(self, mock_vllm_config,
331
+ mock_kv_cache_config,
332
+ mock_structured_output_manager):
333
+ """Test _combine_cached_request_data combines data from all ranks."""
334
+ with patch(
335
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
336
+ ):
337
+ with patch('multiprocessing.get_context'):
338
+ scheduler = DPScheduler(
339
+ vllm_config=mock_vllm_config,
340
+ kv_cache_config=mock_kv_cache_config,
341
+ structured_output_manager=mock_structured_output_manager,
342
+ block_size=16,
343
+ )
344
+
345
+ # Create mock rank outputs
346
+ output_0 = MagicMock(spec=SchedulerOutput)
347
+ output_0.scheduled_cached_reqs = CachedRequestData(
348
+ req_ids=["req1"],
349
+ resumed_req_ids=["req1"],
350
+ new_token_ids=[[1, 2, 3]],
351
+ all_token_ids=[[1, 2, 3, 4, 5]],
352
+ new_block_ids=[[10, 11]],
353
+ num_computed_tokens=[5],
354
+ num_output_tokens=[3],
355
+ )
356
+
357
+ output_1 = MagicMock(spec=SchedulerOutput)
358
+ output_1.scheduled_cached_reqs = CachedRequestData(
359
+ req_ids=["req2"],
360
+ resumed_req_ids=[],
361
+ new_token_ids=[[6, 7]],
362
+ all_token_ids=[[6, 7, 8, 9]],
363
+ new_block_ids=[[20, 21]],
364
+ num_computed_tokens=[4],
365
+ num_output_tokens=[2],
366
+ )
367
+
368
+ combined = scheduler._combine_cached_request_data(
369
+ [output_0, output_1])
370
+
371
+ # Verify combined data
372
+ assert combined.req_ids == ["req1", "req2"]
373
+ assert combined.resumed_req_ids == ["req1"]
374
+ assert combined.new_token_ids == [[1, 2, 3], [6, 7]]
375
+ assert combined.num_computed_tokens == [5, 4]
376
+ assert combined.num_output_tokens == [3, 2]
377
+
378
+ def test_finish_requests_routes_to_workers(self, mock_vllm_config,
379
+ mock_kv_cache_config,
380
+ mock_structured_output_manager):
381
+ """Test finish_requests sends FINISH_REQUESTS command to appropriate workers."""
382
+ with patch(
383
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
384
+ ):
385
+ with patch('multiprocessing.get_context'):
386
+ scheduler = DPScheduler(
387
+ vllm_config=mock_vllm_config,
388
+ kv_cache_config=mock_kv_cache_config,
389
+ structured_output_manager=mock_structured_output_manager,
390
+ block_size=16,
391
+ )
392
+
393
+ scheduler.input_queues = [MagicMock(), MagicMock()]
394
+ scheduler.output_queues = [MagicMock(), MagicMock()]
395
+ scheduler.output_queues[0].get = MagicMock()
396
+ scheduler.output_queues[1].get = MagicMock()
397
+
398
+ scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
399
+
400
+ # Test with list of requests
401
+ scheduler.finish_requests(["req1", "req2"],
402
+ finished_status="completed")
403
+
404
+ # Verify FINISH_REQUESTS commands were sent to correct ranks
405
+ scheduler.input_queues[0].put.assert_called()
406
+ scheduler.input_queues[1].put.assert_called()
407
+
408
+ def test_get_num_unfinished_requests(self, mock_vllm_config,
409
+ mock_kv_cache_config,
410
+ mock_structured_output_manager):
411
+ """Test get_num_unfinished_requests queries all workers."""
412
+ with patch(
413
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
414
+ ):
415
+ with patch('multiprocessing.get_context'):
416
+ scheduler = DPScheduler(
417
+ vllm_config=mock_vllm_config,
418
+ kv_cache_config=mock_kv_cache_config,
419
+ structured_output_manager=mock_structured_output_manager,
420
+ block_size=16,
421
+ )
422
+
423
+ scheduler.input_queues = [MagicMock(), MagicMock()]
424
+ scheduler.output_queues = [MagicMock(), MagicMock()]
425
+
426
+ scheduler.output_queues[0].get = MagicMock(return_value=5)
427
+ scheduler.output_queues[1].get = MagicMock(return_value=3)
428
+
429
+ total = scheduler.get_num_unfinished_requests()
430
+
431
+ # Verify commands were sent
432
+ scheduler.input_queues[0].put.assert_called_with(
433
+ (SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS, None))
434
+ scheduler.input_queues[1].put.assert_called_with(
435
+ (SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS, None))
436
+
437
+ assert total == 8
438
+
439
+ def test_has_finished_requests(self, mock_vllm_config,
440
+ mock_kv_cache_config,
441
+ mock_structured_output_manager):
442
+ """Test has_finished_requests checks all workers."""
443
+ with patch(
444
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
445
+ ):
446
+ with patch('multiprocessing.get_context'):
447
+ scheduler = DPScheduler(
448
+ vllm_config=mock_vllm_config,
449
+ kv_cache_config=mock_kv_cache_config,
450
+ structured_output_manager=mock_structured_output_manager,
451
+ block_size=16,
452
+ )
453
+
454
+ scheduler.input_queues = [MagicMock(), MagicMock()]
455
+ scheduler.output_queues = [MagicMock(), MagicMock()]
456
+
457
+ scheduler.output_queues[0].get = MagicMock(return_value=False)
458
+ scheduler.output_queues[1].get = MagicMock(return_value=True)
459
+
460
+ result = scheduler.has_finished_requests()
461
+
462
+ assert result is True
463
+
464
+ # Verify commands were sent
465
+ scheduler.input_queues[0].put.assert_called_with(
466
+ (SchedulerCommand.HAS_FINISHED_REQUESTS, None))
467
+ scheduler.input_queues[1].put.assert_called_with(
468
+ (SchedulerCommand.HAS_FINISHED_REQUESTS, None))
469
+
470
+ def test_get_request_counts(self, mock_vllm_config, mock_kv_cache_config,
471
+ mock_structured_output_manager):
472
+ """Test get_request_counts queries all workers."""
473
+ with patch(
474
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
475
+ ):
476
+ with patch('multiprocessing.get_context'):
477
+ scheduler = DPScheduler(
478
+ vllm_config=mock_vllm_config,
479
+ kv_cache_config=mock_kv_cache_config,
480
+ structured_output_manager=mock_structured_output_manager,
481
+ block_size=16,
482
+ )
483
+
484
+ scheduler.input_queues = [MagicMock(), MagicMock()]
485
+ scheduler.output_queues = [MagicMock(), MagicMock()]
486
+
487
+ scheduler.output_queues[0].get = MagicMock(return_value=(2, 1))
488
+ scheduler.output_queues[1].get = MagicMock(return_value=(1, 3))
489
+
490
+ running, waiting = scheduler.get_request_counts()
491
+
492
+ # Verify commands were sent
493
+ scheduler.input_queues[0].put.assert_called_with(
494
+ (SchedulerCommand.GET_REQUEST_COUNTS, None))
495
+ scheduler.input_queues[1].put.assert_called_with(
496
+ (SchedulerCommand.GET_REQUEST_COUNTS, None))
497
+
498
+ assert running == 3
499
+ assert waiting == 4
500
+
501
+ def test_reset_prefix_cache(self, mock_vllm_config, mock_kv_cache_config,
502
+ mock_structured_output_manager):
503
+ """Test reset_prefix_cache sends command to all workers."""
504
+ with patch(
505
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
506
+ ):
507
+ with patch('multiprocessing.get_context'):
508
+ scheduler = DPScheduler(
509
+ vllm_config=mock_vllm_config,
510
+ kv_cache_config=mock_kv_cache_config,
511
+ structured_output_manager=mock_structured_output_manager,
512
+ block_size=16,
513
+ )
514
+
515
+ scheduler.input_queues = [MagicMock(), MagicMock()]
516
+ scheduler.output_queues = [MagicMock(), MagicMock()]
517
+
518
+ scheduler.output_queues[0].get = MagicMock(return_value=True)
519
+ scheduler.output_queues[1].get = MagicMock(return_value=True)
520
+
521
+ result = scheduler.reset_prefix_cache()
522
+
523
+ # Verify commands were sent
524
+ scheduler.input_queues[0].put.assert_called_with(
525
+ (SchedulerCommand.RESET_PREFIX_CACHE, None))
526
+ scheduler.input_queues[1].put.assert_called_with(
527
+ (SchedulerCommand.RESET_PREFIX_CACHE, None))
528
+
529
+ assert result is True
530
+
531
+ def test_make_stats_aggregates_from_workers(
532
+ self, mock_vllm_config, mock_kv_cache_config,
533
+ mock_structured_output_manager):
534
+ """Test make_stats aggregates statistics from all workers."""
535
+ with patch(
536
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
537
+ ):
538
+ with patch('multiprocessing.get_context'):
539
+ scheduler = DPScheduler(
540
+ vllm_config=mock_vllm_config,
541
+ kv_cache_config=mock_kv_cache_config,
542
+ structured_output_manager=mock_structured_output_manager,
543
+ block_size=16,
544
+ log_stats=True,
545
+ )
546
+
547
+ scheduler.input_queues = [MagicMock(), MagicMock()]
548
+ scheduler.output_queues = [MagicMock(), MagicMock()]
549
+
550
+ # Create mock stats
551
+ stats_0 = SchedulerStats(
552
+ num_running_reqs=3,
553
+ num_waiting_reqs=2,
554
+ kv_cache_usage=0.5,
555
+ prefix_cache_stats=PrefixCacheStats(reset=False,
556
+ requests=10,
557
+ queries=8,
558
+ hits=5),
559
+ connector_prefix_cache_stats=PrefixCacheStats(reset=False,
560
+ requests=5,
561
+ queries=4,
562
+ hits=2),
563
+ spec_decoding_stats=None,
564
+ kv_connector_stats=None,
565
+ )
566
+
567
+ stats_1 = SchedulerStats(
568
+ num_running_reqs=4,
569
+ num_waiting_reqs=1,
570
+ kv_cache_usage=0.7,
571
+ prefix_cache_stats=PrefixCacheStats(reset=False,
572
+ requests=15,
573
+ queries=12,
574
+ hits=8),
575
+ connector_prefix_cache_stats=PrefixCacheStats(reset=False,
576
+ requests=6,
577
+ queries=5,
578
+ hits=3),
579
+ spec_decoding_stats=None,
580
+ kv_connector_stats=None,
581
+ )
582
+
583
+ scheduler.output_queues[0].get = MagicMock(
584
+ return_value=stats_0)
585
+ scheduler.output_queues[1].get = MagicMock(
586
+ return_value=stats_1)
587
+
588
+ combined_stats = scheduler.make_stats()
589
+
590
+ # Verify commands were sent
591
+ scheduler.input_queues[0].put.assert_called_with(
592
+ (SchedulerCommand.MAKE_STATS, (None, None)))
593
+ scheduler.input_queues[1].put.assert_called_with(
594
+ (SchedulerCommand.MAKE_STATS, (None, None)))
595
+
596
+ assert combined_stats.num_running_reqs == 7
597
+ assert combined_stats.num_waiting_reqs == 3
598
+ assert combined_stats.kv_cache_usage == 0.6
599
+
600
+ def test_make_stats_returns_none_when_disabled(
601
+ self, mock_vllm_config, mock_kv_cache_config,
602
+ mock_structured_output_manager):
603
+ """Test make_stats returns None when logging disabled."""
604
+ with patch(
605
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
606
+ ):
607
+ with patch('multiprocessing.get_context'):
608
+ scheduler = DPScheduler(
609
+ vllm_config=mock_vllm_config,
610
+ kv_cache_config=mock_kv_cache_config,
611
+ structured_output_manager=mock_structured_output_manager,
612
+ block_size=16,
613
+ log_stats=False,
614
+ )
615
+
616
+ stats = scheduler.make_stats()
617
+ assert stats is None
618
+
619
+ def test_update_draft_token_ids(self, mock_vllm_config,
620
+ mock_kv_cache_config,
621
+ mock_structured_output_manager):
622
+ """Test update_draft_token_ids routes to correct workers."""
623
+ with patch(
624
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
625
+ ):
626
+ with patch('multiprocessing.get_context'):
627
+ scheduler = DPScheduler(
628
+ vllm_config=mock_vllm_config,
629
+ kv_cache_config=mock_kv_cache_config,
630
+ structured_output_manager=mock_structured_output_manager,
631
+ block_size=16,
632
+ )
633
+
634
+ scheduler.input_queues = [MagicMock(), MagicMock()]
635
+ scheduler.output_queues = [MagicMock(), MagicMock()]
636
+ scheduler.output_queues[0].get = MagicMock()
637
+ scheduler.output_queues[1].get = MagicMock()
638
+
639
+ scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
640
+
641
+ draft_token_ids = MagicMock()
642
+ draft_token_ids.req_ids = ["req1", "req2", "req3"]
643
+ draft_token_ids.draft_token_ids = [
644
+ [101, 102, 103],
645
+ [201, 202],
646
+ [301, 302, 303, 304],
647
+ ]
648
+
649
+ scheduler.update_draft_token_ids(draft_token_ids)
650
+
651
+ # Verify commands were sent to correct workers
652
+ scheduler.input_queues[0].put.assert_called()
653
+ scheduler.input_queues[1].put.assert_called()
654
+
655
+ def test_shutdown(self, mock_vllm_config, mock_kv_cache_config,
656
+ mock_structured_output_manager):
657
+ """Test shutdown sends SHUTDOWN command to all workers."""
658
+ with patch(
659
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
660
+ ):
661
+ with patch('multiprocessing.get_context'):
662
+ scheduler = DPScheduler(
663
+ vllm_config=mock_vllm_config,
664
+ kv_cache_config=mock_kv_cache_config,
665
+ structured_output_manager=mock_structured_output_manager,
666
+ block_size=16,
667
+ )
668
+
669
+ scheduler.input_queues = [MagicMock(), MagicMock()]
670
+ scheduler.output_queues = [MagicMock(), MagicMock()]
671
+ scheduler.output_queues[0].get = MagicMock()
672
+ scheduler.output_queues[1].get = MagicMock()
673
+
674
+ mock_process_0 = MagicMock()
675
+ mock_process_1 = MagicMock()
676
+ mock_process_0.is_alive = MagicMock(return_value=False)
677
+ mock_process_1.is_alive = MagicMock(return_value=False)
678
+ scheduler.processes = [mock_process_0, mock_process_1]
679
+
680
+ scheduler.shutdown()
681
+
682
+ # Verify SHUTDOWN commands were sent
683
+ scheduler.input_queues[0].put.assert_called_with(
684
+ (SchedulerCommand.SHUTDOWN, None))
685
+ scheduler.input_queues[1].put.assert_called_with(
686
+ (SchedulerCommand.SHUTDOWN, None))
687
+
688
+ # Verify processes were joined
689
+ mock_process_0.join.assert_called()
690
+ mock_process_1.join.assert_called()
691
+
692
+
693
+ class TestUpdateVllmConfigForDPScheduler:
694
+ """Test the update_vllm_config_for_dp_scheduler function."""
695
+
696
+ def test_update_config_with_dp_size_greater_than_one(self):
697
+ """Test Config is updated when DP size > 1."""
698
+ mock_config = MagicMock()
699
+ mock_config.sharding_config.total_dp_size = 2
700
+ mock_config.scheduler_config._original_scheduler_cls = None
701
+ mock_config.scheduler_config.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
702
+ mock_config.scheduler_config.async_scheduling = False
703
+
704
+ update_vllm_config_for_dp_scheduler(mock_config)
705
+
706
+ # Verify config was updated
707
+ assert mock_config.scheduler_config._original_scheduler_cls == Scheduler
708
+ assert mock_config.scheduler_config.scheduler_cls == DPScheduler
709
+
710
+ def test_update_config_with_dp_size_one(self):
711
+ """Test that config is NOT updated when DP size == 1."""
712
+ mock_config = MagicMock()
713
+ mock_config.sharding_config.total_dp_size = 1
714
+ original_scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
715
+ mock_config.scheduler_config.scheduler_cls = original_scheduler_cls
716
+
717
+ update_vllm_config_for_dp_scheduler(mock_config)
718
+
719
+ # Verify config was NOT changed
720
+ assert mock_config.scheduler_config.scheduler_cls == original_scheduler_cls
721
+
722
+
723
+ if __name__ == "__main__":
724
+ pytest.main([__file__, "-v"])