tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,814 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import multiprocessing.reduction
17
+ from collections import defaultdict, deque
18
+ from dataclasses import dataclass
19
+ from enum import Enum
20
+ from multiprocessing import Process, Queue
21
+ from typing import Any, Dict, List, Optional, Tuple
22
+
23
+ import cloudpickle
24
+ import torch
25
+ from vllm.config import VllmConfig
26
+ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
27
+ from vllm.v1.core.sched.async_scheduler import AsyncScheduler
28
+ from vllm.v1.core.sched.interface import SchedulerInterface
29
+ from vllm.v1.core.sched.output import (CachedRequestData, GrammarOutput,
30
+ SchedulerOutput)
31
+ from vllm.v1.core.sched.scheduler import Scheduler
32
+ from vllm.v1.engine import EngineCoreOutputs
33
+ from vllm.v1.kv_cache_interface import KVCacheConfig
34
+ from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
35
+ from vllm.v1.outputs import ModelRunnerOutput
36
+ from vllm.v1.request import Request
37
+ from vllm.v1.structured_output import StructuredOutputManager
38
+
39
+ from tpu_inference.logger import init_logger
40
+ from tpu_inference.utils import time_function
41
+
42
+ logger = init_logger(__name__)
43
+
44
+
45
+ class SchedulerCommand(Enum):
46
+ """Enum for scheduler worker process commands."""
47
+ ADD_REQUEST = "add_request"
48
+ SCHEDULE = "schedule"
49
+ FINISH_REQUESTS = "finish_requests"
50
+ UPDATE_DRAFT_TOKEN_IDS = "update_draft_token_ids"
51
+ UPDATE_FROM_OUTPUT = "update_from_output"
52
+ GET_GRAMMAR_BITMASK = "get_grammar_bitmask"
53
+ MAKE_STATS = "make_stats"
54
+ RESET_PREFIX_CACHE = "reset_prefix_cache"
55
+ GET_NUM_UNFINISHED_REQUESTS = "get_num_unfinished_requests"
56
+ HAS_FINISHED_REQUESTS = "has_finished_requests"
57
+ GET_REQUEST_COUNTS = "get_request_counts"
58
+ GET_TOKEN_COUNT = "get_token_count"
59
+ GET_COMPUTED_BLOCKS = "get_computed_blocks"
60
+ SHUTDOWN = "shutdown"
61
+
62
+
63
+ class SchedulerWorkerError(Exception):
64
+ """Exception raised when a scheduler worker process encounters an error."""
65
+
66
+ def __init__(self, rank: int, message: str):
67
+ self.rank = rank
68
+ self.message = message
69
+ super().__init__(f"Scheduler worker {rank} error: {message}")
70
+
71
+
72
+ # Monkey-patch multiprocessing to use cloudpickle
73
+ # Standard pickle fails to serialize the vLLM Request object.
74
+ _original_dumps = multiprocessing.reduction.ForkingPickler.dumps
75
+ _original_loads = multiprocessing.reduction.ForkingPickler.loads
76
+
77
+
78
+ def _cloudpickle_dumps(obj, protocol=None):
79
+ """Use cloudpickle for serialization."""
80
+ return cloudpickle.dumps(obj, protocol=protocol)
81
+
82
+
83
+ def _cloudpickle_loads(data):
84
+ """Use cloudpickle for deserialization."""
85
+ return cloudpickle.loads(data)
86
+
87
+
88
+ def _enable_cloudpickle():
89
+ """Enable cloudpickle for multiprocessing queues."""
90
+ multiprocessing.reduction.ForkingPickler.dumps = staticmethod(
91
+ _cloudpickle_dumps)
92
+ multiprocessing.reduction.ForkingPickler.loads = staticmethod(
93
+ _cloudpickle_loads)
94
+
95
+
96
+ def _disable_cloudpickle():
97
+ """Restore original pickle for multiprocessing."""
98
+ multiprocessing.reduction.ForkingPickler.dumps = _original_dumps
99
+ multiprocessing.reduction.ForkingPickler.loads = _original_loads
100
+
101
+
102
+ def _scheduler_worker_process(
103
+ rank: int,
104
+ input_queue: Queue,
105
+ output_queue: Queue,
106
+ vllm_config: Any,
107
+ kv_cache_config: Any,
108
+ structured_output_manager: Any,
109
+ block_size: int,
110
+ mm_registry: Any,
111
+ include_finished_set: bool,
112
+ log_stats: bool,
113
+ original_scheduler_cls: type,
114
+ ):
115
+ """Worker process that manages a single scheduler instance."""
116
+ # Initialize the scheduler in this process
117
+ scheduler = original_scheduler_cls(
118
+ vllm_config=vllm_config,
119
+ kv_cache_config=kv_cache_config,
120
+ structured_output_manager=structured_output_manager,
121
+ block_size=block_size,
122
+ mm_registry=mm_registry,
123
+ include_finished_set=include_finished_set,
124
+ log_stats=log_stats,
125
+ )
126
+
127
+ logger.debug(f"Scheduler worker process {rank} started")
128
+
129
+ # Process commands from the input queue
130
+ while True:
131
+ try:
132
+ command, data = input_queue.get()
133
+
134
+ match command:
135
+ case SchedulerCommand.ADD_REQUEST:
136
+ request = data
137
+ scheduler.add_request(request)
138
+ output_queue.put(None) # Signal completion
139
+
140
+ case SchedulerCommand.SCHEDULE:
141
+ output = scheduler.schedule()
142
+ output_queue.put(output)
143
+
144
+ case SchedulerCommand.FINISH_REQUESTS:
145
+ request_ids, finished_status = data
146
+ scheduler.finish_requests(request_ids, finished_status)
147
+ output_queue.put(None) # Signal completion
148
+
149
+ case SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS:
150
+ draft_token_ids = data
151
+ scheduler.update_draft_token_ids(draft_token_ids)
152
+ output_queue.put(None) # Signal completion
153
+
154
+ case SchedulerCommand.UPDATE_FROM_OUTPUT:
155
+ scheduler_output, model_runner_output = data
156
+ result = scheduler.update_from_output(
157
+ scheduler_output, model_runner_output)
158
+ output_queue.put(result)
159
+
160
+ case SchedulerCommand.GET_GRAMMAR_BITMASK:
161
+ scheduler_output = data
162
+ result = scheduler.get_grammar_bitmask(scheduler_output)
163
+ output_queue.put(result)
164
+
165
+ case SchedulerCommand.MAKE_STATS:
166
+ spec_decoding_stats, kv_connector_stats = data
167
+ result = scheduler.make_stats(spec_decoding_stats,
168
+ kv_connector_stats)
169
+ output_queue.put(result)
170
+
171
+ case SchedulerCommand.RESET_PREFIX_CACHE:
172
+ result = scheduler.reset_prefix_cache()
173
+ output_queue.put(result)
174
+
175
+ case SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS:
176
+ result = scheduler.get_num_unfinished_requests()
177
+ output_queue.put(result)
178
+
179
+ case SchedulerCommand.HAS_FINISHED_REQUESTS:
180
+ result = scheduler.has_finished_requests()
181
+ output_queue.put(result)
182
+
183
+ case SchedulerCommand.GET_REQUEST_COUNTS:
184
+ running = len(scheduler.running)
185
+ waiting = len(scheduler.waiting)
186
+ output_queue.put((running, waiting))
187
+
188
+ case SchedulerCommand.GET_TOKEN_COUNT:
189
+ # Calculate total tokens across running and waiting requests
190
+ total_tokens = 0
191
+ for req in scheduler.running:
192
+ total_tokens += len(req.all_token_ids)
193
+ for req in scheduler.waiting:
194
+ total_tokens += len(req.all_token_ids)
195
+ output_queue.put(total_tokens)
196
+
197
+ case SchedulerCommand.GET_COMPUTED_BLOCKS:
198
+ request = data
199
+ blocks, cached_tokens = scheduler.kv_cache_manager.get_computed_blocks(
200
+ request)
201
+ output_queue.put((blocks, cached_tokens))
202
+
203
+ case SchedulerCommand.SHUTDOWN:
204
+ scheduler.shutdown()
205
+ output_queue.put(None) # Signal completion
206
+ break
207
+ case _:
208
+ error = SchedulerWorkerError(
209
+ rank, f"Unknown command: {command}")
210
+ output_queue.put(error)
211
+ raise error
212
+
213
+ except Exception as e:
214
+ logger.error(f"Error in scheduler worker {rank}: {e}",
215
+ exc_info=True)
216
+ # Put error on output queue
217
+ error = SchedulerWorkerError(rank, str(e))
218
+ output_queue.put(error)
219
+
220
+
221
+ @dataclass
222
+ class DPSchedulerOutput(SchedulerOutput):
223
+ """Extended SchedulerOutput that includes DP rank assignments."""
224
+ assigned_dp_rank: Optional[Dict[str, int]] = None
225
+
226
+ def __init__(self, *args, assigned_dp_rank=None, **kwargs):
227
+ super().__init__(*args, **kwargs)
228
+ self.assigned_dp_rank = assigned_dp_rank or {}
229
+
230
+
231
+ class DPScheduler(SchedulerInterface):
232
+ """
233
+ DPScheduler is used when DP size is >=2. Otherwise the default vLLM scheduler is used.
234
+
235
+ The DPScheduler manages:
236
+ 1. Multiple vLLM Schedulers (one per DP rank)
237
+ 2. Request-to-scheduler assignment
238
+
239
+ Each Scheduler manages its own logical KV cache shard and scheduling logic.
240
+
241
+ **Load Balancing**
242
+
243
+ For new requests:
244
+ - If there is prefix cache hit, assigns request to the rank with the best hit
245
+ - Otherwise, assigns request to the rank with the least total tokens
246
+
247
+ Once a DP rank is assigned to a request, it remains fixed for the request's lifetime.
248
+ A request will be freed from its assigned rank when it is completed or preempted.
249
+ """
250
+
251
+ def __init__(
252
+ self,
253
+ vllm_config: VllmConfig,
254
+ kv_cache_config: KVCacheConfig,
255
+ structured_output_manager: StructuredOutputManager,
256
+ block_size: int,
257
+ mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
258
+ include_finished_set: bool = False,
259
+ log_stats: bool = False,
260
+ ) -> None:
261
+ self.vllm_config = vllm_config
262
+ self.block_size = block_size
263
+ self.log_stats = log_stats
264
+ self.connector = None
265
+ self.structured_output_manager = structured_output_manager
266
+
267
+ # DP state
268
+ self.dp_size = vllm_config.sharding_config.total_dp_size
269
+ self.assigned_dp_rank: Dict[str, int] = {} # req_id -> dp_rank
270
+ self.cached_schedulers_output = deque()
271
+ self._create_per_rank_configs(kv_cache_config)
272
+
273
+ # The original scheduler class could be Scheduler or AsyncScheduler
274
+ original_scheduler_cls = vllm_config.scheduler_config._original_scheduler_cls
275
+
276
+ # Enable cloudpickle for multiprocessing to handle local functions
277
+ _enable_cloudpickle()
278
+
279
+ # Create worker processes with one input and one output queue each
280
+ import multiprocessing
281
+ ctx = multiprocessing.get_context('fork')
282
+ self.input_queues: List[Queue] = []
283
+ self.output_queues: List[Queue] = []
284
+ self.processes: List[Process] = []
285
+
286
+ for rank in range(self.dp_size):
287
+ input_queue = ctx.Queue()
288
+ output_queue = ctx.Queue()
289
+
290
+ self.input_queues.append(input_queue)
291
+ self.output_queues.append(output_queue)
292
+
293
+ process = ctx.Process(
294
+ target=_scheduler_worker_process,
295
+ args=(
296
+ rank,
297
+ input_queue,
298
+ output_queue,
299
+ self.vllm_config,
300
+ self.per_rank_kv_cache_configs[rank],
301
+ structured_output_manager,
302
+ block_size,
303
+ mm_registry,
304
+ include_finished_set,
305
+ log_stats,
306
+ original_scheduler_cls,
307
+ ),
308
+ )
309
+ process.start()
310
+ self.processes.append(process)
311
+
312
+ logger.info(
313
+ f"DPScheduler (Async = {self.vllm_config.scheduler_config.async_scheduling}) "
314
+ f"started {self.dp_size} worker processes with cloudpickle. "
315
+ f"Per-rank limits: max_seqs={self.vllm_config.scheduler_config.max_num_seqs}, "
316
+ f"max_tokens={self.vllm_config.scheduler_config.max_num_batched_tokens}"
317
+ )
318
+
319
+ def _create_per_rank_configs(self, kv_cache_config: KVCacheConfig) -> None:
320
+ self.per_rank_kv_cache_configs: List[KVCacheConfig] = []
321
+ for _ in range(self.dp_size):
322
+ rank_config = copy.deepcopy(kv_cache_config)
323
+ rank_config.num_blocks = kv_cache_config.num_blocks // self.dp_size
324
+ self.per_rank_kv_cache_configs.append(rank_config)
325
+
326
+ def _get_result_from_queue(self, queue: Queue) -> Any:
327
+ result = queue.get()
328
+ if isinstance(result, SchedulerWorkerError):
329
+ raise result
330
+ return result
331
+
332
+ def _get_rank_token_counts(self) -> Dict[int, int]:
333
+ """Calculate total tokens currently assigned to each DP rank."""
334
+ for rank in range(self.dp_size):
335
+ self.input_queues[rank].put(
336
+ (SchedulerCommand.GET_TOKEN_COUNT, None))
337
+
338
+ rank_tokens = {}
339
+ for rank in range(self.dp_size):
340
+ token_count = self._get_result_from_queue(self.output_queues[rank])
341
+ rank_tokens[rank] = token_count
342
+
343
+ return rank_tokens
344
+
345
+ def _find_best_rank_for_request(self, request: Request) -> int:
346
+ """Find the best DP rank for a new request based on load balancing."""
347
+ rank_tokens = self._get_rank_token_counts()
348
+
349
+ # First, try to find a rank with prefix cache hit
350
+ for rank in range(self.dp_size):
351
+ self.input_queues[rank].put(
352
+ (SchedulerCommand.GET_COMPUTED_BLOCKS, request))
353
+
354
+ best_cache_rank = None
355
+ best_cache_tokens = 0
356
+ for rank in range(self.dp_size):
357
+ blocks, cached_tokens = self._get_result_from_queue(
358
+ self.output_queues[rank])
359
+ if cached_tokens > best_cache_tokens:
360
+ best_cache_tokens = cached_tokens
361
+ best_cache_rank = rank
362
+ if best_cache_tokens > 0:
363
+ return best_cache_rank
364
+
365
+ # Otherwise, find rank with least tokens
366
+ selected_rank = min(rank_tokens, key=rank_tokens.get)
367
+ return selected_rank
368
+
369
+ def add_request(self, request: Request) -> None:
370
+ """
371
+ Add a new request to the appropriate DP rank scheduler.
372
+
373
+ This is the main entry point for new requests. The scheduler will:
374
+ 1. Determine the best DP rank for the request (load balancing + cache hits)
375
+ 2. Assign the request to that rank
376
+ 3. Add the request to the rank's scheduler
377
+ """
378
+ assert request.request_id not in self.assigned_dp_rank, (
379
+ f"Request {request.request_id} already "
380
+ f"assigned to rank {self.assigned_dp_rank[request.request_id]})")
381
+ rank = self._find_best_rank_for_request(request)
382
+ self.assigned_dp_rank[request.request_id] = rank
383
+
384
+ self.input_queues[rank].put((SchedulerCommand.ADD_REQUEST, request))
385
+ self._get_result_from_queue(self.output_queues[rank])
386
+
387
+ @time_function
388
+ def schedule(self) -> DPSchedulerOutput:
389
+ """
390
+ Main scheduling method that coordinates all DP rank schedulers.
391
+
392
+ Process:
393
+ 1. Add any new requests to appropriate DP ranks
394
+ 2. Run each scheduler independently in parallel
395
+ 3. Combine outputs from all schedulers
396
+ 4. Return unified scheduling result
397
+ """
398
+ # Run each scheduler independently
399
+ for rank in range(self.dp_size):
400
+ self.input_queues[rank].put((SchedulerCommand.SCHEDULE, None))
401
+
402
+ # Collect outputs from all workers (blocking)
403
+ rank_outputs = []
404
+ for rank in range(self.dp_size):
405
+ output = self._get_result_from_queue(self.output_queues[rank])
406
+ rank_outputs.append(output)
407
+
408
+ # Cache scheduler outputs to use in `update_from_output`
409
+ self.cached_schedulers_output.append(rank_outputs)
410
+
411
+ # Return combined scheduler outputs
412
+ combined_output = self._combine_scheduler_outputs(rank_outputs)
413
+
414
+ logger.debug(
415
+ f"DPScheduler scheduled: "
416
+ f"{combined_output.total_num_scheduled_tokens} total tokens, "
417
+ f"{len(combined_output.scheduled_new_reqs)} new requests, "
418
+ f"{len(combined_output.scheduled_cached_reqs.req_ids)} cached requests"
419
+ )
420
+
421
+ return combined_output
422
+
423
+ def _combine_scheduler_outputs(
424
+ self, rank_outputs: List[SchedulerOutput]) -> DPSchedulerOutput:
425
+ """Combine outputs from all DP rank schedulers into a unified output."""
426
+
427
+ # Combine new requests
428
+ all_new_reqs = []
429
+ for output in rank_outputs:
430
+ all_new_reqs.extend(output.scheduled_new_reqs)
431
+
432
+ # Combine cached request data
433
+ combined_cached_data = self._combine_cached_request_data(rank_outputs)
434
+
435
+ # Combine token counts and other metrics
436
+ combined_num_scheduled_tokens = {}
437
+ combined_spec_decode_tokens = {}
438
+ combined_encoder_inputs = {}
439
+ total_scheduled_tokens = 0
440
+
441
+ for output in rank_outputs:
442
+ combined_num_scheduled_tokens.update(output.num_scheduled_tokens)
443
+ combined_spec_decode_tokens.update(
444
+ output.scheduled_spec_decode_tokens)
445
+ combined_encoder_inputs.update(output.scheduled_encoder_inputs)
446
+ total_scheduled_tokens += output.total_num_scheduled_tokens
447
+
448
+ # Combine finished request IDs
449
+ combined_finished_req_ids = set()
450
+ for output in rank_outputs:
451
+ combined_finished_req_ids.update(output.finished_req_ids)
452
+
453
+ # Combine other fields (take from first non-empty or use defaults)
454
+ num_common_prefix_blocks = rank_outputs[
455
+ 0].num_common_prefix_blocks if rank_outputs else []
456
+
457
+ # Create DP rank assignment mapping for scheduled requests
458
+ assigned_dp_rank = {}
459
+ for req_id in combined_num_scheduled_tokens.keys():
460
+ assigned_dp_rank[req_id] = self.assigned_dp_rank[req_id]
461
+
462
+ return DPSchedulerOutput(
463
+ scheduled_new_reqs=all_new_reqs,
464
+ scheduled_cached_reqs=combined_cached_data,
465
+ num_scheduled_tokens=combined_num_scheduled_tokens,
466
+ total_num_scheduled_tokens=total_scheduled_tokens,
467
+ scheduled_spec_decode_tokens=combined_spec_decode_tokens,
468
+ scheduled_encoder_inputs=combined_encoder_inputs,
469
+ num_common_prefix_blocks=num_common_prefix_blocks,
470
+ finished_req_ids=combined_finished_req_ids,
471
+ free_encoder_mm_hashes=set(),
472
+ assigned_dp_rank=assigned_dp_rank,
473
+ )
474
+
475
+ def _combine_cached_request_data(
476
+ self, rank_outputs: List[SchedulerOutput]) -> CachedRequestData:
477
+ """Combine cached request data from all DP rank schedulers."""
478
+ combined_req_ids = []
479
+ combined_resumed_req_ids = []
480
+ combined_new_token_ids = []
481
+ combined_all_token_ids = []
482
+ combined_new_block_ids = []
483
+ combined_num_computed_tokens = []
484
+ combined_num_output_tokens = []
485
+
486
+ for output in rank_outputs:
487
+ cached_data = output.scheduled_cached_reqs
488
+
489
+ combined_req_ids.extend(cached_data.req_ids)
490
+ combined_resumed_req_ids.extend(cached_data.resumed_req_ids)
491
+ combined_new_token_ids.extend(cached_data.new_token_ids)
492
+ combined_all_token_ids.extend(cached_data.all_token_ids)
493
+ combined_new_block_ids.extend(cached_data.new_block_ids)
494
+ combined_num_computed_tokens.extend(
495
+ cached_data.num_computed_tokens)
496
+ combined_num_output_tokens.extend(cached_data.num_output_tokens)
497
+
498
+ return CachedRequestData(
499
+ req_ids=combined_req_ids,
500
+ resumed_req_ids=combined_resumed_req_ids,
501
+ new_token_ids=combined_new_token_ids,
502
+ all_token_ids=combined_all_token_ids,
503
+ new_block_ids=combined_new_block_ids,
504
+ num_computed_tokens=combined_num_computed_tokens,
505
+ num_output_tokens=combined_num_output_tokens,
506
+ )
507
+
508
+ def get_grammar_bitmask(
509
+ self,
510
+ scheduler_output: DPSchedulerOutput,
511
+ ) -> GrammarOutput | None:
512
+ """
513
+ Generate grammar bitmask for structured output requests across all DP ranks.
514
+
515
+ This method calls get_grammar_bitmask on each underlying scheduler and
516
+ combines their outputs, similar to how other operations are handled.
517
+ """
518
+ # Use the most recent cached outputs from the schedule() call
519
+ if not self.cached_schedulers_output:
520
+ return None
521
+
522
+ rank_scheduler_outputs = self.cached_schedulers_output[
523
+ -1] # Get the most recent
524
+
525
+ combined_structured_output_request_ids = []
526
+ combined_bitmasks = []
527
+
528
+ # Get grammar bitmask from each DP rank scheduler
529
+ for rank in range(self.dp_size):
530
+ self.input_queues[rank].put((SchedulerCommand.GET_GRAMMAR_BITMASK,
531
+ rank_scheduler_outputs[rank]))
532
+ for rank in range(self.dp_size):
533
+ grammar_output = self._get_result_from_queue(
534
+ self.output_queues[rank])
535
+ if grammar_output is not None:
536
+ combined_structured_output_request_ids.extend(
537
+ grammar_output.structured_output_request_ids)
538
+ combined_bitmasks.append(grammar_output.grammar_bitmask)
539
+
540
+ if not combined_structured_output_request_ids:
541
+ return None
542
+
543
+ # Combine bitmasks - concatenate along the batch dimension
544
+ if len(combined_bitmasks) == 1:
545
+ combined_bitmask = combined_bitmasks[0]
546
+ else:
547
+ combined_bitmask = torch.cat(combined_bitmasks, dim=0)
548
+
549
+ return GrammarOutput(combined_structured_output_request_ids,
550
+ combined_bitmask)
551
+
552
+ def update_from_output(
553
+ self, scheduler_output: DPSchedulerOutput,
554
+ model_runner_output: ModelRunnerOutput
555
+ ) -> dict[int, EngineCoreOutputs]:
556
+ """
557
+ Update all DP rank schedulers based on model runner output.
558
+
559
+ We need to route the model runner output to the appropriate scheduler
560
+ based on which rank each request belongs to.
561
+ """
562
+ # Group model runner outputs by DP rank
563
+ rank_model_outputs = self._split_model_output_by_rank(
564
+ model_runner_output)
565
+ rank_scheduler_outputs = self.cached_schedulers_output.popleft()
566
+ # Update each scheduler with its portion of the output
567
+ for rank in range(self.dp_size):
568
+ self.input_queues[rank].put(
569
+ (SchedulerCommand.UPDATE_FROM_OUTPUT,
570
+ (rank_scheduler_outputs[rank], rank_model_outputs[rank])))
571
+
572
+ combined_engine_outputs = defaultdict(list)
573
+ for rank in range(self.dp_size):
574
+ rank_engine_outputs = self._get_result_from_queue(
575
+ self.output_queues[rank])
576
+ for client_idx, engine_output in rank_engine_outputs.items():
577
+ combined_engine_outputs[client_idx].append(engine_output)
578
+
579
+ # Clean up finished requests from DP tracking
580
+ self._cleanup_finished_requests(scheduler_output.finished_req_ids)
581
+
582
+ # Return combined EngineCoreOutput
583
+ for client_idx, engine_outputs in combined_engine_outputs.items():
584
+ combined_output = EngineCoreOutputs()
585
+ outputs = []
586
+ finished_requests = set()
587
+ for engine_output in engine_outputs:
588
+ outputs.extend(engine_output.outputs)
589
+ if engine_output.finished_requests:
590
+ finished_requests.update(engine_output.finished_requests)
591
+ combined_output.engine_index = engine_outputs[0].engine_index
592
+ combined_output.outputs = outputs
593
+ combined_output.finished_requests = finished_requests
594
+ combined_output.scheduler_stats = self.make_stats()
595
+ combined_engine_outputs[client_idx] = combined_output
596
+
597
+ return combined_engine_outputs
598
+
599
+ def _split_model_output_by_rank(
600
+ self,
601
+ global_model_output: ModelRunnerOutput) -> List[ModelRunnerOutput]:
602
+ """Split the model runner output by DP rank for individual scheduler updates."""
603
+ outputs = [
604
+ ModelRunnerOutput(
605
+ req_ids=[],
606
+ req_id_to_index=global_model_output.req_id_to_index,
607
+ sampled_token_ids=global_model_output.sampled_token_ids,
608
+ logprobs=global_model_output.logprobs,
609
+ prompt_logprobs_dict=global_model_output.prompt_logprobs_dict,
610
+ pooler_output=None,
611
+ num_nans_in_logits=global_model_output.num_nans_in_logits,
612
+ kv_connector_output=global_model_output.kv_connector_output,
613
+ ) for _ in range(self.dp_size)
614
+ ]
615
+
616
+ for req_id in global_model_output.req_ids:
617
+ rank = self.assigned_dp_rank[req_id]
618
+ outputs[rank].req_ids.append(req_id)
619
+
620
+ return outputs
621
+
622
+ def _cleanup_finished_requests(self, finished_req_ids: set[str]) -> None:
623
+ """Remove finished requests from our DP rank assignment tracking."""
624
+ for req_id in finished_req_ids:
625
+ if req_id in self.assigned_dp_rank:
626
+ del self.assigned_dp_rank[req_id]
627
+
628
+ def finish_requests(self, request_ids, finished_status) -> None:
629
+ """Forward request finish signals to the appropriate DP rank schedulers."""
630
+ if isinstance(request_ids, str):
631
+ request_ids = [request_ids]
632
+
633
+ # Route finish signals to appropriate schedulers
634
+ rank_request_ids = defaultdict(list)
635
+ for req_id in request_ids:
636
+ rank = self.assigned_dp_rank[req_id]
637
+ rank_request_ids[rank].append(req_id)
638
+
639
+ # Forward to each scheduler
640
+ for rank, req_ids in rank_request_ids.items():
641
+ self.input_queues[rank].put(
642
+ (SchedulerCommand.FINISH_REQUESTS, (req_ids, finished_status)))
643
+ self._get_result_from_queue(self.output_queues[rank])
644
+
645
+ def get_num_unfinished_requests(self) -> int:
646
+ """Get total number of unfinished requests across all DP ranks."""
647
+ for rank in range(self.dp_size):
648
+ self.input_queues[rank].put(
649
+ (SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS, None))
650
+
651
+ total = 0
652
+ for rank in range(self.dp_size):
653
+ count = self._get_result_from_queue(self.output_queues[rank])
654
+ total += count
655
+ return total
656
+
657
+ def has_finished_requests(self) -> bool:
658
+ """Check if any DP rank has finished requests."""
659
+ for rank in range(self.dp_size):
660
+ self.input_queues[rank].put(
661
+ (SchedulerCommand.HAS_FINISHED_REQUESTS, None))
662
+
663
+ has_finished_any = False
664
+ for rank in range(self.dp_size):
665
+ has_finished_any |= self._get_result_from_queue(
666
+ self.output_queues[rank])
667
+ return has_finished_any
668
+
669
+ def get_request_counts(self) -> Tuple[int, int]:
670
+ """Get total (running, waiting) request counts across all DP ranks."""
671
+ for rank in range(self.dp_size):
672
+ self.input_queues[rank].put(
673
+ (SchedulerCommand.GET_REQUEST_COUNTS, None))
674
+
675
+ total_running = 0
676
+ total_waiting = 0
677
+ for rank in range(self.dp_size):
678
+ running, waiting = self._get_result_from_queue(
679
+ self.output_queues[rank])
680
+ total_running += running
681
+ total_waiting += waiting
682
+ return total_running, total_waiting
683
+
684
+ def reset_prefix_cache(self) -> bool:
685
+ """Reset prefix cache for all DP rank schedulers."""
686
+ for rank in range(self.dp_size):
687
+ self.input_queues[rank].put(
688
+ (SchedulerCommand.RESET_PREFIX_CACHE, None))
689
+
690
+ all_success = True
691
+ for rank in range(self.dp_size):
692
+ success = self._get_result_from_queue(self.output_queues[rank])
693
+ all_success &= success
694
+ return all_success
695
+
696
+ def make_stats(self,
697
+ spec_decoding_stats=None,
698
+ kv_connector_stats=None) -> Optional[SchedulerStats]:
699
+ """Combine stats from all DP rank schedulers."""
700
+ if not self.log_stats:
701
+ return None
702
+
703
+ # Aggregate stats from all schedulers
704
+ total_running_reqs = 0
705
+ total_waiting_reqs = 0
706
+ total_kv_cache_usage = 0.0
707
+
708
+ combined_prefix_cache_stats = PrefixCacheStats()
709
+ combined_connector_prefix_cache_stats: Optional[
710
+ PrefixCacheStats] = None
711
+
712
+ for rank in range(self.dp_size):
713
+ self.input_queues[rank].put(
714
+ (SchedulerCommand.MAKE_STATS, (spec_decoding_stats,
715
+ kv_connector_stats)))
716
+
717
+ for rank in range(self.dp_size):
718
+ rank_stats = self._get_result_from_queue(self.output_queues[rank])
719
+ if rank_stats is None:
720
+ continue
721
+
722
+ total_running_reqs += rank_stats.num_running_reqs
723
+ total_waiting_reqs += rank_stats.num_waiting_reqs
724
+ total_kv_cache_usage += rank_stats.kv_cache_usage
725
+
726
+ # Combine prefix cache stats
727
+ if rank_stats.prefix_cache_stats:
728
+ combined_prefix_cache_stats.reset = rank_stats.prefix_cache_stats.reset
729
+ combined_prefix_cache_stats.requests += rank_stats.prefix_cache_stats.requests
730
+ combined_prefix_cache_stats.queries += rank_stats.prefix_cache_stats.queries
731
+ combined_prefix_cache_stats.hits += rank_stats.prefix_cache_stats.hits
732
+
733
+ # Combine connector prefix cache stats
734
+ if rank_stats.connector_prefix_cache_stats:
735
+ if combined_connector_prefix_cache_stats is None:
736
+ combined_connector_prefix_cache_stats = PrefixCacheStats()
737
+ combined_connector_prefix_cache_stats.reset = rank_stats.connector_prefix_cache_stats.reset
738
+ combined_connector_prefix_cache_stats.requests += rank_stats.connector_prefix_cache_stats.requests
739
+ combined_connector_prefix_cache_stats.queries += rank_stats.connector_prefix_cache_stats.queries
740
+ combined_connector_prefix_cache_stats.hits += rank_stats.connector_prefix_cache_stats.hits
741
+
742
+ # Average KV cache usage across ranks
743
+ avg_kv_cache_usage = total_kv_cache_usage / self.dp_size if self.dp_size else 0.0
744
+
745
+ return SchedulerStats(
746
+ num_running_reqs=total_running_reqs,
747
+ num_waiting_reqs=total_waiting_reqs,
748
+ kv_cache_usage=avg_kv_cache_usage,
749
+ prefix_cache_stats=combined_prefix_cache_stats,
750
+ connector_prefix_cache_stats=combined_connector_prefix_cache_stats,
751
+ spec_decoding_stats=spec_decoding_stats,
752
+ kv_connector_stats=kv_connector_stats.data
753
+ if kv_connector_stats else None,
754
+ )
755
+
756
+ def update_draft_token_ids(self, draft_token_ids) -> None:
757
+ """Forward draft token updates to the appropriate DP rank schedulers."""
758
+ # Group draft tokens by DP rank based on request assignments
759
+ rank_draft_tokens = defaultdict(lambda: {
760
+ "req_ids": [],
761
+ "draft_token_ids": []
762
+ })
763
+
764
+ for req_id, tokens in zip(draft_token_ids.req_ids,
765
+ draft_token_ids.draft_token_ids):
766
+ if req_id in self.assigned_dp_rank:
767
+ rank = self.assigned_dp_rank[req_id]
768
+ rank_draft_tokens[rank]["req_ids"].append(req_id)
769
+ rank_draft_tokens[rank]["draft_token_ids"].append(tokens)
770
+
771
+ for rank, draft_data in rank_draft_tokens.items():
772
+ # Create a draft_token_ids object for this rank (mock structure)
773
+ rank_draft_token_ids = type(draft_token_ids)(
774
+ req_ids=draft_data["req_ids"],
775
+ draft_token_ids=draft_data["draft_token_ids"])
776
+ self.input_queues[rank].put(
777
+ (SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS,
778
+ rank_draft_token_ids))
779
+ self._get_result_from_queue(self.output_queues[rank])
780
+
781
+ def shutdown(self) -> None:
782
+ """Shutdown all DP rank scheduler worker processes."""
783
+ # Send shutdown command to all workers
784
+ for rank in range(self.dp_size):
785
+ self.input_queues[rank].put((SchedulerCommand.SHUTDOWN, None))
786
+
787
+ # Wait for acknowledgment (blocking)
788
+ for rank in range(self.dp_size):
789
+ self._get_result_from_queue(self.output_queues[rank])
790
+
791
+ # Terminate and join all processes
792
+ for process in self.processes:
793
+ process.join(timeout=5.0)
794
+ if process.is_alive():
795
+ process.terminate()
796
+ process.join()
797
+
798
+ # Restore original pickle
799
+ _disable_cloudpickle()
800
+
801
+
802
+ def update_vllm_config_for_dp_scheduler(vllm_config: Any) -> None:
803
+ """
804
+ Update vLLM configuration to use DPScheduler when DP size > 1.
805
+ """
806
+ dp_size = vllm_config.sharding_config.total_dp_size
807
+
808
+ if dp_size > 1:
809
+ if vllm_config.scheduler_config.async_scheduling:
810
+ vllm_config.scheduler_config._original_scheduler_cls = AsyncScheduler
811
+ else:
812
+ vllm_config.scheduler_config._original_scheduler_cls = Scheduler
813
+
814
+ vllm_config.scheduler_config.scheduler_cls = DPScheduler