tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,27 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import copy
16
+ import multiprocessing.reduction
2
17
  from collections import defaultdict, deque
3
18
  from dataclasses import dataclass
19
+ from enum import Enum
20
+ from multiprocessing import Process, Queue
21
+ from time import time
4
22
  from typing import Any, Dict, List, Optional, Tuple
5
23
 
24
+ import cloudpickle
6
25
  import torch
7
26
  from vllm.config import VllmConfig
8
27
  from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
@@ -19,10 +38,186 @@ from vllm.v1.request import Request
19
38
  from vllm.v1.structured_output import StructuredOutputManager
20
39
 
21
40
  from tpu_inference.logger import init_logger
41
+ from tpu_inference.utils import time_function
22
42
 
23
43
  logger = init_logger(__name__)
24
44
 
25
45
 
46
+ class SchedulerCommand(Enum):
47
+ """Enum for scheduler worker process commands."""
48
+ ADD_REQUEST = "add_request"
49
+ SCHEDULE = "schedule"
50
+ FINISH_REQUESTS = "finish_requests"
51
+ UPDATE_DRAFT_TOKEN_IDS = "update_draft_token_ids"
52
+ UPDATE_FROM_OUTPUT = "update_from_output"
53
+ GET_GRAMMAR_BITMASK = "get_grammar_bitmask"
54
+ MAKE_STATS = "make_stats"
55
+ RESET_PREFIX_CACHE = "reset_prefix_cache"
56
+ GET_NUM_UNFINISHED_REQUESTS = "get_num_unfinished_requests"
57
+ HAS_FINISHED_REQUESTS = "has_finished_requests"
58
+ GET_REQUEST_COUNTS = "get_request_counts"
59
+ GET_TOKEN_COUNT = "get_token_count"
60
+ GET_COMPUTED_BLOCKS = "get_computed_blocks"
61
+ SHUTDOWN = "shutdown"
62
+
63
+
64
+ class SchedulerWorkerError(Exception):
65
+ """Exception raised when a scheduler worker process encounters an error."""
66
+
67
+ def __init__(self, rank: int, message: str):
68
+ self.rank = rank
69
+ self.message = message
70
+ super().__init__(f"Scheduler worker {rank} error: {message}")
71
+
72
+
73
+ # Monkey-patch multiprocessing to use cloudpickle
74
+ # Standard pickle fails to serialize the vLLM Request object.
75
+ _original_dumps = multiprocessing.reduction.ForkingPickler.dumps
76
+ _original_loads = multiprocessing.reduction.ForkingPickler.loads
77
+
78
+
79
+ def _cloudpickle_dumps(obj, protocol=None):
80
+ """Use cloudpickle for serialization."""
81
+ return cloudpickle.dumps(obj, protocol=protocol)
82
+
83
+
84
+ def _cloudpickle_loads(data):
85
+ """Use cloudpickle for deserialization."""
86
+ return cloudpickle.loads(data)
87
+
88
+
89
+ def _enable_cloudpickle():
90
+ """Enable cloudpickle for multiprocessing queues."""
91
+ multiprocessing.reduction.ForkingPickler.dumps = staticmethod(
92
+ _cloudpickle_dumps)
93
+ multiprocessing.reduction.ForkingPickler.loads = staticmethod(
94
+ _cloudpickle_loads)
95
+
96
+
97
+ def _disable_cloudpickle():
98
+ """Restore original pickle for multiprocessing."""
99
+ multiprocessing.reduction.ForkingPickler.dumps = _original_dumps
100
+ multiprocessing.reduction.ForkingPickler.loads = _original_loads
101
+
102
+
103
+ def _scheduler_worker_process(
104
+ rank: int,
105
+ input_queue: Queue,
106
+ output_queues: Dict[str, Queue],
107
+ vllm_config: Any,
108
+ kv_cache_config: Any,
109
+ structured_output_manager: Any,
110
+ block_size: int,
111
+ mm_registry: Any,
112
+ include_finished_set: bool,
113
+ log_stats: bool,
114
+ original_scheduler_cls: type,
115
+ ):
116
+ """Worker process that manages a single scheduler instance."""
117
+ # Initialize the scheduler in this process
118
+ scheduler = original_scheduler_cls(
119
+ vllm_config=vllm_config,
120
+ kv_cache_config=kv_cache_config,
121
+ structured_output_manager=structured_output_manager,
122
+ block_size=block_size,
123
+ mm_registry=mm_registry,
124
+ include_finished_set=include_finished_set,
125
+ log_stats=log_stats,
126
+ )
127
+
128
+ logger.debug(f"Scheduler worker process {rank} started")
129
+
130
+ # Process commands from the input queue
131
+ while True:
132
+ try:
133
+ command, data = input_queue.get()
134
+
135
+ match command:
136
+ case SchedulerCommand.ADD_REQUEST:
137
+ request = data
138
+ scheduler.add_request(request)
139
+ output_queues[command.value].put(None) # Signal completion
140
+
141
+ case SchedulerCommand.SCHEDULE:
142
+ output = scheduler.schedule()
143
+ output_queues[command.value].put(output)
144
+
145
+ case SchedulerCommand.FINISH_REQUESTS:
146
+ request_ids, finished_status = data
147
+ scheduler.finish_requests(request_ids, finished_status)
148
+ output_queues[command.value].put(None) # Signal completion
149
+
150
+ case SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS:
151
+ draft_token_ids = data
152
+ scheduler.update_draft_token_ids(draft_token_ids)
153
+ output_queues[command.value].put(None) # Signal completion
154
+
155
+ case SchedulerCommand.UPDATE_FROM_OUTPUT:
156
+ scheduler_output, model_runner_output = data
157
+ result = scheduler.update_from_output(
158
+ scheduler_output, model_runner_output)
159
+ output_queues[command.value].put(result)
160
+
161
+ case SchedulerCommand.GET_GRAMMAR_BITMASK:
162
+ scheduler_output = data
163
+ result = scheduler.get_grammar_bitmask(scheduler_output)
164
+ output_queues[command.value].put(result)
165
+
166
+ case SchedulerCommand.MAKE_STATS:
167
+ spec_decoding_stats, kv_connector_stats = data
168
+ result = scheduler.make_stats(spec_decoding_stats,
169
+ kv_connector_stats)
170
+ output_queues[command.value].put(result)
171
+
172
+ case SchedulerCommand.RESET_PREFIX_CACHE:
173
+ result = scheduler.reset_prefix_cache()
174
+ output_queues[command.value].put(result)
175
+
176
+ case SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS:
177
+ result = scheduler.get_num_unfinished_requests()
178
+ output_queues[command.value].put(result)
179
+
180
+ case SchedulerCommand.HAS_FINISHED_REQUESTS:
181
+ result = scheduler.has_finished_requests()
182
+ output_queues[command.value].put(result)
183
+
184
+ case SchedulerCommand.GET_REQUEST_COUNTS:
185
+ running = len(scheduler.running)
186
+ waiting = len(scheduler.waiting)
187
+ output_queues[command.value].put((running, waiting))
188
+
189
+ case SchedulerCommand.GET_TOKEN_COUNT:
190
+ # Calculate total tokens across running and waiting requests
191
+ total_tokens = 0
192
+ for req in scheduler.running:
193
+ total_tokens += len(req.all_token_ids)
194
+ for req in scheduler.waiting:
195
+ total_tokens += len(req.all_token_ids)
196
+ output_queues[command.value].put(total_tokens)
197
+
198
+ case SchedulerCommand.GET_COMPUTED_BLOCKS:
199
+ request = data
200
+ blocks, cached_tokens = scheduler.kv_cache_manager.get_computed_blocks(
201
+ request)
202
+ output_queues[command.value].put((blocks, cached_tokens))
203
+
204
+ case SchedulerCommand.SHUTDOWN:
205
+ scheduler.shutdown()
206
+ output_queues[command.value].put(None) # Signal completion
207
+ break
208
+ case _:
209
+ error = SchedulerWorkerError(
210
+ rank, f"Unknown command: {command}")
211
+ output_queues[command.value].put(error)
212
+ raise error
213
+
214
+ except Exception as e:
215
+ logger.error(f"Error in scheduler worker {rank}: {e}",
216
+ exc_info=True)
217
+ error = SchedulerWorkerError(rank, str(e))
218
+ output_queues[command.value].put(error)
219
+
220
+
26
221
  @dataclass
27
222
  class DPSchedulerOutput(SchedulerOutput):
28
223
  """Extended SchedulerOutput that includes DP rank assignments."""
@@ -77,22 +272,50 @@ class DPScheduler(SchedulerInterface):
77
272
 
78
273
  # The original scheduler class could be Scheduler or AsyncScheduler
79
274
  original_scheduler_cls = vllm_config.scheduler_config._original_scheduler_cls
80
- self.schedulers: List[Scheduler] = []
275
+
276
+ # Enable cloudpickle for multiprocessing to handle local functions
277
+ _enable_cloudpickle()
278
+
279
+ # Create worker processes with separate output queues for each command type
280
+ import multiprocessing
281
+ ctx = multiprocessing.get_context('fork')
282
+ self.input_queues: List[Queue] = []
283
+ self.output_queues: Dict[Tuple[int, str], Queue] = {}
284
+ self.processes: List[Process] = []
285
+
81
286
  for rank in range(self.dp_size):
82
- scheduler = original_scheduler_cls(
83
- vllm_config=self.vllm_config,
84
- kv_cache_config=self.per_rank_kv_cache_configs[rank],
85
- structured_output_manager=structured_output_manager,
86
- block_size=block_size,
87
- mm_registry=mm_registry,
88
- include_finished_set=include_finished_set,
89
- log_stats=log_stats,
287
+ input_queue = ctx.Queue()
288
+ self.input_queues.append(input_queue)
289
+
290
+ output_queues_for_rank: Dict[str, Queue] = {}
291
+ for cmd in SchedulerCommand:
292
+ output_queues_for_rank[cmd.value] = ctx.Queue()
293
+ self.output_queues[(
294
+ rank, cmd.value)] = output_queues_for_rank[cmd.value]
295
+
296
+ process = ctx.Process(
297
+ target=_scheduler_worker_process,
298
+ args=(
299
+ rank,
300
+ input_queue,
301
+ output_queues_for_rank,
302
+ self.vllm_config,
303
+ self.per_rank_kv_cache_configs[rank],
304
+ structured_output_manager,
305
+ block_size,
306
+ mm_registry,
307
+ include_finished_set,
308
+ log_stats,
309
+ original_scheduler_cls,
310
+ ),
90
311
  )
91
- self.schedulers.append(scheduler)
312
+ process.start()
313
+ self.processes.append(process)
92
314
 
93
315
  logger.info(
94
316
  f"DPScheduler (Async = {self.vllm_config.scheduler_config.async_scheduling}) "
95
- f"per-rank limits: max_seqs={self.vllm_config.scheduler_config.max_num_seqs}, "
317
+ f"started {self.dp_size} worker processes with cloudpickle. "
318
+ f"Per-rank limits: max_seqs={self.vllm_config.scheduler_config.max_num_seqs}, "
96
319
  f"max_tokens={self.vllm_config.scheduler_config.max_num_batched_tokens}"
97
320
  )
98
321
 
@@ -103,15 +326,39 @@ class DPScheduler(SchedulerInterface):
103
326
  rank_config.num_blocks = kv_cache_config.num_blocks // self.dp_size
104
327
  self.per_rank_kv_cache_configs.append(rank_config)
105
328
 
329
+ def _get_result_from_queue(self, rank: int,
330
+ command: SchedulerCommand) -> Any:
331
+ """Get result from the output queue for a specific rank and command type."""
332
+ queue_obj = self.output_queues[(rank, command.value)]
333
+ try:
334
+ start_time = time()
335
+ result = queue_obj.get()
336
+ end_time = time()
337
+ if end_time - start_time > 1.0:
338
+ logger.warning(
339
+ f"Long wait time ({end_time - start_time:.2f}s) for rank {rank} "
340
+ f"command {command.value} response.")
341
+ except EOFError as e:
342
+ raise RuntimeError(
343
+ f"Queue error for rank {rank} command {command.value}: "
344
+ "Worker process terminated unexpectedly. "
345
+ "This may indicate a crash in the scheduler worker process."
346
+ ) from e
347
+ if isinstance(result, SchedulerWorkerError):
348
+ raise result
349
+ return result
350
+
106
351
  def _get_rank_token_counts(self) -> Dict[int, int]:
107
352
  """Calculate total tokens currently assigned to each DP rank."""
108
- rank_tokens = {rank: 0 for rank in range(self.dp_size)}
353
+ for rank in range(self.dp_size):
354
+ self.input_queues[rank].put(
355
+ (SchedulerCommand.GET_TOKEN_COUNT, None))
109
356
 
110
- for rank, scheduler in enumerate(self.schedulers):
111
- for request in scheduler.running:
112
- rank_tokens[rank] += request.num_tokens
113
- for request in scheduler.waiting:
114
- rank_tokens[rank] += request.num_tokens
357
+ rank_tokens = {}
358
+ for rank in range(self.dp_size):
359
+ token_count = self._get_result_from_queue(
360
+ rank, SchedulerCommand.GET_TOKEN_COUNT)
361
+ rank_tokens[rank] = token_count
115
362
 
116
363
  return rank_tokens
117
364
 
@@ -120,11 +367,15 @@ class DPScheduler(SchedulerInterface):
120
367
  rank_tokens = self._get_rank_token_counts()
121
368
 
122
369
  # First, try to find a rank with prefix cache hit
370
+ for rank in range(self.dp_size):
371
+ self.input_queues[rank].put(
372
+ (SchedulerCommand.GET_COMPUTED_BLOCKS, request))
373
+
123
374
  best_cache_rank = None
124
375
  best_cache_tokens = 0
125
- for rank, scheduler in enumerate(self.schedulers):
126
- blocks, cached_tokens = scheduler.kv_cache_manager.get_computed_blocks(
127
- request)
376
+ for rank in range(self.dp_size):
377
+ blocks, cached_tokens = self._get_result_from_queue(
378
+ rank, SchedulerCommand.GET_COMPUTED_BLOCKS)
128
379
  if cached_tokens > best_cache_tokens:
129
380
  best_cache_tokens = cached_tokens
130
381
  best_cache_rank = rank
@@ -149,26 +400,30 @@ class DPScheduler(SchedulerInterface):
149
400
  f"assigned to rank {self.assigned_dp_rank[request.request_id]})")
150
401
  rank = self._find_best_rank_for_request(request)
151
402
  self.assigned_dp_rank[request.request_id] = rank
152
- self.schedulers[rank].add_request(request)
153
403
 
404
+ self.input_queues[rank].put((SchedulerCommand.ADD_REQUEST, request))
405
+ self._get_result_from_queue(rank, SchedulerCommand.ADD_REQUEST)
406
+
407
+ @time_function
154
408
  def schedule(self) -> DPSchedulerOutput:
155
409
  """
156
410
  Main scheduling method that coordinates all DP rank schedulers.
157
411
 
158
412
  Process:
159
413
  1. Add any new requests to appropriate DP ranks
160
- 2. Run each scheduler independently
414
+ 2. Run each scheduler independently in parallel
161
415
  3. Combine outputs from all schedulers
162
416
  4. Return unified scheduling result
163
417
  """
164
418
  # Run each scheduler independently
419
+ for rank in range(self.dp_size):
420
+ self.input_queues[rank].put((SchedulerCommand.SCHEDULE, None))
421
+
422
+ # Collect outputs from all workers (blocking)
165
423
  rank_outputs = []
166
- for rank, scheduler in enumerate(self.schedulers):
167
- logger.debug(
168
- f"Running scheduler for rank {rank}: "
169
- f"{len(scheduler.running)} running, {len(scheduler.waiting)} waiting"
170
- )
171
- output = scheduler.schedule()
424
+ for rank in range(self.dp_size):
425
+ output = self._get_result_from_queue(rank,
426
+ SchedulerCommand.SCHEDULE)
172
427
  rank_outputs.append(output)
173
428
 
174
429
  # Cache scheduler outputs to use in `update_from_output`
@@ -292,10 +547,12 @@ class DPScheduler(SchedulerInterface):
292
547
  combined_bitmasks = []
293
548
 
294
549
  # Get grammar bitmask from each DP rank scheduler
295
- for rank, scheduler in enumerate(self.schedulers):
296
- rank_output = rank_scheduler_outputs[rank]
297
- grammar_output = scheduler.get_grammar_bitmask(rank_output)
298
-
550
+ for rank in range(self.dp_size):
551
+ self.input_queues[rank].put((SchedulerCommand.GET_GRAMMAR_BITMASK,
552
+ rank_scheduler_outputs[rank]))
553
+ for rank in range(self.dp_size):
554
+ grammar_output = self._get_result_from_queue(
555
+ rank, SchedulerCommand.GET_GRAMMAR_BITMASK)
299
556
  if grammar_output is not None:
300
557
  combined_structured_output_request_ids.extend(
301
558
  grammar_output.structured_output_request_ids)
@@ -328,10 +585,15 @@ class DPScheduler(SchedulerInterface):
328
585
  model_runner_output)
329
586
  rank_scheduler_outputs = self.cached_schedulers_output.popleft()
330
587
  # Update each scheduler with its portion of the output
588
+ for rank in range(self.dp_size):
589
+ self.input_queues[rank].put(
590
+ (SchedulerCommand.UPDATE_FROM_OUTPUT,
591
+ (rank_scheduler_outputs[rank], rank_model_outputs[rank])))
592
+
331
593
  combined_engine_outputs = defaultdict(list)
332
- for rank, scheduler in enumerate(self.schedulers):
333
- rank_engine_outputs = scheduler.update_from_output(
334
- rank_scheduler_outputs[rank], rank_model_outputs[rank])
594
+ for rank in range(self.dp_size):
595
+ rank_engine_outputs = self._get_result_from_queue(
596
+ rank, SchedulerCommand.UPDATE_FROM_OUTPUT)
335
597
  for client_idx, engine_output in rank_engine_outputs.items():
336
598
  combined_engine_outputs[client_idx].append(engine_output)
337
599
 
@@ -397,30 +659,62 @@ class DPScheduler(SchedulerInterface):
397
659
 
398
660
  # Forward to each scheduler
399
661
  for rank, req_ids in rank_request_ids.items():
400
- self.schedulers[rank].finish_requests(req_ids, finished_status)
662
+ self.input_queues[rank].put(
663
+ (SchedulerCommand.FINISH_REQUESTS, (req_ids, finished_status)))
664
+ self._get_result_from_queue(rank, SchedulerCommand.FINISH_REQUESTS)
401
665
 
402
666
  def get_num_unfinished_requests(self) -> int:
403
667
  """Get total number of unfinished requests across all DP ranks."""
404
- return sum(scheduler.get_num_unfinished_requests()
405
- for scheduler in self.schedulers)
668
+ for rank in range(self.dp_size):
669
+ self.input_queues[rank].put(
670
+ (SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS, None))
671
+
672
+ total = 0
673
+ for rank in range(self.dp_size):
674
+ count = self._get_result_from_queue(
675
+ rank, SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS)
676
+ total += count
677
+ return total
406
678
 
407
679
  def has_finished_requests(self) -> bool:
408
680
  """Check if any DP rank has finished requests."""
409
- return any(scheduler.has_finished_requests()
410
- for scheduler in self.schedulers)
681
+ for rank in range(self.dp_size):
682
+ self.input_queues[rank].put(
683
+ (SchedulerCommand.HAS_FINISHED_REQUESTS, None))
684
+
685
+ has_finished_any = False
686
+ for rank in range(self.dp_size):
687
+ has_finished_any |= self._get_result_from_queue(
688
+ rank, SchedulerCommand.HAS_FINISHED_REQUESTS)
689
+ return has_finished_any
411
690
 
412
691
  def get_request_counts(self) -> Tuple[int, int]:
413
692
  """Get total (running, waiting) request counts across all DP ranks."""
414
- total_running = sum(
415
- len(scheduler.running) for scheduler in self.schedulers)
416
- total_waiting = sum(
417
- len(scheduler.waiting) for scheduler in self.schedulers)
693
+ for rank in range(self.dp_size):
694
+ self.input_queues[rank].put(
695
+ (SchedulerCommand.GET_REQUEST_COUNTS, None))
696
+
697
+ total_running = 0
698
+ total_waiting = 0
699
+ for rank in range(self.dp_size):
700
+ running, waiting = self._get_result_from_queue(
701
+ rank, SchedulerCommand.GET_REQUEST_COUNTS)
702
+ total_running += running
703
+ total_waiting += waiting
418
704
  return total_running, total_waiting
419
705
 
420
706
  def reset_prefix_cache(self) -> bool:
421
707
  """Reset prefix cache for all DP rank schedulers."""
422
- return all(scheduler.reset_prefix_cache()
423
- for scheduler in self.schedulers)
708
+ for rank in range(self.dp_size):
709
+ self.input_queues[rank].put(
710
+ (SchedulerCommand.RESET_PREFIX_CACHE, None))
711
+
712
+ all_success = True
713
+ for rank in range(self.dp_size):
714
+ success = self._get_result_from_queue(
715
+ rank, SchedulerCommand.RESET_PREFIX_CACHE)
716
+ all_success &= success
717
+ return all_success
424
718
 
425
719
  def make_stats(self,
426
720
  spec_decoding_stats=None,
@@ -438,9 +732,14 @@ class DPScheduler(SchedulerInterface):
438
732
  combined_connector_prefix_cache_stats: Optional[
439
733
  PrefixCacheStats] = None
440
734
 
441
- for scheduler in self.schedulers:
442
- rank_stats = scheduler.make_stats(spec_decoding_stats,
443
- kv_connector_stats)
735
+ for rank in range(self.dp_size):
736
+ self.input_queues[rank].put(
737
+ (SchedulerCommand.MAKE_STATS, (spec_decoding_stats,
738
+ kv_connector_stats)))
739
+
740
+ for rank in range(self.dp_size):
741
+ rank_stats = self._get_result_from_queue(
742
+ rank, SchedulerCommand.MAKE_STATS)
444
743
  if rank_stats is None:
445
744
  continue
446
745
 
@@ -465,8 +764,7 @@ class DPScheduler(SchedulerInterface):
465
764
  combined_connector_prefix_cache_stats.hits += rank_stats.connector_prefix_cache_stats.hits
466
765
 
467
766
  # Average KV cache usage across ranks
468
- avg_kv_cache_usage = total_kv_cache_usage / len(
469
- self.schedulers) if self.schedulers else 0.0
767
+ avg_kv_cache_usage = total_kv_cache_usage / self.dp_size if self.dp_size else 0.0
470
768
 
471
769
  return SchedulerStats(
472
770
  num_running_reqs=total_running_reqs,
@@ -494,18 +792,36 @@ class DPScheduler(SchedulerInterface):
494
792
  rank_draft_tokens[rank]["req_ids"].append(req_id)
495
793
  rank_draft_tokens[rank]["draft_token_ids"].append(tokens)
496
794
 
497
- # Forward to each scheduler
498
795
  for rank, draft_data in rank_draft_tokens.items():
499
796
  # Create a draft_token_ids object for this rank (mock structure)
500
797
  rank_draft_token_ids = type(draft_token_ids)(
501
798
  req_ids=draft_data["req_ids"],
502
799
  draft_token_ids=draft_data["draft_token_ids"])
503
- self.schedulers[rank].update_draft_token_ids(rank_draft_token_ids)
800
+ self.input_queues[rank].put(
801
+ (SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS,
802
+ rank_draft_token_ids))
803
+ self._get_result_from_queue(
804
+ rank, SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS)
504
805
 
505
806
  def shutdown(self) -> None:
506
- """Shutdown all DP rank schedulers."""
507
- for scheduler in self.schedulers:
508
- scheduler.shutdown()
807
+ """Shutdown all DP rank scheduler worker processes."""
808
+ # Send shutdown command to all workers
809
+ for rank in range(self.dp_size):
810
+ self.input_queues[rank].put((SchedulerCommand.SHUTDOWN, None))
811
+
812
+ # Wait for acknowledgment (blocking)
813
+ for rank in range(self.dp_size):
814
+ self._get_result_from_queue(rank, SchedulerCommand.SHUTDOWN)
815
+
816
+ # Terminate and join all processes
817
+ for process in self.processes:
818
+ process.join(timeout=5.0)
819
+ if process.is_alive():
820
+ process.terminate()
821
+ process.join()
822
+
823
+ # Restore original pickle
824
+ _disable_cloudpickle()
509
825
 
510
826
 
511
827
  def update_vllm_config_for_dp_scheduler(vllm_config: Any) -> None:
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Any, Optional
2
16
 
3
17
  import jax
@@ -88,7 +88,7 @@ if TYPE_CHECKING:
88
88
  from tpu_inference import envs
89
89
  from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
90
90
  get_kv_ports,
91
- get_kv_transfer_port, get_node_id,
91
+ get_kv_transfer_port,
92
92
  get_side_channel_port)
93
93
  from tpu_inference.logger import init_logger
94
94
  from tpu_inference.runner.tpu_runner import TPUModelRunner
@@ -442,10 +442,10 @@ class TPUConnectorWorker:
442
442
  self.runner: TPUModelRunner = None
443
443
  self.mesh: Mesh = None
444
444
  self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
445
- # NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
446
- # The worker rank is assigned with vLLM's sorting logic, which does not work
447
- # for TPU host topology.
448
- self.node_id = get_node_id()
445
+ # default value for none distributed scenario
446
+ # when the topology is initialized, runner will update it
447
+ # based on topology_order_id
448
+ self.node_id = 0
449
449
 
450
450
  # req_id: (kv, expiration_time)
451
451
  self.reqs_wait_pull: dict[ReqId, list[list[jax.Array], float]] = {}
@@ -472,7 +472,7 @@ class TPUConnectorWorker:
472
472
  self.pull_conns: dict[str, Any] = {}
473
473
  self.notif_sockets: dict[str, zmq.Socket] = {}
474
474
 
475
- logger.info(f"TPUConnector Worker {self.node_id} --> init | "
475
+ logger.info(f"TPUConnector Worker --> init | "
476
476
  f"ip={self.host_ip} | "
477
477
  f"kv_transfer_port={self.kv_transfer_port} | "
478
478
  f"side_channel_port={self.side_channel_port}")
@@ -488,6 +488,7 @@ class TPUConnectorWorker:
488
488
  self.zmq_cxt.destroy(linger=0)
489
489
 
490
490
  def register_runner(self, runner: TPUModelRunner):
491
+ self.node_id = runner.topology_order_id
491
492
  self.runner = runner
492
493
  self.mesh = runner.mesh
493
494
 
@@ -498,6 +499,10 @@ class TPUConnectorWorker:
498
499
  self.shape = list(kv_layer.shape)
499
500
  self.dtype = kv_layer.dtype
500
501
  self.sharding = kv_layer.sharding
502
+ logger.info(f"TPUConnector Worker --> register_runner | "
503
+ f"node_id={self.node_id} | "
504
+ f"ip={self.host_ip} | "
505
+ f"kv_transfer_port={self.kv_transfer_port}")
501
506
  self._maybe_start_p2p_server()
502
507
 
503
508
  def _maybe_start_p2p_server(self):
@@ -694,9 +699,9 @@ class TPUConnectorWorker:
694
699
 
695
700
  def get_uuid() -> int:
696
701
  int128 = uuid4().int
697
- # Must be 64-bit int, otherwise vllm output encoder would raise error.
698
- int64 = int128 >> 64
699
- return int64
702
+ # Must be less than 64-bit int, otherwise vllm output encoder would raise error.
703
+ # use 50 bit to avoid GO trunk the int when doing JSon serialization
704
+ return int128 >> 78
700
705
 
701
706
 
702
707
  @jax.jit