tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,786 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import functools
3
+ import itertools
4
+ import math
5
+ import os
6
+ import queue
7
+ import signal
8
+ import threading
9
+ import time
10
+ import traceback
11
+ from typing import Any, Callable, Optional, Tuple, TypeVar, Union
12
+
13
+ import jax
14
+ # ======================================================================================
15
+ # Imports for DisaggEngineCoreProc (the vLLM adapter)
16
+ # ======================================================================================
17
+ from vllm.config import VllmConfig
18
+ from vllm.logger import init_logger
19
+ from vllm.tasks import POOLING_TASKS, SupportedTask
20
+ from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
21
+ init_none_hash)
22
+ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
23
+ EngineCoreRequestType, UtilityOutput,
24
+ UtilityResult)
25
+ from vllm.v1.engine.core import EngineCore as vLLMEngineCore
26
+ from vllm.v1.engine.core import EngineCoreProc as vLLMEngineCoreProc
27
+ from vllm.v1.executor.abstract import Executor
28
+ from vllm.v1.request import Request, RequestStatus
29
+
30
+ from tpu_inference import utils as common_utils
31
+ from tpu_inference.core import disagg_executor, disagg_utils
32
+ from tpu_inference.runner.tpu_runner import AsyncTPUModelRunnerOutput
33
+ # ======================================================================================
34
+ # Imports for _DisaggOrchestrator (decoupled from vLLM)
35
+ # ======================================================================================
36
+ from tpu_inference.runner.utils import LatencyTracker
37
+
38
+ # This file contains two classes:
39
+ # 1. _DisaggOrchestrator: The clean, decoupled core orchestration logic.
40
+ # 2. DisaggEngineCoreProc: The vLLM-facing adapter that handles process management.
41
+
42
+ logger = init_logger(__name__)
43
+
44
+ POLLING_TIMEOUT_S = 2.5
45
+ HANDSHAKE_TIMEOUT_MINS = 5
46
+
47
+ _R = TypeVar('_R') # Return type for collective_rpc
48
+
49
+ # ======================================================================================
50
+ # Class 1: The Decoupled Orchestrator
51
+ # ======================================================================================
52
+
53
+
54
+ class JetThread(threading.Thread):
55
+ """Thread that kills the program if it fails.
56
+
57
+ If a driver thread goes down, we can't operate.
58
+ """
59
+
60
+ def run(self):
61
+ try:
62
+ super().run()
63
+ except Exception as e: # pylint: disable=broad-exception-caught
64
+ print(f"Thread {self.name} encountered an error: {e}")
65
+ traceback.print_exc()
66
+ os.kill(os.getpid(), signal.SIGKILL)
67
+
68
+
69
+ class _DisaggOrchestrator:
70
+ """Contains the core orchestration logic, decoupled from vLLM."""
71
+
72
+ def __init__(
73
+ self,
74
+ config: VllmConfig,
75
+ output_queue: queue.Queue,
76
+ prefill_engines: list[vLLMEngineCore],
77
+ decode_engines: list[vLLMEngineCore],
78
+ prefill_slice_sizes: tuple[int, ...],
79
+ decode_slice_sizes: tuple[int, ...],
80
+ ):
81
+ self._config = config
82
+ self._output_queue = output_queue
83
+ self._prefill_engines = prefill_engines
84
+ self._decode_engines = decode_engines
85
+
86
+ # Keep track of active requests.
87
+ self._requests: dict[str, Request] = {}
88
+
89
+ # Hack device config to pass in the subslice of TPUs.
90
+ slice_sizes = list(prefill_slice_sizes)
91
+ slice_sizes.extend(decode_slice_sizes)
92
+
93
+ self._transfer_backlogs = [
94
+ queue.Queue(4) for i in range(len(self._prefill_engines))
95
+ ]
96
+
97
+ self._decode_backlogs = {}
98
+ for idx, engine in enumerate(self._decode_engines):
99
+ # Determine the decode backlog len by remaning hbm dividing max kv cache size of a single request
100
+ runner = engine.model_executor.driver_worker.model_runner
101
+ hbm_usage = common_utils.hbm_usage_bytes(
102
+ engine.model_executor.driver_worker.devices)
103
+ if not hbm_usage:
104
+ self._decode_backlogs[idx] = queue.Queue(
105
+ self._config.scheduler_config.max_num_seqs)
106
+ continue
107
+ hbm_free = [limit - used for used, limit in hbm_usage]
108
+ max_kv_bytes = len(runner.kv_caches) * (
109
+ runner.max_model_len // runner.cache_config.block_size) * (
110
+ runner.kv_caches[0][0].nbytes) // len(hbm_free)
111
+ max_queue_len = min(hbm_free[0] // max_kv_bytes,
112
+ self._config.scheduler_config.max_num_seqs)
113
+ logger.debug(
114
+ f"max kv bytes: {max_kv_bytes}, max_queue_len {max_queue_len}")
115
+ self._decode_backlogs[idx] = queue.Queue(max_queue_len)
116
+
117
+ self._prefill_threads = [
118
+ JetThread(
119
+ target=functools.partial(self._prefill, idx),
120
+ name=f"prefill-{idx}",
121
+ daemon=True,
122
+ ) for idx in range(len(self._prefill_engines))
123
+ ]
124
+ self._transfer_threads = [
125
+ JetThread(
126
+ target=functools.partial(
127
+ self._transfer,
128
+ idx,
129
+ ),
130
+ name=f"transfer-{idx}",
131
+ daemon=True,
132
+ ) for idx in range(len(self._prefill_engines))
133
+ ]
134
+ self._decode_threads = [
135
+ JetThread(
136
+ target=functools.partial(
137
+ self._decode,
138
+ idx,
139
+ ),
140
+ name=f"decode-{idx}",
141
+ daemon=True,
142
+ ) for idx in range(len(self._decode_engines))
143
+ ]
144
+ self._all_threads = list(
145
+ itertools.chain(
146
+ self._prefill_threads,
147
+ self._transfer_threads,
148
+ self._decode_threads,
149
+ ))
150
+ self.live = True
151
+ # Start all threads
152
+ for t in self._all_threads:
153
+ t.start()
154
+
155
+ def add_request(self, request: Request):
156
+ """
157
+ Adds a new request to the orchestrator.
158
+
159
+ This is the main entry point for new work. It stores the request for
160
+ internal state tracking and hands it off to the first stage of the
161
+ processing pipeline (the prefill scheduler).
162
+ """
163
+ # Hand off the request to the prefill scheduler to be batched for execution.
164
+ self._prefill_engines[0].scheduler.add_request(request)
165
+
166
+ # Add to internal state for tracking by other threads.
167
+ # The key is the request_id, the value is the request object.
168
+ self._requests[request.request_id] = request
169
+
170
+ def _prefill(self, idx: int):
171
+ prefill_engine = self._prefill_engines[idx]
172
+ transfer_backlog = self._transfer_backlogs[idx]
173
+
174
+ while self.live:
175
+ if not prefill_engine.scheduler.has_requests():
176
+ time.sleep(0.05)
177
+ continue
178
+
179
+ scheduler_output = prefill_engine.scheduler.schedule()
180
+ with LatencyTracker(f"prefill-{idx}"):
181
+ future = prefill_engine.model_executor.execute_model(
182
+ scheduler_output, non_block=True)
183
+ grammar_output = prefill_engine.scheduler.get_grammar_bitmask(
184
+ scheduler_output)
185
+ with prefill_engine.log_error_detail(scheduler_output):
186
+ model_output = future.result()
187
+ if model_output is None:
188
+ model_output = prefill_engine.model_executor.sample_tokens(
189
+ grammar_output)
190
+ if isinstance(model_output, AsyncTPUModelRunnerOutput):
191
+ model_output = model_output.get_output()
192
+
193
+ if scheduler_output.total_num_scheduled_tokens > 0:
194
+ logger.debug(f"Prefill result: {model_output}")
195
+
196
+ kv_cache_map: dict[str, Tuple(list[jax.Array], list[Any])] = {}
197
+ for req_id, idx in model_output.req_id_to_index.items():
198
+ if len(model_output.sampled_token_ids[idx]) > 0:
199
+ request = self._requests[req_id]
200
+ block_ids = (prefill_engine.scheduler.kv_cache_manager.
201
+ get_block_ids(req_id))
202
+ # Assume one KV cache group for now.
203
+ kv_cache_map[req_id] = (
204
+ prefill_engine.model_executor.driver_worker.
205
+ model_runner.get_kv_cache_for_block_ids(
206
+ block_ids[0]), request.block_hashes)
207
+ logger.debug(f"prefill done: for {req_id}")
208
+ transfer_backlog.put(kv_cache_map, block=True)
209
+
210
+ # tweak model_output to let the scheduler know kv_transfer is done for requests, so they can be freed.
211
+ engine_core_outputs = prefill_engine.scheduler.update_from_output(
212
+ scheduler_output, model_output) # type: ignore
213
+
214
+ for req_id, idx in model_output.req_id_to_index.items():
215
+ if len(model_output.sampled_token_ids[idx]) > 0:
216
+ request = self._requests[req_id]
217
+ logger.debug(
218
+ f"request block_hashes at prefill: {request.block_hashes}"
219
+ )
220
+ logger.debug(
221
+ f"request-{req_id}: tokens={request.all_token_ids} after prefill"
222
+ )
223
+ # Remove request from the prefill engine.
224
+ if req_id in prefill_engine.scheduler.requests:
225
+ request = prefill_engine.scheduler.requests[req_id]
226
+ prefill_engine.scheduler.running.remove(request)
227
+ prefill_engine.scheduler.encoder_cache_manager.free(
228
+ request)
229
+
230
+ prefill_engine.scheduler.kv_cache_manager.free(
231
+ request)
232
+
233
+ prefill_engine.scheduler.requests.pop(req_id)
234
+
235
+ for output in (engine_core_outputs.items()
236
+ if engine_core_outputs else ()):
237
+ self._output_queue.put_nowait(output)
238
+
239
+ def _transfer(self, idx: int):
240
+ """Transfers the kv cache on an active request to the least full
241
+ decode backlog."""
242
+ transfer_backlog = self._transfer_backlogs[idx]
243
+ while self.live:
244
+ # The transfer thread can just sleep until it has work to do.
245
+ kv_cachce_map = transfer_backlog.get(block=True)
246
+ if kv_cachce_map is None:
247
+ break
248
+
249
+ logger.debug(
250
+ f"transfer-{idx}: KV Cache items received: {kv_cachce_map.keys()}"
251
+ )
252
+
253
+ push_targets = []
254
+ for req_id, (kv_cache, block_hashes) in kv_cachce_map.items():
255
+ target_idx = -1
256
+ cnt = 9999999
257
+ for i, e in enumerate(self._decode_engines):
258
+ req_cnt = sum(e.scheduler.get_request_counts())
259
+ if req_cnt < cnt:
260
+ cnt = req_cnt
261
+ target_idx = i
262
+
263
+ # Only transfer the KVCache for the disaggregated serving.
264
+ with LatencyTracker("KVCacheTransfer"):
265
+ kv_cache = self._decode_engines[
266
+ target_idx].model_executor.driver_worker.model_runner.transfer_kv_cache(
267
+ kv_cache)
268
+
269
+ # TODO(fhzhang): Now how do we get the kv cache to the decode engine?
270
+ prefill_output = {
271
+ "cache": kv_cache,
272
+ "req_id": req_id,
273
+ "block_hashes": block_hashes,
274
+ }
275
+ push_targets.append((target_idx, prefill_output))
276
+
277
+ for target_idx, prefill_output in push_targets:
278
+ self._decode_backlogs[target_idx].put(prefill_output,
279
+ block=True)
280
+ logger.debug(
281
+ "Successfully transferred prefill request %s "
282
+ "from prefill engine %d to decode engine %d. decode backlog len %d",
283
+ prefill_output["req_id"],
284
+ idx,
285
+ target_idx,
286
+ self._decode_backlogs[target_idx].qsize(),
287
+ )
288
+
289
+ def _decode(self, idx: int):
290
+ decode_engine = self._decode_engines[idx]
291
+ decode_backlog = self._decode_backlogs[idx]
292
+
293
+ while self.live:
294
+ block = not decode_engine.scheduler.has_requests()
295
+
296
+ while True:
297
+ # We need to check input batch as well as the request completion is delayed
298
+ # from scheduler to the runner.
299
+ if (sum(decode_engine.scheduler.get_request_counts())
300
+ >= self._config.scheduler_config.max_num_seqs
301
+ or decode_engine.model_executor.driver_worker.
302
+ model_runner.input_batch.num_reqs
303
+ >= self._config.scheduler_config.max_num_seqs):
304
+ break
305
+
306
+ try:
307
+ prefill_output = decode_backlog.get(block=block,
308
+ timeout=1.0)
309
+ except queue.Empty:
310
+ if block:
311
+ continue
312
+ break
313
+
314
+ if prefill_output is None:
315
+ logger.debug(
316
+ f"decode-{idx} Empty output, and we are idle, exiting..."
317
+ )
318
+ break
319
+
320
+ # We got a request, set block to False
321
+ block = False
322
+
323
+ # Insert the request to the decoder.
324
+ req_id = prefill_output["req_id"]
325
+ vllm_request = self._requests[req_id]
326
+ # Caching num_computed_tokens. The tokens in kv manager allocate blocks
327
+ # is computed as num_computed_tokens + num_new_tokens, so without caching
328
+ # the token number would double.
329
+ prompt_tokens = vllm_request.num_computed_tokens
330
+ vllm_request.num_computed_tokens = 0
331
+ kv_cache = prefill_output["cache"]
332
+
333
+ kv_cache_manager = decode_engine.scheduler.kv_cache_manager
334
+ kv_cache_manager.allocate_slots(
335
+ vllm_request,
336
+ prompt_tokens,
337
+ )
338
+ vllm_request.num_computed_tokens = prompt_tokens
339
+ new_block_ids = kv_cache_manager.get_block_ids(req_id)
340
+ logger.debug(
341
+ f"inserting {req_id} new_block_ids {new_block_ids}")
342
+ if len(new_block_ids[0]) != math.ceil(
343
+ prompt_tokens / self._config.cache_config.block_size):
344
+ logger.warning("Running out of blocks in decode engine! ")
345
+ break
346
+
347
+ decode_engine.model_executor.driver_worker.model_runner.insert_request_with_kv_cache(
348
+ vllm_request, kv_cache, new_block_ids)
349
+
350
+ vllm_request.status = RequestStatus.RUNNING
351
+ block_hashes = prefill_output["block_hashes"]
352
+ vllm_request.block_hashes = block_hashes
353
+ decode_engine.scheduler.running.append(vllm_request)
354
+ decode_engine.scheduler.requests[req_id] = vllm_request
355
+
356
+ self._requests.pop(req_id)
357
+
358
+ scheduler_output = decode_engine.scheduler.schedule()
359
+
360
+ logger.debug(f'''decode-{idx}: scheduler_output -
361
+ {scheduler_output.scheduled_cached_reqs.num_computed_tokens},
362
+ new block ids - {scheduler_output.scheduled_cached_reqs.new_block_ids}'''
363
+ )
364
+
365
+ with LatencyTracker(f"decode-{idx}"):
366
+ future = decode_engine.model_executor.execute_model(
367
+ scheduler_output, non_block=True)
368
+ grammar_output = decode_engine.scheduler.get_grammar_bitmask(
369
+ scheduler_output)
370
+ with decode_engine.log_error_detail(scheduler_output):
371
+ model_output = future.result()
372
+ if model_output is None:
373
+ model_output = decode_engine.model_executor.sample_tokens(
374
+ grammar_output)
375
+ if isinstance(model_output, AsyncTPUModelRunnerOutput):
376
+ model_output = model_output.get_output()
377
+
378
+ if scheduler_output.total_num_scheduled_tokens > 0:
379
+ logger.debug(f"Decode result: {model_output}")
380
+
381
+ engine_core_outputs = decode_engine.scheduler.update_from_output(
382
+ scheduler_output, model_output) # type: ignore
383
+ for output in (engine_core_outputs.items()
384
+ if engine_core_outputs else ()):
385
+ self._output_queue.put_nowait(output)
386
+
387
+ def shutdown(self):
388
+ for e in self._prefill_engines:
389
+ e.shutdown()
390
+ for e in self._decode_engines:
391
+ e.shutdown()
392
+
393
+
394
+ # ======================================================================================
395
+ # Class 2: The vLLM-Facing Adapter
396
+ # ======================================================================================
397
+
398
+
399
+ def _create_engine_cores(
400
+ slice_sizes: tuple[int, ...],
401
+ vllm_config: VllmConfig,
402
+ log_stats: bool,
403
+ executor_fail_callback: Optional[Callable] = None,
404
+ ) -> list[vLLMEngineCore]:
405
+ engine_cores = []
406
+ for _ in slice_sizes:
407
+ engine_core = vLLMEngineCore(
408
+ vllm_config,
409
+ disagg_executor.DisaggExecutor,
410
+ log_stats,
411
+ executor_fail_callback,
412
+ )
413
+
414
+ engine_cores.append(engine_core)
415
+ logger.warning("Disaggregated engine core created.")
416
+
417
+ return engine_cores
418
+
419
+
420
+ def _get_slice_sizes(devices):
421
+ prefill_slice_sizes = disagg_utils.get_prefill_slices()
422
+ decode_slice_sizes = disagg_utils.get_decode_slices()
423
+ if isinstance(prefill_slice_sizes[0], int):
424
+ prefill_chip_cnt = sum(prefill_slice_sizes)
425
+ else:
426
+ prefill_chip_cnt = sum([math.prod(t) for t in prefill_slice_sizes])
427
+ if isinstance(decode_slice_sizes[0], int):
428
+ decode_chip_cnt = sum(decode_slice_sizes)
429
+ else:
430
+ decode_chip_cnt = sum([math.prod(t) for t in decode_slice_sizes])
431
+ assert decode_chip_cnt + prefill_chip_cnt <= len(devices)
432
+ assert prefill_chip_cnt > 0 and decode_chip_cnt > 0
433
+
434
+ slice_sizes = list(prefill_slice_sizes)
435
+ slice_sizes.extend(decode_slice_sizes)
436
+ return prefill_slice_sizes, decode_slice_sizes, slice_sizes
437
+
438
+
439
+ class DisaggEngineCore(vLLMEngineCore):
440
+ """The vLLM-facing adapter that handles process management and I/O. Modifes vLLMEngineCore and is only used in in-process EngineCore client."""
441
+
442
+ @staticmethod
443
+ def is_supported() -> bool:
444
+ """
445
+ Returns True if this engine can run in the current environment.
446
+ """
447
+ return disagg_utils.is_disagg_enabled()
448
+
449
+ def __init__(
450
+ self,
451
+ vllm_config: VllmConfig,
452
+ executor_class: type[Executor],
453
+ log_stats: bool,
454
+ executor_fail_callback: Optional[Callable] = None,
455
+ ):
456
+ self.vllm_config = vllm_config
457
+
458
+ self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
459
+ bytes]]()
460
+
461
+ self.devices = jax.devices()
462
+ device_kind = self.devices[0].device_kind
463
+ if device_kind != 'TPU7x':
464
+ self.vllm_config.cache_config.gpu_memory_utilization = (
465
+ self.vllm_config.cache_config.gpu_memory_utilization - 0.1)
466
+ prefill_slice_sizes, decode_slice_sizes, slice_sizes = _get_slice_sizes(
467
+ self.devices)
468
+
469
+ if isinstance(slice_sizes[0], int):
470
+ setattr(vllm_config.device_config, "slice",
471
+ (0, slice_sizes, self.devices))
472
+ else:
473
+ setattr(vllm_config.device_config, "slice",
474
+ ((0, 0), 0, slice_sizes, self.devices))
475
+ logger.info(
476
+ f"Creating DisaggEngineCore with slice_sizes {slice_sizes}...")
477
+
478
+ self._prefill_engines = _create_engine_cores(
479
+ prefill_slice_sizes,
480
+ vllm_config,
481
+ log_stats,
482
+ executor_fail_callback,
483
+ )
484
+ logger.info(
485
+ f"{len(self._prefill_engines)} Disaggregated prefill engines created."
486
+ )
487
+
488
+ self._decode_engines = _create_engine_cores(
489
+ decode_slice_sizes,
490
+ vllm_config,
491
+ log_stats,
492
+ executor_fail_callback,
493
+ )
494
+ logger.info(
495
+ f"{len(self._decode_engines)} Disaggregated decode engines created."
496
+ )
497
+
498
+ self.batch_queue = None
499
+
500
+ self.request_block_hasher = None
501
+ if (self.vllm_config.cache_config.enable_prefix_caching
502
+ or self._prefill_engines[0].scheduler.get_kv_connector()
503
+ is not None):
504
+
505
+ block_size = vllm_config.cache_config.block_size
506
+ caching_hash_fn = common_utils.get_hash_fn_by_name(
507
+ vllm_config.cache_config.prefix_caching_hash_algo)
508
+ init_none_hash(caching_hash_fn)
509
+
510
+ self.request_block_hasher = get_request_block_hasher(
511
+ block_size, caching_hash_fn)
512
+
513
+ self.step_fn = (self.step if self.batch_queue is None else
514
+ self.step_with_batch_queue)
515
+
516
+ self.mm_receiver_cache = None
517
+ self._orchestrator = _DisaggOrchestrator(
518
+ config=vllm_config,
519
+ output_queue=self.output_queue,
520
+ prefill_engines=self._prefill_engines,
521
+ decode_engines=self._decode_engines,
522
+ prefill_slice_sizes=prefill_slice_sizes,
523
+ decode_slice_sizes=decode_slice_sizes,
524
+ )
525
+ # for vllm compatibility
526
+ self.model_executor = self._prefill_engines[0].model_executor
527
+
528
+ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
529
+ return self._prefill_engines[0].model_executor.supported_tasks
530
+
531
+ def add_request(self, request: Request, request_wave: int = 0):
532
+ if not isinstance(request.request_id, str):
533
+ raise TypeError(
534
+ f"request_id must be a string, got {type(request.request_id)}")
535
+
536
+ if pooling_params := request.pooling_params:
537
+ supported_pooling_tasks = [
538
+ task for task in self.get_supported_tasks()
539
+ if task in POOLING_TASKS
540
+ ]
541
+
542
+ if pooling_params.task not in supported_pooling_tasks:
543
+ raise ValueError(f"Unsupported task: {pooling_params.task!r} "
544
+ f"Supported tasks: {supported_pooling_tasks}")
545
+
546
+ if request.kv_transfer_params is not None and (
547
+ not self.scheduler.get_kv_connector()):
548
+ logger.warning("Got kv_transfer_params, but no KVConnector found. "
549
+ "Disabling KVTransfer for this request.")
550
+
551
+ self._orchestrator.add_request(request)
552
+
553
+ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
554
+ client_idx, output = self.output_queue.get()
555
+ # logger.warning(f"step output: {output}")
556
+ time.sleep(0.03)
557
+ return {client_idx: output}, True
558
+
559
+ def shutdown(self):
560
+ self._orchestrator.shutdown()
561
+
562
+ def reset_mm_cache(self):
563
+ # NOTE: Since this is mainly for debugging, we don't attempt to
564
+ # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
565
+ for engine in itertools.chain(self._prefill_engines,
566
+ self._decode_engines):
567
+ if engine.scheduler.has_unfinished_requests():
568
+ logger.warning(
569
+ "Resetting the multi-modal cache when requests are "
570
+ "in progress may lead to desynced internal caches.")
571
+
572
+ if engine.mm_receiver_cache is not None:
573
+ engine.mm_receiver_cache.clear_cache()
574
+
575
+ def reset_prefix_cache(self):
576
+ for engine in itertools.chain(self._prefill_engines,
577
+ self._decode_engines):
578
+ engine.scheduler.reset_prefix_cache()
579
+
580
+
581
+ class DisaggEngineCoreProc(vLLMEngineCoreProc):
582
+ """The vLLM-facing adapter that handles process management and I/O."""
583
+
584
+ @staticmethod
585
+ def is_supported() -> bool:
586
+ """
587
+ Returns True if this engine can run in the current environment.
588
+ """
589
+ return disagg_utils.is_disagg_enabled()
590
+
591
+ def __init__(
592
+ self,
593
+ vllm_config: VllmConfig,
594
+ local_client: bool,
595
+ handshake_address: str,
596
+ executor_class: type[Executor],
597
+ log_stats: bool,
598
+ client_handshake_address: Optional[str] = None,
599
+ engine_index: int = 0,
600
+ **kwargs,
601
+ ):
602
+ if 'dp_rank' in kwargs or 'local_dp_rank' in kwargs:
603
+ logger.debug(
604
+ "Ignoring data parallelism arguments for non-DP disaggregated engine."
605
+ )
606
+ # We don't invoke super class's ctor as we are not really the
607
+ # engine core to be executed, instead we create other instance of
608
+ # engine cores and let them do the work.
609
+ self.vllm_config = vllm_config
610
+
611
+ # We should be taking the input from the client, the code below is forked from
612
+ # vllm.v1.engine.core.EngineCoreProc.
613
+ self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
614
+ self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
615
+ bytes]]()
616
+
617
+ self.engine_index = engine_index
618
+ identity = self.engine_index.to_bytes(length=2, byteorder="little")
619
+ self.engines_running = False
620
+
621
+ self.devices = jax.devices()
622
+ device_kind = self.devices[0].device_kind
623
+ if device_kind != 'TPU7x':
624
+ self.vllm_config.cache_config.gpu_memory_utilization = (
625
+ self.vllm_config.cache_config.gpu_memory_utilization - 0.1)
626
+ prefill_slice_sizes, decode_slice_sizes, slice_sizes = _get_slice_sizes(
627
+ self.devices)
628
+
629
+ if isinstance(slice_sizes[0], int):
630
+ setattr(vllm_config.device_config, "slice",
631
+ (0, slice_sizes, self.devices))
632
+ else:
633
+ setattr(vllm_config.device_config, "slice",
634
+ ((0, 0), 0, slice_sizes, self.devices))
635
+ logger.info(
636
+ f"Creating DisaggEngineCoreProc with slice_sizes {slice_sizes}...")
637
+
638
+ def executor_fail_callback():
639
+ self.input_queue.put_nowait(
640
+ (EngineCoreRequestType.EXECUTOR_FAILED, b''))
641
+
642
+ # Don't complete handshake until DP coordinator ready message is
643
+ # received.
644
+ with self._perform_handshakes(handshake_address, identity,
645
+ local_client, vllm_config,
646
+ client_handshake_address) as addresses:
647
+ self.client_count = len(addresses.outputs)
648
+
649
+ # Set up data parallel environment.
650
+ self.has_coordinator = addresses.coordinator_output is not None
651
+ self.frontend_stats_publish_address = (
652
+ addresses.frontend_stats_publish_address)
653
+ self.publish_dp_lb_stats = (
654
+ self.has_coordinator
655
+ and not vllm_config.parallel_config.data_parallel_external_lb)
656
+ # Background Threads and Queues for IO. These enable us to
657
+ # overlap ZMQ socket IO with GPU since they release the GIL,
658
+ # and to overlap some serialization/deserialization with the
659
+ # model forward pass.
660
+ # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
661
+
662
+ self._prefill_engines = _create_engine_cores(
663
+ prefill_slice_sizes,
664
+ vllm_config,
665
+ log_stats,
666
+ executor_fail_callback,
667
+ )
668
+ logger.info(
669
+ f"{len(self._prefill_engines)} Disaggregated prefill engines created."
670
+ )
671
+
672
+ self._decode_engines = _create_engine_cores(
673
+ decode_slice_sizes,
674
+ vllm_config,
675
+ log_stats,
676
+ executor_fail_callback,
677
+ )
678
+ logger.info(
679
+ f"{len(self._decode_engines)} Disaggregated decode engines created."
680
+ )
681
+
682
+ ready_event = threading.Event()
683
+ input_thread = threading.Thread(target=self.process_input_sockets,
684
+ args=(addresses.inputs,
685
+ addresses.coordinator_input,
686
+ identity, ready_event),
687
+ daemon=True)
688
+ input_thread.start()
689
+
690
+ self.output_thread = threading.Thread(
691
+ target=self.process_output_sockets,
692
+ args=(addresses.outputs, addresses.coordinator_output,
693
+ self.engine_index),
694
+ daemon=True)
695
+ self.output_thread.start()
696
+ while not ready_event.wait(timeout=10):
697
+ if not input_thread.is_alive():
698
+ raise RuntimeError(
699
+ "Input socket thread died during startup")
700
+ if addresses.coordinator_input is not None:
701
+ logger.info(
702
+ "Waiting for READY message from DP Coordinator...")
703
+ self.request_block_hasher = None
704
+ if (self.vllm_config.cache_config.enable_prefix_caching
705
+ or self._prefill_engines[0].scheduler.get_kv_connector()
706
+ is not None):
707
+
708
+ block_size = vllm_config.cache_config.block_size
709
+ caching_hash_fn = common_utils.get_hash_fn_by_name(
710
+ vllm_config.cache_config.prefix_caching_hash_algo)
711
+ init_none_hash(caching_hash_fn)
712
+
713
+ self.request_block_hasher = get_request_block_hasher(
714
+ block_size, caching_hash_fn)
715
+
716
+ self.mm_receiver_cache = None
717
+ self._orchestrator = _DisaggOrchestrator(
718
+ config=vllm_config,
719
+ output_queue=self.output_queue,
720
+ prefill_engines=self._prefill_engines,
721
+ decode_engines=self._decode_engines,
722
+ prefill_slice_sizes=prefill_slice_sizes,
723
+ decode_slice_sizes=decode_slice_sizes,
724
+ )
725
+
726
+ def add_request(self, request: EngineCoreRequest, request_wave: int = 0):
727
+ if not isinstance(request.request_id, str):
728
+ raise TypeError(
729
+ f"request_id must be a string, got {type(request.request_id)}")
730
+
731
+ if pooling_params := request.pooling_params:
732
+ supported_pooling_tasks = [
733
+ task for task in self.get_supported_tasks()
734
+ if task in POOLING_TASKS
735
+ ]
736
+
737
+ if pooling_params.task not in supported_pooling_tasks:
738
+ raise ValueError(f"Unsupported task: {pooling_params.task!r} "
739
+ f"Supported tasks: {supported_pooling_tasks}")
740
+
741
+ self._orchestrator.add_request(request)
742
+
743
+ def _handle_client_request(self, request_type: EngineCoreRequestType,
744
+ request: Any) -> None:
745
+ """Dispatch request from client."""
746
+ if request_type == EngineCoreRequestType.ADD:
747
+ req, request_wave = request
748
+ self.add_request(req)
749
+ elif request_type == EngineCoreRequestType.ABORT:
750
+ # TODO(fhzhang): we need to keep track of which engine is processing
751
+ # the request and finish it there.
752
+ # owner_engine.scheduler.finish_requests(request, RequestStatus.FINISHED_ABORTED)
753
+ pass
754
+ elif request_type == EngineCoreRequestType.UTILITY:
755
+ client_idx, call_id, method_name, args = request
756
+ output = UtilityOutput(call_id)
757
+ try:
758
+ method = getattr(self._prefill_engines[0], method_name)
759
+ result = method(*self._convert_msgspec_args(method, args))
760
+ output.result = UtilityResult(result)
761
+ except BaseException as e:
762
+ logger.exception("Invocation of %s method failed", method_name)
763
+ output.failure_message = (f"Call to {method_name} method"
764
+ f" failed: {str(e)}")
765
+ self.output_queue.put_nowait(
766
+ (client_idx, EngineCoreOutputs(utility_output=output)))
767
+ elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
768
+ raise RuntimeError("Executor failed.")
769
+ else:
770
+ logger.error("Unrecognized input request type encountered: %s",
771
+ request_type)
772
+
773
+ def run_busy_loop(self):
774
+ """Core busy loop of the EngineCore."""
775
+
776
+ # Loop until process is sent a SIGINT or SIGTERM
777
+ while True:
778
+ while not self.input_queue.empty():
779
+ req = self.input_queue.get_nowait()
780
+ self._handle_client_request(*req)
781
+ # Yield control to other threads, as we are not doing any real work.
782
+ # Without this sleep, we'd be hogging all the cpu cycles with our run_busy_loop.
783
+ time.sleep(0.01)
784
+
785
+ def shutdown(self):
786
+ self._orchestrator.shutdown()