tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,30 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from unittest.mock import MagicMock, patch
2
16
 
3
17
  import pytest
4
- import torch
5
18
  from vllm.config import VllmConfig
6
- from vllm.v1.core.sched.output import (CachedRequestData, GrammarOutput,
7
- SchedulerOutput)
19
+ from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
8
20
  from vllm.v1.core.sched.scheduler import Scheduler
9
- from vllm.v1.engine import EngineCoreOutputs
10
21
  from vllm.v1.kv_cache_interface import KVCacheConfig
11
22
  from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
12
- from vllm.v1.outputs import ModelRunnerOutput
13
23
  from vllm.v1.request import Request
14
24
 
15
25
  from tpu_inference.core.sched.dp_scheduler import (
16
- DPScheduler, DPSchedulerOutput, update_vllm_config_for_dp_scheduler)
26
+ DPScheduler, DPSchedulerOutput, SchedulerCommand,
27
+ update_vllm_config_for_dp_scheduler)
17
28
 
18
29
 
19
30
  class TestDPScheduler:
@@ -43,387 +54,241 @@ class TestDPScheduler:
43
54
  """Create a mock StructuredOutputManager."""
44
55
  return MagicMock()
45
56
 
46
- def _create_dp_scheduler_with_mocks(self, mock_vllm_config,
47
- mock_kv_cache_config,
48
- mock_structured_output_manager,
49
- **kwargs):
50
- """Helper to create a DPScheduler with properly mocked schedulers."""
51
- # Create individual mock scheduler instances
52
- mock_scheduler_0 = MagicMock()
53
- mock_scheduler_1 = MagicMock()
54
-
55
- # Patch the Scheduler class to return our mock instances
56
- with patch.object(
57
- mock_vllm_config.scheduler_config, '_original_scheduler_cls',
58
- MagicMock(side_effect=[mock_scheduler_0, mock_scheduler_1])):
59
- scheduler = DPScheduler(
60
- vllm_config=mock_vllm_config,
61
- kv_cache_config=mock_kv_cache_config,
62
- structured_output_manager=mock_structured_output_manager,
63
- block_size=16,
64
- **kwargs)
65
-
66
- return scheduler
67
-
68
- def test_init_creates_per_rank_schedulers(
57
+ def test_init_creates_worker_processes(
69
58
  self,
70
59
  mock_vllm_config,
71
60
  mock_kv_cache_config,
72
61
  mock_structured_output_manager,
73
62
  ):
74
- """Test Initialization creates schedulers for each DP rank."""
75
- # Mock the scheduler class
76
- mock_scheduler_instance = MagicMock()
77
- mock_scheduler_cls = MagicMock(return_value=mock_scheduler_instance)
78
-
79
- with patch.object(mock_vllm_config.scheduler_config,
80
- '_original_scheduler_cls', mock_scheduler_cls):
81
- scheduler = DPScheduler(
82
- vllm_config=mock_vllm_config,
83
- kv_cache_config=mock_kv_cache_config,
84
- structured_output_manager=mock_structured_output_manager,
85
- block_size=16,
86
- log_stats=True,
87
- )
88
-
89
- # Verify schedulers were created
90
- assert len(scheduler.schedulers) == 2
91
- assert scheduler.dp_size == 2
92
- assert scheduler.log_stats is True
93
- assert len(scheduler.per_rank_kv_cache_configs) == 2
94
-
95
- # Verify each rank got the correct config
96
- for rank_config in scheduler.per_rank_kv_cache_configs:
97
- assert rank_config.num_blocks == 50 # 100 / 2
63
+ """Test initialization creates worker processes for each DP rank."""
64
+ with patch(
65
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
66
+ ):
67
+ with patch('multiprocessing.get_context') as mock_get_context:
68
+ # Setup mock context
69
+ mock_ctx = MagicMock()
70
+ mock_process = MagicMock()
71
+ mock_queue = MagicMock()
72
+
73
+ mock_ctx.Queue = MagicMock(return_value=mock_queue)
74
+ mock_ctx.Process = MagicMock(return_value=mock_process)
75
+ mock_get_context.return_value = mock_ctx
76
+
77
+ scheduler = DPScheduler(
78
+ vllm_config=mock_vllm_config,
79
+ kv_cache_config=mock_kv_cache_config,
80
+ structured_output_manager=mock_structured_output_manager,
81
+ block_size=16,
82
+ log_stats=True,
83
+ )
84
+
85
+ # Verify processes and queues were created
86
+ assert scheduler.dp_size == 2
87
+ assert len(scheduler.processes) == 2
88
+ assert len(scheduler.input_queues) == 2
89
+ # output_queues is a dict with (rank, command) tuple keys
90
+ # 2 ranks × 14 commands (SchedulerCommand enum)
91
+ assert len(scheduler.output_queues) == 28
92
+ assert scheduler.log_stats is True
93
+ assert len(scheduler.per_rank_kv_cache_configs) == 2
94
+
95
+ # Verify each rank got the correct config
96
+ for rank_config in scheduler.per_rank_kv_cache_configs:
97
+ assert rank_config.num_blocks == 50 # 100 / 2
98
+
99
+ # Verify processes were started
100
+ assert mock_process.start.call_count == 2
98
101
 
99
102
  def test_get_rank_token_counts(self, mock_vllm_config,
100
103
  mock_kv_cache_config,
101
104
  mock_structured_output_manager):
102
- """Test _get_rank_token_counts calculates tokens per rank."""
103
- scheduler = self._create_dp_scheduler_with_mocks(
104
- mock_vllm_config, mock_kv_cache_config,
105
- mock_structured_output_manager)
106
-
107
- # Mock requests on different ranks
108
- req1 = MagicMock()
109
- req1.num_tokens = 10
110
- req2 = MagicMock()
111
- req2.num_tokens = 20
112
- req3 = MagicMock()
113
- req3.num_tokens = 15
114
-
115
- scheduler.schedulers[0].running = [req1]
116
- scheduler.schedulers[0].waiting = [req2]
117
- scheduler.schedulers[1].running = [req3]
118
- scheduler.schedulers[1].waiting = []
119
-
120
- rank_tokens = scheduler._get_rank_token_counts()
121
-
122
- assert rank_tokens[0] == 30 # 10 + 20
123
- assert rank_tokens[1] == 15
105
+ """Test _get_rank_token_counts queries workers and aggregates tokens."""
106
+ with patch(
107
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
108
+ ):
109
+ with patch('multiprocessing.get_context'):
110
+ scheduler = DPScheduler(
111
+ vllm_config=mock_vllm_config,
112
+ kv_cache_config=mock_kv_cache_config,
113
+ structured_output_manager=mock_structured_output_manager,
114
+ block_size=16,
115
+ )
116
+
117
+ # Mock the queues - need to mock the .get() method to return the value
118
+ scheduler.input_queues = [MagicMock(), MagicMock()]
119
+
120
+ mock_queue_0 = MagicMock()
121
+ mock_queue_0.get.return_value = 30
122
+ mock_queue_1 = MagicMock()
123
+ mock_queue_1.get.return_value = 15
124
+
125
+ scheduler.output_queues = {
126
+ (0, "get_token_count"): mock_queue_0,
127
+ (1, "get_token_count"): mock_queue_1,
128
+ }
129
+
130
+ rank_tokens = scheduler._get_rank_token_counts()
131
+
132
+ # Verify correct commands were sent
133
+ scheduler.input_queues[0].put.assert_called_with(
134
+ (SchedulerCommand.GET_TOKEN_COUNT, None))
135
+ scheduler.input_queues[1].put.assert_called_with(
136
+ (SchedulerCommand.GET_TOKEN_COUNT, None))
137
+
138
+ assert rank_tokens[0] == 30
139
+ assert rank_tokens[1] == 15
124
140
 
125
141
  def test_find_best_rank_with_cache_hit(self, mock_vllm_config,
126
142
  mock_kv_cache_config,
127
143
  mock_structured_output_manager):
128
- """Test _find_best_rank_for_request with cache hit."""
129
- scheduler = self._create_dp_scheduler_with_mocks(
130
- mock_vllm_config, mock_kv_cache_config,
131
- mock_structured_output_manager)
132
-
133
- # Mock request
134
- mock_request = MagicMock(spec=Request)
135
-
136
- # Mock KV cache managers with different cache hits
137
- scheduler.schedulers[0].kv_cache_manager = MagicMock()
138
- scheduler.schedulers[
139
- 0].kv_cache_manager.get_computed_blocks.return_value = (
140
- [],
141
- 10,
142
- ) # 10 cached tokens
143
-
144
- scheduler.schedulers[1].kv_cache_manager = MagicMock()
145
- scheduler.schedulers[
146
- 1].kv_cache_manager.get_computed_blocks.return_value = (
147
- [],
148
- 20,
149
- ) # 20 cached tokens (better)
150
-
151
- # Mock empty running/waiting queues
152
- scheduler.schedulers[0].running = []
153
- scheduler.schedulers[0].waiting = []
154
- scheduler.schedulers[1].running = []
155
- scheduler.schedulers[1].waiting = []
156
-
157
- rank = scheduler._find_best_rank_for_request(mock_request)
158
-
159
- # Should choose rank 1 with better cache hit
160
- assert rank == 1
144
+ """Test _find_best_rank_for_request prefers cache hits."""
145
+ with patch(
146
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
147
+ ):
148
+ with patch('multiprocessing.get_context'):
149
+ scheduler = DPScheduler(
150
+ vllm_config=mock_vllm_config,
151
+ kv_cache_config=mock_kv_cache_config,
152
+ structured_output_manager=mock_structured_output_manager,
153
+ block_size=16,
154
+ )
155
+
156
+ mock_request = MagicMock(spec=Request)
157
+
158
+ # Mock the queues with tuple keys (rank, command)
159
+ scheduler.input_queues = [MagicMock(), MagicMock()]
160
+
161
+ # Create proper mocks for queue.get() calls
162
+ mock_queue_get_token_0 = MagicMock()
163
+ mock_queue_get_token_0.get.return_value = 100
164
+ mock_queue_get_token_1 = MagicMock()
165
+ mock_queue_get_token_1.get.return_value = 50
166
+ mock_queue_computed_0 = MagicMock()
167
+ mock_queue_computed_0.get.return_value = ([], 10)
168
+ mock_queue_computed_1 = MagicMock()
169
+ mock_queue_computed_1.get.return_value = ([], 25)
170
+
171
+ scheduler.output_queues = {
172
+ (0, "get_token_count"): mock_queue_get_token_0,
173
+ (1, "get_token_count"): mock_queue_get_token_1,
174
+ (0, "get_computed_blocks"): mock_queue_computed_0,
175
+ (1, "get_computed_blocks"): mock_queue_computed_1,
176
+ }
177
+
178
+ rank = scheduler._find_best_rank_for_request(mock_request)
179
+
180
+ # Should prefer rank with better cache hit
181
+ assert rank == 1
161
182
 
162
183
  def test_find_best_rank_without_cache_hit(self, mock_vllm_config,
163
184
  mock_kv_cache_config,
164
185
  mock_structured_output_manager):
165
- """Test _find_best_rank_for_request without cache hit (load balancing)."""
166
- scheduler = self._create_dp_scheduler_with_mocks(
167
- mock_vllm_config, mock_kv_cache_config,
168
- mock_structured_output_manager)
169
-
170
- # Mock request
171
- mock_request = MagicMock(spec=Request)
172
-
173
- # Mock KV cache managers with no cache hits
174
- scheduler.schedulers[0].kv_cache_manager = MagicMock()
175
- scheduler.schedulers[
176
- 0].kv_cache_manager.get_computed_blocks.return_value = ([], 0)
177
-
178
- scheduler.schedulers[1].kv_cache_manager = MagicMock()
179
- scheduler.schedulers[
180
- 1].kv_cache_manager.get_computed_blocks.return_value = ([], 0)
181
-
182
- # Mock requests with different token counts
183
- req1 = MagicMock()
184
- req1.num_tokens = 50
185
- req2 = MagicMock()
186
- req2.num_tokens = 30
187
-
188
- scheduler.schedulers[0].running = [req1]
189
- scheduler.schedulers[0].waiting = []
190
- scheduler.schedulers[1].running = [req2]
191
- scheduler.schedulers[1].waiting = []
192
-
193
- rank = scheduler._find_best_rank_for_request(mock_request)
194
-
195
- # Should choose rank 1 with fewer tokens
196
- assert rank == 1
186
+ """Test _find_best_rank_for_request uses load balancing without cache hit."""
187
+ with patch(
188
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
189
+ ):
190
+ with patch('multiprocessing.get_context'):
191
+ scheduler = DPScheduler(
192
+ vllm_config=mock_vllm_config,
193
+ kv_cache_config=mock_kv_cache_config,
194
+ structured_output_manager=mock_structured_output_manager,
195
+ block_size=16,
196
+ )
197
+
198
+ mock_request = MagicMock(spec=Request)
199
+
200
+ # Mock the queues with tuple keys (rank, command)
201
+ scheduler.input_queues = [MagicMock(), MagicMock()]
202
+
203
+ # Create proper mocks for queue.get() calls
204
+ mock_queue_get_token_0 = MagicMock()
205
+ mock_queue_get_token_0.get.return_value = 100
206
+ mock_queue_get_token_1 = MagicMock()
207
+ mock_queue_get_token_1.get.return_value = 50
208
+ mock_queue_computed_0 = MagicMock()
209
+ mock_queue_computed_0.get.return_value = ([], 0)
210
+ mock_queue_computed_1 = MagicMock()
211
+ mock_queue_computed_1.get.return_value = ([], 0)
212
+
213
+ scheduler.output_queues = {
214
+ (0, "get_token_count"): mock_queue_get_token_0,
215
+ (1, "get_token_count"): mock_queue_get_token_1,
216
+ (0, "get_computed_blocks"): mock_queue_computed_0,
217
+ (1, "get_computed_blocks"): mock_queue_computed_1,
218
+ }
219
+
220
+ rank = scheduler._find_best_rank_for_request(mock_request)
221
+
222
+ # Should choose rank with fewer tokens (rank 1)
223
+ assert rank == 1
197
224
 
198
225
  def test_add_request_assigns_to_best_rank(self, mock_vllm_config,
199
226
  mock_kv_cache_config,
200
227
  mock_structured_output_manager):
201
- """Test add_request assigns and adds request to best rank."""
202
- scheduler = self._create_dp_scheduler_with_mocks(
203
- mock_vllm_config, mock_kv_cache_config,
204
- mock_structured_output_manager)
205
-
206
- # Mock the rank selection
207
- mock_request = MagicMock(spec=Request)
208
- mock_request.request_id = "req1"
209
-
210
- # Mock _find_best_rank_for_request to return rank 1
211
- scheduler._find_best_rank_for_request = MagicMock(return_value=1)
212
-
213
- # Mock schedulers
214
- scheduler.schedulers[0].add_request = MagicMock()
215
- scheduler.schedulers[1].add_request = MagicMock()
216
-
217
- scheduler.add_request(mock_request)
218
-
219
- # Verify request was assigned to rank 1
220
- assert scheduler.assigned_dp_rank["req1"] == 1
221
- scheduler.schedulers[1].add_request.assert_called_once_with(
222
- mock_request)
223
- scheduler.schedulers[0].add_request.assert_not_called()
224
-
225
- def test_schedule_runs_all_schedulers(self, mock_vllm_config,
226
- mock_kv_cache_config,
227
- mock_structured_output_manager):
228
- """Test schedule runs all schedulers and combines output."""
229
- scheduler = self._create_dp_scheduler_with_mocks(
230
- mock_vllm_config, mock_kv_cache_config,
231
- mock_structured_output_manager)
232
-
233
- # Mock scheduler outputs
234
- mock_output_0 = MagicMock(spec=SchedulerOutput)
235
- mock_output_0.scheduled_new_reqs = []
236
- mock_output_0.num_scheduled_tokens = {"req1": 10}
237
- mock_output_0.total_num_scheduled_tokens = 10
238
- mock_output_0.finished_req_ids = set()
239
- mock_output_0.scheduled_cached_reqs = CachedRequestData(
240
- req_ids=[],
241
- resumed_req_ids=[],
242
- new_token_ids=[],
243
- all_token_ids=[],
244
- new_block_ids=[],
245
- num_computed_tokens=[],
246
- num_output_tokens=[],
247
- )
248
- mock_output_0.scheduled_spec_decode_tokens = {}
249
- mock_output_0.scheduled_encoder_inputs = {}
250
- mock_output_0.num_common_prefix_blocks = []
251
-
252
- mock_output_1 = MagicMock(spec=SchedulerOutput)
253
- mock_output_1.scheduled_new_reqs = []
254
- mock_output_1.num_scheduled_tokens = {"req2": 20}
255
- mock_output_1.total_num_scheduled_tokens = 20
256
- mock_output_1.finished_req_ids = set()
257
- mock_output_1.scheduled_cached_reqs = CachedRequestData(
258
- req_ids=[],
259
- resumed_req_ids=[],
260
- new_token_ids=[],
261
- all_token_ids=[],
262
- new_block_ids=[],
263
- num_computed_tokens=[],
264
- num_output_tokens=[],
265
- )
266
- mock_output_1.scheduled_spec_decode_tokens = {}
267
- mock_output_1.scheduled_encoder_inputs = {}
268
- mock_output_1.num_common_prefix_blocks = []
269
-
270
- scheduler.schedulers[0].schedule = MagicMock(
271
- return_value=mock_output_0)
272
- scheduler.schedulers[1].schedule = MagicMock(
273
- return_value=mock_output_1)
274
- scheduler.schedulers[0].running = []
275
- scheduler.schedulers[0].waiting = []
276
- scheduler.schedulers[1].running = []
277
- scheduler.schedulers[1].waiting = []
278
-
279
- # Assign ranks for requests
280
- scheduler.assigned_dp_rank = {"req1": 0, "req2": 1}
281
-
282
- output = scheduler.schedule()
283
-
284
- # Verify combined output
285
- assert isinstance(output, DPSchedulerOutput)
286
- assert output.total_num_scheduled_tokens == 30 # 10 + 20
287
- assert "req1" in output.num_scheduled_tokens
288
- assert "req2" in output.num_scheduled_tokens
289
- assert output.assigned_dp_rank == {"req1": 0, "req2": 1}
290
-
291
- def test_combine_cached_request_data(self, mock_vllm_config,
292
- mock_kv_cache_config,
293
- mock_structured_output_manager):
294
- """Test _combine_cached_request_data combines data from all ranks."""
295
- mock_scheduler_cls = MagicMock(return_value=MagicMock())
296
- with patch.object(mock_vllm_config.scheduler_config,
297
- '_original_scheduler_cls', mock_scheduler_cls):
298
- scheduler = DPScheduler(
299
- vllm_config=mock_vllm_config,
300
- kv_cache_config=mock_kv_cache_config,
301
- structured_output_manager=mock_structured_output_manager,
302
- block_size=16,
303
- )
304
-
305
- # Create mock rank outputs with different cached request data
306
- output_0 = MagicMock(spec=SchedulerOutput)
307
- output_0.scheduled_cached_reqs = CachedRequestData(
308
- req_ids=["req1"],
309
- resumed_req_ids=["req1"],
310
- new_token_ids=[[1, 2, 3]],
311
- all_token_ids=[[1, 2, 3, 4, 5]],
312
- new_block_ids=[[10, 11]],
313
- num_computed_tokens=[5],
314
- num_output_tokens=[3],
315
- )
316
-
317
- output_1 = MagicMock(spec=SchedulerOutput)
318
- output_1.scheduled_cached_reqs = CachedRequestData(
319
- req_ids=["req2"],
320
- resumed_req_ids=[],
321
- new_token_ids=[[6, 7]],
322
- all_token_ids=[[6, 7, 8, 9]],
323
- new_block_ids=[[20, 21]],
324
- num_computed_tokens=[4],
325
- num_output_tokens=[2],
326
- )
327
-
328
- rank_outputs = [output_0, output_1]
329
- combined = scheduler._combine_cached_request_data(rank_outputs)
330
-
331
- # Verify combined data
332
- assert combined.req_ids == ["req1", "req2"]
333
- assert combined.resumed_req_ids == ["req1"]
334
- assert combined.new_token_ids == [[1, 2, 3], [6, 7]]
335
- assert combined.all_token_ids == [[1, 2, 3, 4, 5], [6, 7, 8, 9]]
336
- assert combined.new_block_ids == [[10, 11], [20, 21]]
337
- assert combined.num_computed_tokens == [5, 4]
338
- assert combined.num_output_tokens == [3, 2]
339
-
340
- def test_get_grammar_bitmask_with_structured_output(
228
+ """Test add_request assigns request to best rank."""
229
+ with patch(
230
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
231
+ ):
232
+ with patch('multiprocessing.get_context'):
233
+ scheduler = DPScheduler(
234
+ vllm_config=mock_vllm_config,
235
+ kv_cache_config=mock_kv_cache_config,
236
+ structured_output_manager=mock_structured_output_manager,
237
+ block_size=16,
238
+ )
239
+
240
+ mock_request = MagicMock(spec=Request)
241
+ mock_request.request_id = "req1"
242
+
243
+ # Mock the queues with tuple keys
244
+ scheduler.input_queues = [MagicMock(), MagicMock()]
245
+ scheduler.output_queues = {
246
+ (0, "add_request"): MagicMock(),
247
+ (1, "add_request"): MagicMock(),
248
+ }
249
+
250
+ # Mock _find_best_rank_for_request to return rank 1
251
+ scheduler._find_best_rank_for_request = MagicMock(
252
+ return_value=1)
253
+
254
+ scheduler.add_request(mock_request)
255
+
256
+ # Verify request was assigned to rank 1
257
+ assert scheduler.assigned_dp_rank["req1"] == 1
258
+
259
+ # Verify ADD_REQUEST command was sent to rank 1
260
+ scheduler.input_queues[1].put.assert_called_with(
261
+ (SchedulerCommand.ADD_REQUEST, mock_request))
262
+
263
+ # Verify we waited for completion
264
+ scheduler.output_queues[(
265
+ 1, "add_request")].get.assert_called_once()
266
+
267
+ def test_schedule_sends_commands_and_combines_output(
341
268
  self, mock_vllm_config, mock_kv_cache_config,
342
269
  mock_structured_output_manager):
343
- """Test get_grammar_bitmask combines bitmasks from all ranks."""
344
- scheduler = self._create_dp_scheduler_with_mocks(
345
- mock_vllm_config, mock_kv_cache_config,
346
- mock_structured_output_manager)
347
-
348
- # Create mock scheduler outputs
349
- mock_output_0 = MagicMock()
350
- mock_output_1 = MagicMock()
351
-
352
- # Mock grammar outputs from each rank
353
- grammar_output_0 = GrammarOutput(
354
- structured_output_request_ids=["req1"],
355
- grammar_bitmask=torch.ones((1, 100), dtype=torch.bool),
356
- )
357
- grammar_output_1 = GrammarOutput(
358
- structured_output_request_ids=["req2"],
359
- grammar_bitmask=torch.ones((1, 100), dtype=torch.bool) * 0,
360
- )
361
-
362
- scheduler.schedulers[0].get_grammar_bitmask = MagicMock(
363
- return_value=grammar_output_0)
364
- scheduler.schedulers[1].get_grammar_bitmask = MagicMock(
365
- return_value=grammar_output_1)
366
-
367
- # Cache scheduler outputs
368
- scheduler.cached_schedulers_output.append(
369
- [mock_output_0, mock_output_1])
370
-
371
- # Create a DPSchedulerOutput
372
- dp_output = DPSchedulerOutput(
373
- scheduled_new_reqs=[],
374
- scheduled_cached_reqs=CachedRequestData(
375
- req_ids=[],
376
- resumed_req_ids=[],
377
- new_token_ids=[],
378
- all_token_ids=[],
379
- new_block_ids=[],
380
- num_computed_tokens=[],
381
- num_output_tokens=[],
382
- ),
383
- num_scheduled_tokens={},
384
- total_num_scheduled_tokens=0,
385
- scheduled_spec_decode_tokens={},
386
- scheduled_encoder_inputs={},
387
- num_common_prefix_blocks=[],
388
- finished_req_ids=set(),
389
- free_encoder_mm_hashes=set(),
390
- )
391
-
392
- result = scheduler.get_grammar_bitmask(dp_output)
393
-
394
- assert result is not None
395
- assert result.structured_output_request_ids == ["req1", "req2"]
396
- assert result.grammar_bitmask.shape == (2, 100)
397
-
398
- def test_get_grammar_bitmask_no_structured_output(
399
- self, mock_vllm_config, mock_kv_cache_config,
400
- mock_structured_output_manager):
401
- """Test get_grammar_bitmask returns None when no structured output."""
402
- mock_scheduler_cls = MagicMock(return_value=MagicMock())
403
- with patch.object(mock_vllm_config.scheduler_config,
404
- '_original_scheduler_cls', mock_scheduler_cls):
405
- scheduler = DPScheduler(
406
- vllm_config=mock_vllm_config,
407
- kv_cache_config=mock_kv_cache_config,
408
- structured_output_manager=mock_structured_output_manager,
409
- block_size=16,
410
- )
411
-
412
- # Mock schedulers returning None
413
- scheduler.schedulers[0].get_grammar_bitmask = MagicMock(
414
- return_value=None)
415
- scheduler.schedulers[1].get_grammar_bitmask = MagicMock(
416
- return_value=None)
417
-
418
- # Cache scheduler outputs
419
- mock_output_0 = MagicMock()
420
- mock_output_1 = MagicMock()
421
- scheduler.cached_schedulers_output.append(
422
- [mock_output_0, mock_output_1])
423
-
424
- dp_output = DPSchedulerOutput(
425
- scheduled_new_reqs=[],
426
- scheduled_cached_reqs=CachedRequestData(
270
+ """Test schedule sends SCHEDULE command to all workers and combines output."""
271
+ with patch(
272
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
273
+ ):
274
+ with patch('multiprocessing.get_context'):
275
+ scheduler = DPScheduler(
276
+ vllm_config=mock_vllm_config,
277
+ kv_cache_config=mock_kv_cache_config,
278
+ structured_output_manager=mock_structured_output_manager,
279
+ block_size=16,
280
+ )
281
+
282
+ # Mock the queues with tuple keys
283
+ scheduler.input_queues = [MagicMock(), MagicMock()]
284
+
285
+ # Create mock scheduler outputs
286
+ mock_output_0 = MagicMock(spec=SchedulerOutput)
287
+ mock_output_0.scheduled_new_reqs = []
288
+ mock_output_0.num_scheduled_tokens = {"req1": 10}
289
+ mock_output_0.total_num_scheduled_tokens = 10
290
+ mock_output_0.finished_req_ids = set()
291
+ mock_output_0.scheduled_cached_reqs = CachedRequestData(
427
292
  req_ids=[],
428
293
  resumed_req_ids=[],
429
294
  new_token_ids=[],
@@ -431,40 +296,17 @@ class TestDPScheduler:
431
296
  new_block_ids=[],
432
297
  num_computed_tokens=[],
433
298
  num_output_tokens=[],
434
- ),
435
- num_scheduled_tokens={},
436
- total_num_scheduled_tokens=0,
437
- scheduled_spec_decode_tokens={},
438
- scheduled_encoder_inputs={},
439
- num_common_prefix_blocks=[],
440
- finished_req_ids=set(),
441
- free_encoder_mm_hashes=set(),
442
- )
443
-
444
- result = scheduler.get_grammar_bitmask(dp_output)
445
- assert result is None
446
-
447
- def test_update_from_output_routes_to_schedulers(
448
- self, mock_vllm_config, mock_kv_cache_config,
449
- mock_structured_output_manager):
450
- """Test update_from_output splits output and updates each scheduler."""
451
- mock_scheduler_cls = MagicMock(return_value=MagicMock())
452
- with patch.object(mock_vllm_config.scheduler_config,
453
- '_original_scheduler_cls', mock_scheduler_cls):
454
- scheduler = DPScheduler(
455
- vllm_config=mock_vllm_config,
456
- kv_cache_config=mock_kv_cache_config,
457
- structured_output_manager=mock_structured_output_manager,
458
- block_size=16,
459
- )
460
-
461
- # Setup assigned ranks
462
- scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
463
-
464
- # Create DPSchedulerOutput
465
- dp_output = DPSchedulerOutput(
466
- scheduled_new_reqs=[],
467
- scheduled_cached_reqs=CachedRequestData(
299
+ )
300
+ mock_output_0.scheduled_spec_decode_tokens = {}
301
+ mock_output_0.scheduled_encoder_inputs = {}
302
+ mock_output_0.num_common_prefix_blocks = []
303
+
304
+ mock_output_1 = MagicMock(spec=SchedulerOutput)
305
+ mock_output_1.scheduled_new_reqs = []
306
+ mock_output_1.num_scheduled_tokens = {"req2": 20}
307
+ mock_output_1.total_num_scheduled_tokens = 20
308
+ mock_output_1.finished_req_ids = set()
309
+ mock_output_1.scheduled_cached_reqs = CachedRequestData(
468
310
  req_ids=[],
469
311
  resumed_req_ids=[],
470
312
  new_token_ids=[],
@@ -472,397 +314,437 @@ class TestDPScheduler:
472
314
  new_block_ids=[],
473
315
  num_computed_tokens=[],
474
316
  num_output_tokens=[],
475
- ),
476
- num_scheduled_tokens={
477
- "req1": 10,
478
- "req2": 20,
479
- "req3": 15
480
- },
481
- total_num_scheduled_tokens=45,
482
- scheduled_spec_decode_tokens={},
483
- scheduled_encoder_inputs={},
484
- num_common_prefix_blocks=[],
485
- finished_req_ids={"req3"}, # req3 finished
486
- free_encoder_mm_hashes=set(),
487
- assigned_dp_rank={
488
- "req1": 0,
489
- "req2": 1,
490
- "req3": 0
491
- },
492
- )
493
-
494
- # Create mock model runner output
495
- model_output = ModelRunnerOutput(
496
- req_ids=["req1", "req2", "req3"],
497
- req_id_to_index={
498
- "req1": 0,
499
- "req2": 1,
500
- "req3": 2
501
- },
502
- sampled_token_ids=torch.tensor([100, 200, 300]),
503
- logprobs=None,
504
- prompt_logprobs_dict={},
505
- pooler_output=None,
506
- num_nans_in_logits=0,
507
- kv_connector_output=None,
508
- )
509
-
510
- # Mock rank scheduler outputs (cached from schedule call)
511
- rank_output_0 = MagicMock()
512
- rank_output_1 = MagicMock()
513
- scheduler.cached_schedulers_output.append(
514
- [rank_output_0, rank_output_1])
515
-
516
- # Mock scheduler update_from_output
517
- engine_output_0 = EngineCoreOutputs()
518
- engine_output_0.engine_index = 0
519
- engine_output_0.outputs = []
520
- engine_output_0.finished_requests = {"req3"}
521
-
522
- engine_output_1 = EngineCoreOutputs()
523
- engine_output_1.engine_index = 0
524
- engine_output_1.outputs = []
525
- engine_output_1.finished_requests = set()
526
-
527
- scheduler.schedulers[0].update_from_output = MagicMock(
528
- return_value={0: engine_output_0})
529
- scheduler.schedulers[1].update_from_output = MagicMock(
530
- return_value={0: engine_output_1})
531
-
532
- # Mock make_stats
533
- scheduler.make_stats = MagicMock(return_value=None)
534
-
535
- _ = scheduler.update_from_output(dp_output, model_output)
536
-
537
- # Verify schedulers were updated
538
- assert scheduler.schedulers[0].update_from_output.called
539
- assert scheduler.schedulers[1].update_from_output.called
540
-
541
- # Verify finished request was cleaned up
542
- assert "req3" not in scheduler.assigned_dp_rank
543
- assert "req1" in scheduler.assigned_dp_rank
544
- assert "req2" in scheduler.assigned_dp_rank
545
-
546
- def test_split_model_output_by_rank(self, mock_vllm_config,
547
- mock_kv_cache_config,
548
- mock_structured_output_manager):
549
- """Test _split_model_output_by_rank distributes output correctly."""
550
- mock_scheduler_cls = MagicMock(return_value=MagicMock())
551
- with patch.object(mock_vllm_config.scheduler_config,
552
- '_original_scheduler_cls', mock_scheduler_cls):
553
- scheduler = DPScheduler(
554
- vllm_config=mock_vllm_config,
555
- kv_cache_config=mock_kv_cache_config,
556
- structured_output_manager=mock_structured_output_manager,
557
- block_size=16,
558
- )
559
-
560
- # Setup assigned ranks
561
- scheduler.assigned_dp_rank = {
562
- "req1": 0,
563
- "req2": 1,
564
- "req3": 0,
565
- "req4": 1
566
- }
567
-
568
- # Create global model output
569
- global_output = ModelRunnerOutput(
570
- req_ids=["req1", "req2", "req3", "req4"],
571
- req_id_to_index={
572
- "req1": 0,
573
- "req2": 1,
574
- "req3": 2,
575
- "req4": 3
576
- },
577
- sampled_token_ids=torch.tensor([100, 200, 300, 400]),
578
- logprobs=None,
579
- prompt_logprobs_dict={},
580
- pooler_output=None,
581
- num_nans_in_logits=0,
582
- kv_connector_output=None,
583
- )
584
-
585
- rank_outputs = scheduler._split_model_output_by_rank(global_output)
586
-
587
- # Verify split outputs
588
- assert len(rank_outputs) == 2
589
- assert rank_outputs[0].req_ids == ["req1", "req3"]
590
- assert rank_outputs[1].req_ids == ["req2", "req4"]
591
-
592
- def test_cleanup_finished_requests(self, mock_vllm_config,
593
- mock_kv_cache_config,
594
- mock_structured_output_manager):
595
- """Test _cleanup_finished_requests removes finished requests."""
596
- mock_scheduler_cls = MagicMock(return_value=MagicMock())
597
- with patch.object(mock_vllm_config.scheduler_config,
598
- '_original_scheduler_cls', mock_scheduler_cls):
599
- scheduler = DPScheduler(
600
- vllm_config=mock_vllm_config,
601
- kv_cache_config=mock_kv_cache_config,
602
- structured_output_manager=mock_structured_output_manager,
603
- block_size=16,
604
- )
605
-
606
- # Setup assigned ranks
607
- scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
608
-
609
- # Clean up finished requests
610
- scheduler._cleanup_finished_requests({"req1", "req3"})
611
-
612
- # Verify cleanup
613
- assert "req1" not in scheduler.assigned_dp_rank
614
- assert "req3" not in scheduler.assigned_dp_rank
615
- assert "req2" in scheduler.assigned_dp_rank
616
-
617
- def test_finish_requests_single_and_multiple(
618
- self, mock_vllm_config, mock_kv_cache_config,
619
- mock_structured_output_manager):
620
- """Test finish_requests handles single string and list."""
621
- scheduler = self._create_dp_scheduler_with_mocks(
622
- mock_vllm_config, mock_kv_cache_config,
623
- mock_structured_output_manager)
624
-
625
- # Setup assigned ranks
626
- scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
627
-
628
- # Mock scheduler finish_requests
629
- scheduler.schedulers[0].finish_requests = MagicMock()
630
- scheduler.schedulers[1].finish_requests = MagicMock()
631
-
632
- # Test with single string
633
- scheduler.finish_requests("req1", finished_status="completed")
634
- scheduler.schedulers[0].finish_requests.assert_called_with(["req1"],
635
- "completed")
636
-
637
- # Test with list
638
- scheduler.schedulers[0].finish_requests.reset_mock()
639
- scheduler.schedulers[1].finish_requests.reset_mock()
640
-
641
- scheduler.finish_requests(["req1", "req2"],
642
- finished_status="completed")
643
- scheduler.schedulers[0].finish_requests.assert_called_once_with(
644
- ["req1"], "completed")
645
- scheduler.schedulers[1].finish_requests.assert_called_once_with(
646
- ["req2"], "completed")
317
+ )
318
+ mock_output_1.scheduled_spec_decode_tokens = {}
319
+ mock_output_1.scheduled_encoder_inputs = {}
320
+ mock_output_1.num_common_prefix_blocks = []
321
+
322
+ # Setup mock queue responses with tuple keys - need to mock .get()
323
+ mock_queue_0 = MagicMock()
324
+ mock_queue_0.get.return_value = mock_output_0
325
+ mock_queue_1 = MagicMock()
326
+ mock_queue_1.get.return_value = mock_output_1
327
+
328
+ scheduler.output_queues = {
329
+ (0, "schedule"): mock_queue_0,
330
+ (1, "schedule"): mock_queue_1,
331
+ }
332
+
333
+ # Setup assigned ranks
334
+ scheduler.assigned_dp_rank = {"req1": 0, "req2": 1}
335
+
336
+ output = scheduler.schedule()
337
+
338
+ # Verify SCHEDULE commands were sent
339
+ scheduler.input_queues[0].put.assert_called_with(
340
+ (SchedulerCommand.SCHEDULE, None))
341
+ scheduler.input_queues[1].put.assert_called_with(
342
+ (SchedulerCommand.SCHEDULE, None))
343
+
344
+ # Verify combined output
345
+ assert isinstance(output, DPSchedulerOutput)
346
+ assert output.total_num_scheduled_tokens == 30
347
+ assert "req1" in output.num_scheduled_tokens
348
+ assert "req2" in output.num_scheduled_tokens
349
+ assert output.assigned_dp_rank == {"req1": 0, "req2": 1}
647
350
 
648
- def test_get_num_unfinished_requests(self, mock_vllm_config,
351
+ def test_combine_cached_request_data(self, mock_vllm_config,
649
352
  mock_kv_cache_config,
650
353
  mock_structured_output_manager):
651
- """Test get_num_unfinished_requests aggregates across ranks."""
652
- scheduler = self._create_dp_scheduler_with_mocks(
653
- mock_vllm_config, mock_kv_cache_config,
654
- mock_structured_output_manager)
655
-
656
- scheduler.schedulers[0].get_num_unfinished_requests = MagicMock(
657
- return_value=5)
658
- scheduler.schedulers[1].get_num_unfinished_requests = MagicMock(
659
- return_value=3)
354
+ """Test _combine_cached_request_data combines data from all ranks."""
355
+ with patch(
356
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
357
+ ):
358
+ with patch('multiprocessing.get_context'):
359
+ scheduler = DPScheduler(
360
+ vllm_config=mock_vllm_config,
361
+ kv_cache_config=mock_kv_cache_config,
362
+ structured_output_manager=mock_structured_output_manager,
363
+ block_size=16,
364
+ )
365
+
366
+ # Create mock rank outputs
367
+ output_0 = MagicMock(spec=SchedulerOutput)
368
+ output_0.scheduled_cached_reqs = CachedRequestData(
369
+ req_ids=["req1"],
370
+ resumed_req_ids=["req1"],
371
+ new_token_ids=[[1, 2, 3]],
372
+ all_token_ids=[[1, 2, 3, 4, 5]],
373
+ new_block_ids=[[10, 11]],
374
+ num_computed_tokens=[5],
375
+ num_output_tokens=[3],
376
+ )
377
+
378
+ output_1 = MagicMock(spec=SchedulerOutput)
379
+ output_1.scheduled_cached_reqs = CachedRequestData(
380
+ req_ids=["req2"],
381
+ resumed_req_ids=[],
382
+ new_token_ids=[[6, 7]],
383
+ all_token_ids=[[6, 7, 8, 9]],
384
+ new_block_ids=[[20, 21]],
385
+ num_computed_tokens=[4],
386
+ num_output_tokens=[2],
387
+ )
388
+
389
+ combined = scheduler._combine_cached_request_data(
390
+ [output_0, output_1])
391
+
392
+ # Verify combined data
393
+ assert combined.req_ids == ["req1", "req2"]
394
+ assert combined.resumed_req_ids == ["req1"]
395
+ assert combined.new_token_ids == [[1, 2, 3], [6, 7]]
396
+ assert combined.num_computed_tokens == [5, 4]
397
+ assert combined.num_output_tokens == [3, 2]
398
+
399
+ def test_finish_requests_routes_to_workers(self, mock_vllm_config,
400
+ mock_kv_cache_config,
401
+ mock_structured_output_manager):
402
+ """Test finish_requests sends FINISH_REQUESTS command to appropriate workers."""
403
+ with patch(
404
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
405
+ ):
406
+ with patch('multiprocessing.get_context'):
407
+ scheduler = DPScheduler(
408
+ vllm_config=mock_vllm_config,
409
+ kv_cache_config=mock_kv_cache_config,
410
+ structured_output_manager=mock_structured_output_manager,
411
+ block_size=16,
412
+ )
413
+
414
+ scheduler.input_queues = [MagicMock(), MagicMock()]
415
+ scheduler.output_queues = {
416
+ (0, "finish_requests"): MagicMock(),
417
+ (1, "finish_requests"): MagicMock(),
418
+ }
419
+
420
+ scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
421
+
422
+ # Test with list of requests
423
+ scheduler.finish_requests(["req1", "req2"],
424
+ finished_status="completed")
425
+
426
+ # Verify FINISH_REQUESTS commands were sent to correct ranks
427
+ scheduler.input_queues[0].put.assert_called()
428
+ scheduler.input_queues[1].put.assert_called()
660
429
 
661
- total = scheduler.get_num_unfinished_requests()
662
- assert total == 8
430
+ def test_get_num_unfinished_requests(self, mock_vllm_config,
431
+ mock_kv_cache_config,
432
+ mock_structured_output_manager):
433
+ """Test get_num_unfinished_requests queries all workers."""
434
+ with patch(
435
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
436
+ ):
437
+ with patch('multiprocessing.get_context'):
438
+ scheduler = DPScheduler(
439
+ vllm_config=mock_vllm_config,
440
+ kv_cache_config=mock_kv_cache_config,
441
+ structured_output_manager=mock_structured_output_manager,
442
+ block_size=16,
443
+ )
444
+
445
+ scheduler.input_queues = [MagicMock(), MagicMock()]
446
+
447
+ # Create proper mocks for queue.get() calls
448
+ mock_queue_0 = MagicMock()
449
+ mock_queue_0.get.return_value = 5
450
+ mock_queue_1 = MagicMock()
451
+ mock_queue_1.get.return_value = 3
452
+
453
+ scheduler.output_queues = {
454
+ (0, "get_num_unfinished_requests"): mock_queue_0,
455
+ (1, "get_num_unfinished_requests"): mock_queue_1,
456
+ }
457
+
458
+ total = scheduler.get_num_unfinished_requests()
459
+
460
+ # Verify commands were sent
461
+ scheduler.input_queues[0].put.assert_called_with(
462
+ (SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS, None))
463
+ scheduler.input_queues[1].put.assert_called_with(
464
+ (SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS, None))
465
+
466
+ assert total == 8
663
467
 
664
468
  def test_has_finished_requests(self, mock_vllm_config,
665
469
  mock_kv_cache_config,
666
470
  mock_structured_output_manager):
667
- """Test has_finished_requests checks all ranks."""
668
- mock_scheduler_cls = MagicMock(return_value=MagicMock())
669
- with patch.object(mock_vllm_config.scheduler_config,
670
- '_original_scheduler_cls', mock_scheduler_cls):
671
- scheduler = DPScheduler(
672
- vllm_config=mock_vllm_config,
673
- kv_cache_config=mock_kv_cache_config,
674
- structured_output_manager=mock_structured_output_manager,
675
- block_size=16,
676
- )
677
-
678
- # Test when one rank has finished requests
679
- scheduler.schedulers[0].has_finished_requests = MagicMock(
680
- return_value=False)
681
- scheduler.schedulers[1].has_finished_requests = MagicMock(
682
- return_value=True)
683
-
684
- assert scheduler.has_finished_requests() is True
685
-
686
- # Test when no rank has finished requests
687
- scheduler.schedulers[1].has_finished_requests = MagicMock(
688
- return_value=False)
689
- assert scheduler.has_finished_requests() is False
471
+ """Test has_finished_requests checks all workers."""
472
+ with patch(
473
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
474
+ ):
475
+ with patch('multiprocessing.get_context'):
476
+ scheduler = DPScheduler(
477
+ vllm_config=mock_vllm_config,
478
+ kv_cache_config=mock_kv_cache_config,
479
+ structured_output_manager=mock_structured_output_manager,
480
+ block_size=16,
481
+ )
482
+
483
+ scheduler.input_queues = [MagicMock(), MagicMock()]
484
+
485
+ # Create proper mocks for queue.get() calls
486
+ mock_queue_0 = MagicMock()
487
+ mock_queue_0.get.return_value = False
488
+ mock_queue_1 = MagicMock()
489
+ mock_queue_1.get.return_value = True
490
+
491
+ scheduler.output_queues = {
492
+ (0, "has_finished_requests"): mock_queue_0,
493
+ (1, "has_finished_requests"): mock_queue_1,
494
+ }
495
+
496
+ result = scheduler.has_finished_requests()
497
+
498
+ assert result is True
499
+
500
+ # Verify commands were sent
501
+ scheduler.input_queues[0].put.assert_called_with(
502
+ (SchedulerCommand.HAS_FINISHED_REQUESTS, None))
503
+ scheduler.input_queues[1].put.assert_called_with(
504
+ (SchedulerCommand.HAS_FINISHED_REQUESTS, None))
690
505
 
691
506
  def test_get_request_counts(self, mock_vllm_config, mock_kv_cache_config,
692
507
  mock_structured_output_manager):
693
- """Test get_request_counts aggregates across ranks."""
694
- scheduler = self._create_dp_scheduler_with_mocks(
695
- mock_vllm_config, mock_kv_cache_config,
696
- mock_structured_output_manager)
697
-
698
- # Mock running and waiting queues
699
- scheduler.schedulers[0].running = [MagicMock(),
700
- MagicMock()] # 2 running
701
- scheduler.schedulers[0].waiting = [MagicMock()] # 1 waiting
702
- scheduler.schedulers[1].running = [MagicMock()] # 1 running
703
- scheduler.schedulers[1].waiting = [
704
- MagicMock(), MagicMock(), MagicMock()
705
- ] # 3 waiting
706
-
707
- running, waiting = scheduler.get_request_counts()
708
-
709
- assert running == 3 # 2 + 1
710
- assert waiting == 4 # 1 + 3
508
+ """Test get_request_counts queries all workers."""
509
+ with patch(
510
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
511
+ ):
512
+ with patch('multiprocessing.get_context'):
513
+ scheduler = DPScheduler(
514
+ vllm_config=mock_vllm_config,
515
+ kv_cache_config=mock_kv_cache_config,
516
+ structured_output_manager=mock_structured_output_manager,
517
+ block_size=16,
518
+ )
519
+
520
+ scheduler.input_queues = [MagicMock(), MagicMock()]
521
+
522
+ # Create proper mocks for queue.get() calls
523
+ mock_queue_0 = MagicMock()
524
+ mock_queue_0.get.return_value = (2, 1)
525
+ mock_queue_1 = MagicMock()
526
+ mock_queue_1.get.return_value = (1, 3)
527
+
528
+ scheduler.output_queues = {
529
+ (0, "get_request_counts"): mock_queue_0,
530
+ (1, "get_request_counts"): mock_queue_1,
531
+ }
532
+
533
+ running, waiting = scheduler.get_request_counts()
534
+
535
+ # Verify commands were sent
536
+ scheduler.input_queues[0].put.assert_called_with(
537
+ (SchedulerCommand.GET_REQUEST_COUNTS, None))
538
+ scheduler.input_queues[1].put.assert_called_with(
539
+ (SchedulerCommand.GET_REQUEST_COUNTS, None))
540
+
541
+ assert running == 3
542
+ assert waiting == 4
711
543
 
712
544
  def test_reset_prefix_cache(self, mock_vllm_config, mock_kv_cache_config,
713
545
  mock_structured_output_manager):
714
- """Test reset_prefix_cache resets all ranks."""
715
- scheduler = self._create_dp_scheduler_with_mocks(
716
- mock_vllm_config, mock_kv_cache_config,
717
- mock_structured_output_manager)
718
-
719
- scheduler.schedulers[0].reset_prefix_cache = MagicMock(
720
- return_value=True)
721
- scheduler.schedulers[1].reset_prefix_cache = MagicMock(
722
- return_value=True)
723
-
724
- result = scheduler.reset_prefix_cache()
725
-
726
- assert result is True
727
- scheduler.schedulers[0].reset_prefix_cache.assert_called_once()
728
- scheduler.schedulers[1].reset_prefix_cache.assert_called_once()
729
-
730
- def test_make_stats_with_logging_enabled(self, mock_vllm_config,
731
- mock_kv_cache_config,
732
- mock_structured_output_manager):
733
- """Test make_stats aggregates stats from all ranks."""
734
- scheduler = self._create_dp_scheduler_with_mocks(
735
- mock_vllm_config,
736
- mock_kv_cache_config,
737
- mock_structured_output_manager,
738
- log_stats=True)
739
-
740
- # Create mock stats for each rank
741
- stats_0 = SchedulerStats(
742
- num_running_reqs=3,
743
- num_waiting_reqs=2,
744
- kv_cache_usage=0.5,
745
- prefix_cache_stats=PrefixCacheStats(reset=False,
746
- requests=10,
747
- queries=8,
748
- hits=5),
749
- connector_prefix_cache_stats=PrefixCacheStats(reset=False,
750
- requests=5,
751
- queries=4,
752
- hits=2),
753
- spec_decoding_stats=None,
754
- kv_connector_stats=None,
755
- )
756
-
757
- stats_1 = SchedulerStats(
758
- num_running_reqs=4,
759
- num_waiting_reqs=1,
760
- kv_cache_usage=0.7,
761
- prefix_cache_stats=PrefixCacheStats(reset=False,
762
- requests=15,
763
- queries=12,
764
- hits=8),
765
- connector_prefix_cache_stats=PrefixCacheStats(reset=False,
766
- requests=6,
767
- queries=5,
768
- hits=3),
769
- spec_decoding_stats=None,
770
- kv_connector_stats=None,
771
- )
772
-
773
- scheduler.schedulers[0].make_stats = MagicMock(return_value=stats_0)
774
- scheduler.schedulers[1].make_stats = MagicMock(return_value=stats_1)
775
-
776
- combined_stats = scheduler.make_stats()
777
-
778
- # Verify aggregated stats
779
- assert combined_stats.num_running_reqs == 7 # 3 + 4
780
- assert combined_stats.num_waiting_reqs == 3 # 2 + 1
781
- assert combined_stats.kv_cache_usage == 0.6 # (0.5 + 0.7) / 2
782
-
783
- # Verify prefix cache stats
784
- assert combined_stats.prefix_cache_stats.requests == 25 # 10 + 15
785
- assert combined_stats.prefix_cache_stats.queries == 20 # 8 + 12
786
- assert combined_stats.prefix_cache_stats.hits == 13 # 5 + 8
787
-
788
- # Verify connector prefix cache stats
789
- assert combined_stats.connector_prefix_cache_stats.requests == 11 # 5 + 6
790
- assert combined_stats.connector_prefix_cache_stats.queries == 9 # 4 + 5
791
- assert combined_stats.connector_prefix_cache_stats.hits == 5 # 2 + 3
792
-
793
- def test_make_stats_with_logging_disabled(self, mock_vllm_config,
794
- mock_kv_cache_config,
795
- mock_structured_output_manager):
796
- """Test make_stats returns None when logging is disabled."""
797
- mock_scheduler_cls = MagicMock(return_value=MagicMock())
798
- with patch.object(mock_vllm_config.scheduler_config,
799
- '_original_scheduler_cls', mock_scheduler_cls):
800
- scheduler = DPScheduler(
801
- vllm_config=mock_vllm_config,
802
- kv_cache_config=mock_kv_cache_config,
803
- structured_output_manager=mock_structured_output_manager,
804
- block_size=16,
805
- log_stats=False,
806
- )
807
-
808
- stats = scheduler.make_stats()
809
- assert stats is None
546
+ """Test reset_prefix_cache sends command to all workers."""
547
+ with patch(
548
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
549
+ ):
550
+ with patch('multiprocessing.get_context'):
551
+ scheduler = DPScheduler(
552
+ vllm_config=mock_vllm_config,
553
+ kv_cache_config=mock_kv_cache_config,
554
+ structured_output_manager=mock_structured_output_manager,
555
+ block_size=16,
556
+ )
557
+
558
+ scheduler.input_queues = [MagicMock(), MagicMock()]
559
+
560
+ # Create proper mocks for queue.get() calls
561
+ mock_queue_0 = MagicMock()
562
+ mock_queue_0.get.return_value = True
563
+ mock_queue_1 = MagicMock()
564
+ mock_queue_1.get.return_value = True
565
+
566
+ scheduler.output_queues = {
567
+ (0, "reset_prefix_cache"): mock_queue_0,
568
+ (1, "reset_prefix_cache"): mock_queue_1,
569
+ }
570
+
571
+ result = scheduler.reset_prefix_cache()
572
+
573
+ # Verify commands were sent
574
+ scheduler.input_queues[0].put.assert_called_with(
575
+ (SchedulerCommand.RESET_PREFIX_CACHE, None))
576
+ scheduler.input_queues[1].put.assert_called_with(
577
+ (SchedulerCommand.RESET_PREFIX_CACHE, None))
578
+
579
+ assert result is True
580
+
581
+ def test_make_stats_aggregates_from_workers(
582
+ self, mock_vllm_config, mock_kv_cache_config,
583
+ mock_structured_output_manager):
584
+ """Test make_stats aggregates statistics from all workers."""
585
+ with patch(
586
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
587
+ ):
588
+ with patch('multiprocessing.get_context'):
589
+ scheduler = DPScheduler(
590
+ vllm_config=mock_vllm_config,
591
+ kv_cache_config=mock_kv_cache_config,
592
+ structured_output_manager=mock_structured_output_manager,
593
+ block_size=16,
594
+ log_stats=True,
595
+ )
596
+
597
+ scheduler.input_queues = [MagicMock(), MagicMock()]
598
+
599
+ # Create mock stats
600
+ stats_0 = SchedulerStats(
601
+ num_running_reqs=3,
602
+ num_waiting_reqs=2,
603
+ kv_cache_usage=0.5,
604
+ prefix_cache_stats=PrefixCacheStats(reset=False,
605
+ requests=10,
606
+ queries=8,
607
+ hits=5),
608
+ connector_prefix_cache_stats=PrefixCacheStats(reset=False,
609
+ requests=5,
610
+ queries=4,
611
+ hits=2),
612
+ spec_decoding_stats=None,
613
+ kv_connector_stats=None,
614
+ )
615
+
616
+ stats_1 = SchedulerStats(
617
+ num_running_reqs=4,
618
+ num_waiting_reqs=1,
619
+ kv_cache_usage=0.7,
620
+ prefix_cache_stats=PrefixCacheStats(reset=False,
621
+ requests=15,
622
+ queries=12,
623
+ hits=8),
624
+ connector_prefix_cache_stats=PrefixCacheStats(reset=False,
625
+ requests=6,
626
+ queries=5,
627
+ hits=3),
628
+ spec_decoding_stats=None,
629
+ kv_connector_stats=None,
630
+ )
631
+
632
+ # Create proper mocks for queue.get() calls
633
+ mock_queue_0 = MagicMock()
634
+ mock_queue_0.get.return_value = stats_0
635
+ mock_queue_1 = MagicMock()
636
+ mock_queue_1.get.return_value = stats_1
637
+
638
+ scheduler.output_queues = {
639
+ (0, "make_stats"): mock_queue_0,
640
+ (1, "make_stats"): mock_queue_1,
641
+ }
642
+
643
+ combined_stats = scheduler.make_stats()
644
+
645
+ # Verify commands were sent
646
+ scheduler.input_queues[0].put.assert_called_with(
647
+ (SchedulerCommand.MAKE_STATS, (None, None)))
648
+ scheduler.input_queues[1].put.assert_called_with(
649
+ (SchedulerCommand.MAKE_STATS, (None, None)))
650
+
651
+ assert combined_stats.num_running_reqs == 7
652
+ assert combined_stats.num_waiting_reqs == 3
653
+ assert combined_stats.kv_cache_usage == 0.6
654
+
655
+ def test_make_stats_returns_none_when_disabled(
656
+ self, mock_vllm_config, mock_kv_cache_config,
657
+ mock_structured_output_manager):
658
+ """Test make_stats returns None when logging disabled."""
659
+ with patch(
660
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
661
+ ):
662
+ with patch('multiprocessing.get_context'):
663
+ scheduler = DPScheduler(
664
+ vllm_config=mock_vllm_config,
665
+ kv_cache_config=mock_kv_cache_config,
666
+ structured_output_manager=mock_structured_output_manager,
667
+ block_size=16,
668
+ log_stats=False,
669
+ )
670
+
671
+ stats = scheduler.make_stats()
672
+ assert stats is None
810
673
 
811
674
  def test_update_draft_token_ids(self, mock_vllm_config,
812
675
  mock_kv_cache_config,
813
676
  mock_structured_output_manager):
814
- """Test update_draft_token_ids routes tokens to correct ranks."""
815
- scheduler = self._create_dp_scheduler_with_mocks(
816
- mock_vllm_config, mock_kv_cache_config,
817
- mock_structured_output_manager)
818
-
819
- # Setup assigned ranks
820
- scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
821
-
822
- # Create mock draft token IDs
823
- draft_token_ids = MagicMock()
824
- draft_token_ids.req_ids = ["req1", "req2", "req3"]
825
- draft_token_ids.draft_token_ids = [
826
- [101, 102, 103],
827
- [201, 202],
828
- [301, 302, 303, 304],
829
- ]
830
-
831
- # Mock scheduler update_draft_token_ids
832
- scheduler.schedulers[0].update_draft_token_ids = MagicMock()
833
- scheduler.schedulers[1].update_draft_token_ids = MagicMock()
834
-
835
- scheduler.update_draft_token_ids(draft_token_ids)
836
-
837
- # Verify each scheduler received correct tokens
838
- assert scheduler.schedulers[0].update_draft_token_ids.called
839
- assert scheduler.schedulers[1].update_draft_token_ids.called
840
-
841
- # Check rank 0 got req1 and req3
842
- call_args_0 = scheduler.schedulers[0].update_draft_token_ids.call_args[
843
- 0][0]
844
- assert "req1" in call_args_0.req_ids
845
- assert "req3" in call_args_0.req_ids
846
-
847
- # Check rank 1 got req2
848
- call_args_1 = scheduler.schedulers[1].update_draft_token_ids.call_args[
849
- 0][0]
850
- assert "req2" in call_args_1.req_ids
677
+ """Test update_draft_token_ids routes to correct workers."""
678
+ with patch(
679
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
680
+ ):
681
+ with patch('multiprocessing.get_context'):
682
+ scheduler = DPScheduler(
683
+ vllm_config=mock_vllm_config,
684
+ kv_cache_config=mock_kv_cache_config,
685
+ structured_output_manager=mock_structured_output_manager,
686
+ block_size=16,
687
+ )
688
+
689
+ scheduler.input_queues = [MagicMock(), MagicMock()]
690
+ scheduler.output_queues = {
691
+ (0, "update_draft_token_ids"): MagicMock(),
692
+ (1, "update_draft_token_ids"): MagicMock(),
693
+ }
694
+
695
+ scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
696
+
697
+ draft_token_ids = MagicMock()
698
+ draft_token_ids.req_ids = ["req1", "req2", "req3"]
699
+ draft_token_ids.draft_token_ids = [
700
+ [101, 102, 103],
701
+ [201, 202],
702
+ [301, 302, 303, 304],
703
+ ]
704
+
705
+ scheduler.update_draft_token_ids(draft_token_ids)
706
+
707
+ # Verify commands were sent to correct workers
708
+ scheduler.input_queues[0].put.assert_called()
709
+ scheduler.input_queues[1].put.assert_called()
851
710
 
852
711
  def test_shutdown(self, mock_vllm_config, mock_kv_cache_config,
853
712
  mock_structured_output_manager):
854
- """Test shutdown calls shutdown on all schedulers."""
855
- scheduler = self._create_dp_scheduler_with_mocks(
856
- mock_vllm_config, mock_kv_cache_config,
857
- mock_structured_output_manager)
858
-
859
- scheduler.schedulers[0].shutdown = MagicMock()
860
- scheduler.schedulers[1].shutdown = MagicMock()
861
-
862
- scheduler.shutdown()
863
-
864
- scheduler.schedulers[0].shutdown.assert_called_once()
865
- scheduler.schedulers[1].shutdown.assert_called_once()
713
+ """Test shutdown sends SHUTDOWN command to all workers."""
714
+ with patch(
715
+ 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
716
+ ):
717
+ with patch('multiprocessing.get_context'):
718
+ scheduler = DPScheduler(
719
+ vllm_config=mock_vllm_config,
720
+ kv_cache_config=mock_kv_cache_config,
721
+ structured_output_manager=mock_structured_output_manager,
722
+ block_size=16,
723
+ )
724
+
725
+ scheduler.input_queues = [MagicMock(), MagicMock()]
726
+ scheduler.output_queues = {
727
+ (0, "shutdown"): MagicMock(),
728
+ (1, "shutdown"): MagicMock(),
729
+ }
730
+
731
+ mock_process_0 = MagicMock()
732
+ mock_process_1 = MagicMock()
733
+ mock_process_0.is_alive = MagicMock(return_value=False)
734
+ mock_process_1.is_alive = MagicMock(return_value=False)
735
+ scheduler.processes = [mock_process_0, mock_process_1]
736
+
737
+ scheduler.shutdown()
738
+
739
+ # Verify SHUTDOWN commands were sent
740
+ scheduler.input_queues[0].put.assert_called_with(
741
+ (SchedulerCommand.SHUTDOWN, None))
742
+ scheduler.input_queues[1].put.assert_called_with(
743
+ (SchedulerCommand.SHUTDOWN, None))
744
+
745
+ # Verify processes were joined
746
+ mock_process_0.join.assert_called()
747
+ mock_process_1.join.assert_called()
866
748
 
867
749
 
868
750
  class TestUpdateVllmConfigForDPScheduler: