tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,776 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import functools
3
+ import itertools
4
+ import math
5
+ import os
6
+ import queue
7
+ import signal
8
+ import threading
9
+ import time
10
+ import traceback
11
+ from typing import Any, Callable, Optional, Tuple, TypeVar, Union
12
+
13
+ import jax
14
+ # ======================================================================================
15
+ # Imports for DisaggEngineCoreProc (the vLLM adapter)
16
+ # ======================================================================================
17
+ from vllm.config import VllmConfig
18
+ from vllm.logger import init_logger
19
+ from vllm.tasks import POOLING_TASKS, SupportedTask
20
+ from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
21
+ init_none_hash)
22
+ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
23
+ EngineCoreRequestType, UtilityOutput,
24
+ UtilityResult)
25
+ from vllm.v1.engine.core import EngineCore as vLLMEngineCore
26
+ from vllm.v1.engine.core import EngineCoreProc as vLLMEngineCoreProc
27
+ from vllm.v1.executor.abstract import Executor
28
+ from vllm.v1.request import Request, RequestStatus
29
+
30
+ from tpu_inference import utils as common_utils
31
+ from tpu_inference.core import disagg_executor, disagg_utils
32
+ # ======================================================================================
33
+ # Imports for _DisaggOrchestrator (decoupled from vLLM)
34
+ # ======================================================================================
35
+ from tpu_inference.interfaces.config import IConfig
36
+ from tpu_inference.interfaces.engine import IEngineCore
37
+ from tpu_inference.interfaces.request import IRequest
38
+ from tpu_inference.runner.utils import LatencyTracker
39
+
40
+ from .adapters import VllmConfigAdapter, VllmEngineAdapter, VllmRequestAdapter
41
+
42
+ # This file contains two classes:
43
+ # 1. _DisaggOrchestrator: The clean, decoupled core orchestration logic.
44
+ # 2. DisaggEngineCoreProc: The vLLM-facing adapter that handles process management.
45
+
46
+ logger = init_logger(__name__)
47
+
48
+ POLLING_TIMEOUT_S = 2.5
49
+ HANDSHAKE_TIMEOUT_MINS = 5
50
+
51
+ _R = TypeVar('_R') # Return type for collective_rpc
52
+
53
+ # ======================================================================================
54
+ # Class 1: The Decoupled Orchestrator
55
+ # ======================================================================================
56
+
57
+
58
+ class JetThread(threading.Thread):
59
+ """Thread that kills the program if it fails.
60
+
61
+ If a driver thread goes down, we can't operate.
62
+ """
63
+
64
+ def run(self):
65
+ try:
66
+ super().run()
67
+ except Exception as e: # pylint: disable=broad-exception-caught
68
+ print(f"Thread {self.name} encountered an error: {e}")
69
+ traceback.print_exc()
70
+ os.kill(os.getpid(), signal.SIGKILL)
71
+
72
+
73
+ class _DisaggOrchestrator:
74
+ """Contains the core orchestration logic, decoupled from vLLM."""
75
+
76
+ def __init__(
77
+ self,
78
+ config: IConfig,
79
+ output_queue: queue.Queue,
80
+ prefill_engines: list[IEngineCore],
81
+ decode_engines: list[IEngineCore],
82
+ prefill_slice_sizes: tuple[int, ...],
83
+ decode_slice_sizes: tuple[int, ...],
84
+ ):
85
+ self._config = config
86
+ self._output_queue = output_queue
87
+ self._prefill_engines = prefill_engines
88
+ self._decode_engines = decode_engines
89
+
90
+ # Keep track of active requests.
91
+ self._requests: dict[str, IRequest] = {}
92
+
93
+ # Hack device config to pass in the subslice of TPUs.
94
+ slice_sizes = list(prefill_slice_sizes)
95
+ slice_sizes.extend(decode_slice_sizes)
96
+
97
+ self._transfer_backlogs = [
98
+ queue.Queue(4) for i in range(len(self._prefill_engines))
99
+ ]
100
+
101
+ self._decode_backlogs = {}
102
+ for idx, engine in enumerate(self._decode_engines):
103
+ # Determine the decode backlog len by remaning hbm dividing max kv cache size of a single request
104
+ runner = engine.model_executor.driver_worker.model_runner
105
+ hbm_usage = common_utils.hbm_usage_bytes(
106
+ engine.model_executor.driver_worker.devices)
107
+ if not hbm_usage:
108
+ self._decode_backlogs[idx] = queue.Queue(
109
+ self._config.scheduler_config.max_num_seqs)
110
+ continue
111
+ hbm_free = [limit - used for used, limit in hbm_usage]
112
+ max_kv_bytes = len(runner.kv_caches) * (
113
+ runner.max_model_len // runner.cache_config.block_size) * (
114
+ runner.kv_caches[0][0].nbytes) // len(hbm_free)
115
+ max_queue_len = min(hbm_free[0] // max_kv_bytes,
116
+ self._config.scheduler_config.max_num_seqs)
117
+ logger.debug(
118
+ f"max kv bytes: {max_kv_bytes}, max_queue_len {max_queue_len}")
119
+ self._decode_backlogs[idx] = queue.Queue(max_queue_len)
120
+
121
+ self._prefill_threads = [
122
+ JetThread(
123
+ target=functools.partial(self._prefill, idx),
124
+ name=f"prefill-{idx}",
125
+ daemon=True,
126
+ ) for idx in range(len(self._prefill_engines))
127
+ ]
128
+ self._transfer_threads = [
129
+ JetThread(
130
+ target=functools.partial(
131
+ self._transfer,
132
+ idx,
133
+ ),
134
+ name=f"transfer-{idx}",
135
+ daemon=True,
136
+ ) for idx in range(len(self._prefill_engines))
137
+ ]
138
+ self._decode_threads = [
139
+ JetThread(
140
+ target=functools.partial(
141
+ self._decode,
142
+ idx,
143
+ ),
144
+ name=f"decode-{idx}",
145
+ daemon=True,
146
+ ) for idx in range(len(self._decode_engines))
147
+ ]
148
+ self._all_threads = list(
149
+ itertools.chain(
150
+ self._prefill_threads,
151
+ self._transfer_threads,
152
+ self._decode_threads,
153
+ ))
154
+ self.live = True
155
+ # Start all threads
156
+ for t in self._all_threads:
157
+ t.start()
158
+
159
+ def add_request(self, request: IRequest):
160
+ """
161
+ Adds a new request to the orchestrator.
162
+
163
+ This is the main entry point for new work. It stores the request for
164
+ internal state tracking and hands it off to the first stage of the
165
+ processing pipeline (the prefill scheduler).
166
+ """
167
+ # Hand off the request to the prefill scheduler to be batched for
168
+ # execution. Note: We are calling add_request on our IScheduler
169
+ # interface. The VllmSchedulerAdapter will handle unwrapping the
170
+ # IRequest before passing it to the real vllm scheduler.
171
+ self._prefill_engines[0].scheduler.add_request(request)
172
+
173
+ # Add to internal state for tracking by other threads.
174
+ # The key is the request_id, the value is the adapted request object.
175
+ self._requests[request.vllm_request.request_id] = request
176
+
177
+ def _prefill(self, idx: int):
178
+ prefill_engine = self._prefill_engines[idx]
179
+ transfer_backlog = self._transfer_backlogs[idx]
180
+
181
+ while self.live:
182
+ if not prefill_engine.scheduler.has_requests():
183
+ time.sleep(0.05)
184
+ continue
185
+
186
+ scheduler_output = prefill_engine.scheduler.schedule()
187
+ with LatencyTracker(f"prefill-{idx}"):
188
+ model_output = prefill_engine.execute_model_with_error_logging(
189
+ prefill_engine.model_executor.execute_model,
190
+ scheduler_output)
191
+
192
+ if scheduler_output.total_num_scheduled_tokens > 0:
193
+ logger.debug(f"Prefill result: {model_output}")
194
+
195
+ kv_cache_map: dict[str, Tuple(list[jax.Array], list[Any])] = {}
196
+ for req_id, idx in model_output.req_id_to_index.items():
197
+ if len(model_output.sampled_token_ids[idx]) > 0:
198
+ request = self._requests[req_id]
199
+ block_ids = (prefill_engine.scheduler.kv_cache_manager.
200
+ get_block_ids(req_id))
201
+ # Assume one KV cache group for now.
202
+ kv_cache_map[req_id] = (
203
+ prefill_engine.model_executor.driver_worker.
204
+ model_runner.get_kv_cache_for_block_ids(
205
+ block_ids[0]), request.block_hashes)
206
+ logger.debug(f"prefill done: for {req_id}")
207
+ transfer_backlog.put(kv_cache_map, block=True)
208
+
209
+ # tweak model_output to let the scheduler know kv_transfer is done for requests, so they can be freed.
210
+ engine_core_outputs = prefill_engine.scheduler.update_from_output(
211
+ scheduler_output, model_output) # type: ignore
212
+
213
+ for req_id, idx in model_output.req_id_to_index.items():
214
+ if len(model_output.sampled_token_ids[idx]) > 0:
215
+ request = self._requests[req_id].vllm_request
216
+ logger.debug(
217
+ f"request block_hashes at prefill: {request.block_hashes}"
218
+ )
219
+ logger.debug(
220
+ f"request-{req_id}: tokens={request.all_token_ids} after prefill"
221
+ )
222
+ # Remove request from the prefill engine.
223
+
224
+ request = prefill_engine.scheduler.requests[req_id]
225
+ prefill_engine.scheduler.running.remove(request)
226
+ prefill_engine.scheduler.encoder_cache_manager.free(
227
+ request)
228
+
229
+ prefill_engine.scheduler.kv_cache_manager.free(request)
230
+
231
+ prefill_engine.scheduler.requests.pop(req_id)
232
+
233
+ for output in (engine_core_outputs.items()
234
+ if engine_core_outputs else ()):
235
+ self._output_queue.put_nowait(output)
236
+
237
+ def _transfer(self, idx: int):
238
+ """Transfers the kv cache on an active request to the least full
239
+ decode backlog."""
240
+ transfer_backlog = self._transfer_backlogs[idx]
241
+ while self.live:
242
+ # The transfer thread can just sleep until it has work to do.
243
+ kv_cachce_map = transfer_backlog.get(block=True)
244
+ if kv_cachce_map is None:
245
+ break
246
+
247
+ logger.debug(
248
+ f"transfer-{idx}: KV Cache items received: {kv_cachce_map.keys()}"
249
+ )
250
+
251
+ push_targets = []
252
+ for req_id, (kv_cache, block_hashes) in kv_cachce_map.items():
253
+ target_idx = -1
254
+ cnt = 9999999
255
+ for i, e in enumerate(self._decode_engines):
256
+ req_cnt = sum(e.scheduler.get_request_counts())
257
+ if req_cnt < cnt:
258
+ cnt = req_cnt
259
+ target_idx = i
260
+
261
+ # Only transfer the KVCache for the disaggregated serving.
262
+ with LatencyTracker("KVCacheTransfer"):
263
+ kv_cache = self._decode_engines[
264
+ target_idx].model_executor.driver_worker.model_runner.transfer_kv_cache(
265
+ kv_cache)
266
+
267
+ # TODO(fhzhang): Now how do we get the kv cache to the decode engine?
268
+ prefill_output = {
269
+ "cache": kv_cache,
270
+ "req_id": req_id,
271
+ "block_hashes": block_hashes,
272
+ }
273
+ push_targets.append((target_idx, prefill_output))
274
+
275
+ for target_idx, prefill_output in push_targets:
276
+ self._decode_backlogs[target_idx].put(prefill_output,
277
+ block=True)
278
+ logger.debug(
279
+ "Successfully transferred prefill request %s "
280
+ "from prefill engine %d to decode engine %d. decode backlog len %d",
281
+ prefill_output["req_id"],
282
+ idx,
283
+ target_idx,
284
+ self._decode_backlogs[target_idx].qsize(),
285
+ )
286
+
287
+ def _decode(self, idx: int):
288
+ decode_engine = self._decode_engines[idx]
289
+ decode_backlog = self._decode_backlogs[idx]
290
+
291
+ while self.live:
292
+ block = not decode_engine.scheduler.has_requests()
293
+
294
+ while True:
295
+ # We need to check input batch as well as the request completion is delayed
296
+ # from scheduler to the runner.
297
+ if (sum(decode_engine.scheduler.get_request_counts())
298
+ >= self._config.scheduler_config.max_num_seqs
299
+ or decode_engine.model_executor.driver_worker.
300
+ model_runner.input_batch.num_reqs
301
+ >= self._config.scheduler_config.max_num_seqs):
302
+ break
303
+
304
+ try:
305
+ prefill_output = decode_backlog.get(block=block,
306
+ timeout=1.0)
307
+ except queue.Empty:
308
+ if block:
309
+ continue
310
+ break
311
+
312
+ if prefill_output is None:
313
+ logger.debug(
314
+ f"decode-{idx} Empty output, and we are idle, exiting..."
315
+ )
316
+ break
317
+
318
+ # We got a request, set block to False
319
+ block = False
320
+
321
+ # Insert the request to the decoder.
322
+ req_id = prefill_output["req_id"]
323
+ vllm_request = self._requests[req_id].vllm_request
324
+ # Caching num_computed_tokens. The tokens in kv manager allocate blocks
325
+ # is computed as num_computed_tokens + num_new_tokens, so without caching
326
+ # the token number would double.
327
+ prompt_tokens = vllm_request.num_computed_tokens
328
+ vllm_request.num_computed_tokens = 0
329
+ kv_cache = prefill_output["cache"]
330
+
331
+ kv_cache_manager = decode_engine.scheduler.kv_cache_manager
332
+ kv_cache_manager.allocate_slots(
333
+ vllm_request,
334
+ prompt_tokens,
335
+ )
336
+ vllm_request.num_computed_tokens = prompt_tokens
337
+ new_block_ids = kv_cache_manager.get_block_ids(req_id)
338
+ logger.debug(
339
+ f"inserting {req_id} new_block_ids {new_block_ids}")
340
+ assert (len(new_block_ids[0]) == math.ceil(
341
+ prompt_tokens / self._config.cache_config.block_size))
342
+
343
+ decode_engine.model_executor.driver_worker.model_runner.insert_request_with_kv_cache(
344
+ vllm_request, kv_cache, new_block_ids)
345
+
346
+ vllm_request.status = RequestStatus.RUNNING
347
+ block_hashes = prefill_output["block_hashes"]
348
+ vllm_request.block_hashes = block_hashes
349
+ decode_engine.scheduler.running.append(vllm_request)
350
+ decode_engine.scheduler.requests[req_id] = vllm_request
351
+
352
+ self._requests.pop(req_id)
353
+
354
+ scheduler_output = decode_engine.scheduler.schedule()
355
+
356
+ logger.debug(f'''decode-{idx}: scheduler_output -
357
+ {scheduler_output.scheduled_cached_reqs.num_computed_tokens},
358
+ new block ids - {scheduler_output.scheduled_cached_reqs.new_block_ids}'''
359
+ )
360
+
361
+ with LatencyTracker(f"decode-{idx}"):
362
+ model_output = decode_engine.execute_model_with_error_logging(
363
+ decode_engine.model_executor.execute_model,
364
+ scheduler_output)
365
+ if scheduler_output.total_num_scheduled_tokens > 0:
366
+ logger.debug(f"Decode result: {model_output}")
367
+
368
+ engine_core_outputs = decode_engine.scheduler.update_from_output(
369
+ scheduler_output, model_output) # type: ignore
370
+ for output in (engine_core_outputs.items()
371
+ if engine_core_outputs else ()):
372
+ self._output_queue.put_nowait(output)
373
+
374
+ def shutdown(self):
375
+ for e in self._prefill_engines:
376
+ e.shutdown()
377
+ for e in self._decode_engines:
378
+ e.shutdown()
379
+
380
+
381
+ # ======================================================================================
382
+ # Class 2: The vLLM-Facing Adapter
383
+ # ======================================================================================
384
+
385
+
386
+ def _create_engine_cores(
387
+ slice_sizes: tuple[int, ...],
388
+ vllm_config: VllmConfig,
389
+ log_stats: bool,
390
+ executor_fail_callback: Optional[Callable] = None,
391
+ ) -> list[vLLMEngineCore]:
392
+ engine_cores = []
393
+ for _ in slice_sizes:
394
+ engine_core = vLLMEngineCore(
395
+ vllm_config,
396
+ disagg_executor.DisaggExecutor,
397
+ log_stats,
398
+ executor_fail_callback,
399
+ )
400
+
401
+ engine_cores.append(engine_core)
402
+ logger.warning("Disaggregated engine core created.")
403
+
404
+ return engine_cores
405
+
406
+
407
+ def _get_slice_sizes(devices):
408
+ prefill_slice_sizes = disagg_utils.get_prefill_slices()
409
+ decode_slice_sizes = disagg_utils.get_decode_slices()
410
+ if isinstance(prefill_slice_sizes[0], int):
411
+ prefill_chip_cnt = sum(prefill_slice_sizes)
412
+ else:
413
+ prefill_chip_cnt = sum([math.prod(t) for t in prefill_slice_sizes])
414
+ if isinstance(decode_slice_sizes[0], int):
415
+ decode_chip_cnt = sum(decode_slice_sizes)
416
+ else:
417
+ decode_chip_cnt = sum([math.prod(t) for t in decode_slice_sizes])
418
+ assert decode_chip_cnt + prefill_chip_cnt <= len(devices)
419
+ assert prefill_chip_cnt > 0 and decode_chip_cnt > 0
420
+
421
+ slice_sizes = list(prefill_slice_sizes)
422
+ slice_sizes.extend(decode_slice_sizes)
423
+ return prefill_slice_sizes, decode_slice_sizes, slice_sizes
424
+
425
+
426
+ class DisaggEngineCore(vLLMEngineCore):
427
+ """The vLLM-facing adapter that handles process management and I/O. Modifes vLLMEngineCore and is only used in in-process EngineCore client."""
428
+
429
+ @staticmethod
430
+ def is_supported() -> bool:
431
+ """
432
+ Returns True if this engine can run in the current environment.
433
+ """
434
+ return disagg_utils.is_disagg_enabled()
435
+
436
+ def __init__(
437
+ self,
438
+ vllm_config: VllmConfig,
439
+ executor_class: type[Executor],
440
+ log_stats: bool,
441
+ executor_fail_callback: Optional[Callable] = None,
442
+ ):
443
+ self.vllm_config = vllm_config
444
+ self.vllm_config.cache_config.gpu_memory_utilization = (
445
+ self.vllm_config.cache_config.gpu_memory_utilization - 0.1)
446
+
447
+ self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
448
+ bytes]]()
449
+
450
+ self.devices = jax.devices()
451
+ prefill_slice_sizes, decode_slice_sizes, slice_sizes = _get_slice_sizes(
452
+ self.devices)
453
+
454
+ if isinstance(slice_sizes[0], int):
455
+ setattr(vllm_config.device_config, "slice",
456
+ (0, slice_sizes, self.devices))
457
+ else:
458
+ setattr(vllm_config.device_config, "slice",
459
+ ((0, 0), 0, slice_sizes, self.devices))
460
+ logger.info(
461
+ f"Creating DisaggEngineCore with slice_sizes {slice_sizes}...")
462
+
463
+ self._prefill_engines = _create_engine_cores(
464
+ prefill_slice_sizes,
465
+ vllm_config,
466
+ log_stats,
467
+ executor_fail_callback,
468
+ )
469
+ logger.info(
470
+ f"{len(self._prefill_engines)} Disaggregated prefill engines created."
471
+ )
472
+
473
+ self._decode_engines = _create_engine_cores(
474
+ decode_slice_sizes,
475
+ vllm_config,
476
+ log_stats,
477
+ executor_fail_callback,
478
+ )
479
+ logger.info(
480
+ f"{len(self._decode_engines)} Disaggregated decode engines created."
481
+ )
482
+
483
+ self.batch_queue = None
484
+
485
+ self.request_block_hasher = None
486
+ if (self.vllm_config.cache_config.enable_prefix_caching
487
+ or self._prefill_engines[0].scheduler.get_kv_connector()
488
+ is not None):
489
+
490
+ block_size = vllm_config.cache_config.block_size
491
+ caching_hash_fn = common_utils.get_hash_fn_by_name(
492
+ vllm_config.cache_config.prefix_caching_hash_algo)
493
+ init_none_hash(caching_hash_fn)
494
+
495
+ self.request_block_hasher = get_request_block_hasher(
496
+ block_size, caching_hash_fn)
497
+
498
+ self.step_fn = (self.step if self.batch_queue is None else
499
+ self.step_with_batch_queue)
500
+
501
+ self.mm_receiver_cache = None
502
+ self._orchestrator = _DisaggOrchestrator(
503
+ config=VllmConfigAdapter(vllm_config),
504
+ output_queue=self.output_queue,
505
+ prefill_engines=[
506
+ VllmEngineAdapter(e) for e in self._prefill_engines
507
+ ],
508
+ decode_engines=[
509
+ VllmEngineAdapter(e) for e in self._decode_engines
510
+ ],
511
+ prefill_slice_sizes=prefill_slice_sizes,
512
+ decode_slice_sizes=decode_slice_sizes,
513
+ )
514
+ # for vllm compatibility
515
+ self.model_executor = self._prefill_engines[0].model_executor
516
+
517
+ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
518
+ return self._prefill_engines[0].model_executor.supported_tasks
519
+
520
+ def add_request(self, request: Request, request_wave: int = 0):
521
+ if not isinstance(request.request_id, str):
522
+ raise TypeError(
523
+ f"request_id must be a string, got {type(request.request_id)}")
524
+
525
+ if pooling_params := request.pooling_params:
526
+ supported_pooling_tasks = [
527
+ task for task in self.get_supported_tasks()
528
+ if task in POOLING_TASKS
529
+ ]
530
+
531
+ if pooling_params.task not in supported_pooling_tasks:
532
+ raise ValueError(f"Unsupported task: {pooling_params.task!r} "
533
+ f"Supported tasks: {supported_pooling_tasks}")
534
+
535
+ if request.kv_transfer_params is not None and (
536
+ not self.scheduler.get_kv_connector()):
537
+ logger.warning("Got kv_transfer_params, but no KVConnector found. "
538
+ "Disabling KVTransfer for this request.")
539
+
540
+ self._orchestrator.add_request(VllmRequestAdapter(request))
541
+
542
+ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
543
+ client_idx, output = self.output_queue.get()
544
+ # logger.warning(f"step output: {output}")
545
+ time.sleep(0.03)
546
+ return {client_idx: output}, True
547
+
548
+ def shutdown(self):
549
+ self._orchestrator.shutdown()
550
+
551
+ def reset_mm_cache(self):
552
+ # NOTE: Since this is mainly for debugging, we don't attempt to
553
+ # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
554
+ for engine in itertools.chain(self._prefill_engines,
555
+ self._decode_engines):
556
+ if engine.scheduler.has_unfinished_requests():
557
+ logger.warning(
558
+ "Resetting the multi-modal cache when requests are "
559
+ "in progress may lead to desynced internal caches.")
560
+
561
+ if engine.mm_receiver_cache is not None:
562
+ engine.mm_receiver_cache.clear_cache()
563
+
564
+ def reset_prefix_cache(self):
565
+ for engine in itertools.chain(self._prefill_engines,
566
+ self._decode_engines):
567
+ engine.scheduler.reset_prefix_cache()
568
+
569
+
570
+ class DisaggEngineCoreProc(vLLMEngineCoreProc):
571
+ """The vLLM-facing adapter that handles process management and I/O."""
572
+
573
+ @staticmethod
574
+ def is_supported() -> bool:
575
+ """
576
+ Returns True if this engine can run in the current environment.
577
+ """
578
+ return disagg_utils.is_disagg_enabled()
579
+
580
+ def __init__(
581
+ self,
582
+ vllm_config: VllmConfig,
583
+ local_client: bool,
584
+ handshake_address: str,
585
+ executor_class: type[Executor],
586
+ log_stats: bool,
587
+ client_handshake_address: Optional[str] = None,
588
+ engine_index: int = 0,
589
+ **kwargs,
590
+ ):
591
+ if 'dp_rank' in kwargs or 'local_dp_rank' in kwargs:
592
+ logger.debug(
593
+ "Ignoring data parallelism arguments for non-DP disaggregated engine."
594
+ )
595
+ # We don't invoke super class's ctor as we are not really the
596
+ # engine core to be executed, instead we create other instance of
597
+ # engine cores and let them do the work.
598
+ self.vllm_config = vllm_config
599
+ self.vllm_config.cache_config.gpu_memory_utilization = self.vllm_config.cache_config.gpu_memory_utilization - 0.1
600
+
601
+ # We should be taking the input from the client, the code below is forked from
602
+ # vllm.v1.engine.core.EngineCoreProc.
603
+ self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
604
+ self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
605
+ bytes]]()
606
+
607
+ self.engine_index = engine_index
608
+ identity = self.engine_index.to_bytes(length=2, byteorder="little")
609
+ self.engines_running = False
610
+
611
+ self.devices = jax.devices()
612
+ prefill_slice_sizes, decode_slice_sizes, slice_sizes = _get_slice_sizes(
613
+ self.devices)
614
+
615
+ if isinstance(slice_sizes[0], int):
616
+ setattr(vllm_config.device_config, "slice",
617
+ (0, slice_sizes, self.devices))
618
+ else:
619
+ setattr(vllm_config.device_config, "slice",
620
+ ((0, 0), 0, slice_sizes, self.devices))
621
+ logger.info(
622
+ f"Creating DisaggEngineCoreProc with slice_sizes {slice_sizes}...")
623
+
624
+ def executor_fail_callback():
625
+ self.input_queue.put_nowait(
626
+ (EngineCoreRequestType.EXECUTOR_FAILED, b''))
627
+
628
+ # Don't complete handshake until DP coordinator ready message is
629
+ # received.
630
+ with self._perform_handshakes(handshake_address, identity,
631
+ local_client, vllm_config,
632
+ client_handshake_address) as addresses:
633
+ self.client_count = len(addresses.outputs)
634
+
635
+ # Set up data parallel environment.
636
+ self.has_coordinator = addresses.coordinator_output is not None
637
+ self.frontend_stats_publish_address = (
638
+ addresses.frontend_stats_publish_address)
639
+ self.publish_dp_lb_stats = (
640
+ self.has_coordinator
641
+ and not vllm_config.parallel_config.data_parallel_external_lb)
642
+ # Background Threads and Queues for IO. These enable us to
643
+ # overlap ZMQ socket IO with GPU since they release the GIL,
644
+ # and to overlap some serialization/deserialization with the
645
+ # model forward pass.
646
+ # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
647
+
648
+ self._prefill_engines = _create_engine_cores(
649
+ prefill_slice_sizes,
650
+ vllm_config,
651
+ log_stats,
652
+ executor_fail_callback,
653
+ )
654
+ logger.info(
655
+ f"{len(self._prefill_engines)} Disaggregated prefill engines created."
656
+ )
657
+
658
+ self._decode_engines = _create_engine_cores(
659
+ decode_slice_sizes,
660
+ vllm_config,
661
+ log_stats,
662
+ executor_fail_callback,
663
+ )
664
+ logger.info(
665
+ f"{len(self._decode_engines)} Disaggregated decode engines created."
666
+ )
667
+
668
+ ready_event = threading.Event()
669
+ input_thread = threading.Thread(target=self.process_input_sockets,
670
+ args=(addresses.inputs,
671
+ addresses.coordinator_input,
672
+ identity, ready_event),
673
+ daemon=True)
674
+ input_thread.start()
675
+
676
+ self.output_thread = threading.Thread(
677
+ target=self.process_output_sockets,
678
+ args=(addresses.outputs, addresses.coordinator_output,
679
+ self.engine_index),
680
+ daemon=True)
681
+ self.output_thread.start()
682
+ while not ready_event.wait(timeout=10):
683
+ if not input_thread.is_alive():
684
+ raise RuntimeError(
685
+ "Input socket thread died during startup")
686
+ if addresses.coordinator_input is not None:
687
+ logger.info(
688
+ "Waiting for READY message from DP Coordinator...")
689
+ self.request_block_hasher = None
690
+ if (self.vllm_config.cache_config.enable_prefix_caching
691
+ or self._prefill_engines[0].scheduler.get_kv_connector()
692
+ is not None):
693
+
694
+ block_size = vllm_config.cache_config.block_size
695
+ caching_hash_fn = common_utils.get_hash_fn_by_name(
696
+ vllm_config.cache_config.prefix_caching_hash_algo)
697
+ init_none_hash(caching_hash_fn)
698
+
699
+ self.request_block_hasher = get_request_block_hasher(
700
+ block_size, caching_hash_fn)
701
+
702
+ self.mm_receiver_cache = None
703
+ self._orchestrator = _DisaggOrchestrator(
704
+ config=VllmConfigAdapter(vllm_config),
705
+ output_queue=self.output_queue,
706
+ prefill_engines=[
707
+ VllmEngineAdapter(e) for e in self._prefill_engines
708
+ ],
709
+ decode_engines=[
710
+ VllmEngineAdapter(e) for e in self._decode_engines
711
+ ],
712
+ prefill_slice_sizes=prefill_slice_sizes,
713
+ decode_slice_sizes=decode_slice_sizes,
714
+ )
715
+
716
+ def add_request(self, request: EngineCoreRequest, request_wave: int = 0):
717
+ if not isinstance(request.request_id, str):
718
+ raise TypeError(
719
+ f"request_id must be a string, got {type(request.request_id)}")
720
+
721
+ if pooling_params := request.pooling_params:
722
+ supported_pooling_tasks = [
723
+ task for task in self.get_supported_tasks()
724
+ if task in POOLING_TASKS
725
+ ]
726
+
727
+ if pooling_params.task not in supported_pooling_tasks:
728
+ raise ValueError(f"Unsupported task: {pooling_params.task!r} "
729
+ f"Supported tasks: {supported_pooling_tasks}")
730
+
731
+ self._orchestrator.add_request(VllmRequestAdapter(request))
732
+
733
+ def _handle_client_request(self, request_type: EngineCoreRequestType,
734
+ request: Any) -> None:
735
+ """Dispatch request from client."""
736
+ if request_type == EngineCoreRequestType.ADD:
737
+ req, request_wave = request
738
+ self.add_request(req)
739
+ elif request_type == EngineCoreRequestType.ABORT:
740
+ # TODO(fhzhang): we need to keep track of which engine is processing
741
+ # the request and finish it there.
742
+ # owner_engine.scheduler.finish_requests(request, RequestStatus.FINISHED_ABORTED)
743
+ pass
744
+ elif request_type == EngineCoreRequestType.UTILITY:
745
+ client_idx, call_id, method_name, args = request
746
+ output = UtilityOutput(call_id)
747
+ try:
748
+ method = getattr(self._prefill_engines[0], method_name)
749
+ result = method(*self._convert_msgspec_args(method, args))
750
+ output.result = UtilityResult(result)
751
+ except BaseException as e:
752
+ logger.exception("Invocation of %s method failed", method_name)
753
+ output.failure_message = (f"Call to {method_name} method"
754
+ f" failed: {str(e)}")
755
+ self.output_queue.put_nowait(
756
+ (client_idx, EngineCoreOutputs(utility_output=output)))
757
+ elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
758
+ raise RuntimeError("Executor failed.")
759
+ else:
760
+ logger.error("Unrecognized input request type encountered: %s",
761
+ request_type)
762
+
763
+ def run_busy_loop(self):
764
+ """Core busy loop of the EngineCore."""
765
+
766
+ # Loop until process is sent a SIGINT or SIGTERM
767
+ while True:
768
+ while not self.input_queue.empty():
769
+ req = self.input_queue.get_nowait()
770
+ self._handle_client_request(*req)
771
+ # Yield control to other threads, as we are not doing any real work.
772
+ # Without this sleep, we'd be hogging all the cpu cycles with our run_busy_loop.
773
+ time.sleep(0.01)
774
+
775
+ def shutdown(self):
776
+ self._orchestrator.shutdown()