tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,699 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """
4
+ Proxy server routes the request to P with max_output_tokens=1
5
+
6
+ P workflow:
7
+ P recives the request
8
+
9
+ P scheduler checks if the prefill is full done in `request_finished()`
10
+ If done:
11
+ P puts the request-id in `scheduler_output.finished_req_ids`
12
+ and puts the request in `scheduler_output.kv_connector_metadata.reqs_to_send`
13
+ P responds the proxy server with `finished_req_ids` and the `kv_transfer_params`
14
+ P worker gets `reqs_to_send` and runs async `_prepare_kv_and_wait()`
15
+ Else:
16
+ P schedules the prefill with multiple turns due to chunked-prefill.
17
+
18
+ P worker checks if the request has been pulled by D
19
+ If done:
20
+ P worker puts the request-id in `done_sending()`
21
+ P scheduler frees blocks for the requet in done sending.
22
+ Else:
23
+ P holds the blocks for the request until it's pulled by D
24
+
25
+ (
26
+ One scheduler step can finish:
27
+ scheduler RUNNING -> connector reqs_to_send -> worker prefill -> output
28
+ The waiting buffer will get freed after notified by D or expired.
29
+ )
30
+
31
+ Proxy server recives the response from P and forwards it to D
32
+
33
+ D workflow:
34
+ D recives the request
35
+
36
+ D scheduler calculates the num of tokens needing to pull from P in `get_num_new_matched_tokens()`
37
+ D checks if need to pull from P
38
+ If true:
39
+ D puts the request in `scheduler_output.kv_connector_metadata.reqs_to_load`
40
+ D worker gets `reqs_to_load` and runs `_pull_and_write_kv()` in separate threads (to be async)
41
+ D worker checks if the async loading is done:
42
+ If done:
43
+ D worker puts the request-id in `done_recving`.
44
+ D scheduler then knows the request can be scheduled for decoding now. The model decode
45
+ will happen in the next scheduler step.
46
+ Else:
47
+ D worker handles other requests first.
48
+ Else (too short prompt, full local prefix-cache):
49
+ D still needs to puts the request in `reqs_to_load` but with None metadata, because D needs to
50
+ notify P the prefilled KV cache is no longer needed and can be freed in P.
51
+
52
+ (
53
+ Two scheduler steps can finish:
54
+ scheduler WAITING_FOR_REMOTE_KVS -> connector reqs_to_load -> worker wait for pulling
55
+ worker pulling done, notify P to free blocks
56
+ scheduler RUNNING -> connector reqs_to_load=None -> worker decode -> output
57
+ The waiting buffer will get freed after notified by D or expired.
58
+ )
59
+ """
60
+
61
+ import copy
62
+ import functools
63
+ import os
64
+ import threading
65
+ import time
66
+ from concurrent.futures import Future, ThreadPoolExecutor
67
+ from dataclasses import dataclass, field
68
+ from typing import TYPE_CHECKING, Any, Optional
69
+ from uuid import uuid4
70
+
71
+ import jax
72
+ import jax.numpy as jnp
73
+ import numpy as np
74
+ import zmq
75
+ from jax.experimental.transfer import start_transfer_server
76
+ from jax.sharding import Mesh
77
+ from vllm.config import VllmConfig
78
+ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
79
+ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
80
+ from vllm.utils import make_zmq_path, make_zmq_socket, round_down
81
+ from vllm.v1.core.sched.output import SchedulerOutput
82
+ from vllm.v1.request import RequestStatus
83
+
84
+ if TYPE_CHECKING:
85
+ from vllm.v1.core.kv_cache_manager import KVCacheBlocks
86
+ from vllm.v1.request import Request
87
+
88
+ from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
89
+ get_kv_ports,
90
+ get_kv_transfer_port, get_node_id,
91
+ get_side_channel_port)
92
+ from tpu_inference.logger import init_logger
93
+ from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
94
+ from tpu_inference.utils import device_array
95
+
96
+ ReqId = str
97
+
98
+ # Feature requests:
99
+ # 1. support async pulling natively
100
+ # 2. partial pulling (like RDMA)
101
+ # 3. non-blocking jax array read/write
102
+
103
+ # The await pull KV cache will be cleared after
104
+ # this time (in seconds) if no pulling occurred on it.
105
+ P2P_WAIT_PULL_TIMEOUT = 120
106
+
107
+ logger = init_logger(__name__)
108
+
109
+
110
+ @dataclass
111
+ class SendMeta:
112
+ uuid: int
113
+ local_block_ids: list[int]
114
+ expiration_time: float
115
+
116
+
117
+ @dataclass
118
+ class LoadMeta:
119
+ uuid: int
120
+ local_block_ids: list[int]
121
+ remote_block_ids: list[int]
122
+ remote_host: str | list[str]
123
+ remote_port: int | list[int]
124
+
125
+
126
+ @dataclass
127
+ class _kv_transfer_params:
128
+ """
129
+ P prepares this in request_finished() and responds to proxy server.
130
+ D recieves this from proxy server and uses this to create LoadMeta.
131
+ """
132
+ uuid: int
133
+ remote_block_ids: list[int]
134
+ # A single IP for single-host, or a list of IPs for mult-host.
135
+ remote_host: str | list[str]
136
+ # A single port for single-host, or a list of ports for mult-host.
137
+ remote_port: int | list[int]
138
+
139
+
140
+ # The metadata used for communicating between scheduler and worker connectors.
141
+ @dataclass
142
+ class TPUConnectorMetadata(KVConnectorMetadata):
143
+ reqs_to_send: dict[ReqId, SendMeta] = field(default_factory=dict)
144
+ reqs_to_load: dict[ReqId, LoadMeta] = field(default_factory=dict)
145
+
146
+
147
+ class TPUConnector(KVConnectorBase_V1):
148
+
149
+ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
150
+ assert vllm_config.kv_transfer_config is not None
151
+
152
+ if role == KVConnectorRole.SCHEDULER:
153
+ self.connector_scheduler = \
154
+ TPUConnectorScheduler(vllm_config)
155
+ self.connector_worker = None
156
+ elif role == KVConnectorRole.WORKER:
157
+ self.connector_scheduler = None
158
+ self.connector_worker = TPUConnectorWorker(vllm_config)
159
+
160
+ ############################################################
161
+ # Scheduler Side Methods
162
+ ############################################################
163
+ def get_num_new_matched_tokens(
164
+ self, request: "Request",
165
+ num_computed_tokens: int) -> tuple[int, bool]:
166
+ assert self.connector_scheduler is not None
167
+ return self.connector_scheduler.get_num_new_matched_tokens(
168
+ request, num_computed_tokens)
169
+
170
+ def update_state_after_alloc(self, request: "Request",
171
+ blocks: "KVCacheBlocks",
172
+ num_external_tokens: int):
173
+ assert self.connector_scheduler is not None
174
+ return self.connector_scheduler.update_state_after_alloc(
175
+ request, blocks, num_external_tokens)
176
+
177
+ def build_connector_meta(
178
+ self,
179
+ scheduler_output: SchedulerOutput,
180
+ ) -> TPUConnectorMetadata:
181
+ assert self.connector_scheduler is not None
182
+ return self.connector_scheduler.build_connector_meta()
183
+
184
+ def request_finished(
185
+ self,
186
+ request: "Request",
187
+ block_ids: list[int],
188
+ ) -> tuple[bool, Optional[dict[str, Any]]]:
189
+ assert self.connector_scheduler is not None
190
+ return self.connector_scheduler.request_finished(request, block_ids)
191
+
192
+ ############################################################
193
+ # Worker Side Methods
194
+ ############################################################
195
+ def register_kv_caches(self, kv_caches: list[jax.Array]):
196
+ """
197
+ We don't register kv_caches in connector, we call `register_runner` and
198
+ use runner.kv_caches directly instead because the ref of runner.kv_caches
199
+ would be reassigned during model forward.
200
+ """
201
+ pass
202
+
203
+ def register_runner(self, runner: TPUModelRunner) -> None:
204
+ assert self.connector_worker is not None
205
+ self.connector_worker.register_runner(runner)
206
+
207
+ def start_load_kv(self, _, **kwargs) -> None:
208
+ assert self.connector_worker is not None
209
+ assert isinstance(self._connector_metadata, TPUConnectorMetadata)
210
+ self.connector_worker.process_send_load(self._connector_metadata)
211
+
212
+ def wait_for_layer_load(self, layer_name: str) -> None:
213
+ """TPU connector doesn't support layer wise load."""
214
+ pass
215
+
216
+ def save_kv_layer(self, **kwargs) -> None:
217
+ """TPU connector doesn't support layer wise save."""
218
+ pass
219
+
220
+ def wait_for_save(self):
221
+ """
222
+ Not useful for TPU, because by the design of vLLM KVConnectorModelRunnerMixin,
223
+ this function is only called when scheduler_output.total_num_scheduled_tokens is not 0.
224
+ But the reqs_to_send is only available after the req finished prefilling where the
225
+ total_num_scheduled_tokens could be 0 if no other running reqs.
226
+ So we run saving logic in `start_load_kv -> process_send_load` instead.
227
+ """
228
+ pass
229
+
230
+ def get_finished(self,
231
+ finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
232
+ assert self.connector_worker is not None
233
+ return self.connector_worker.get_finished()
234
+
235
+
236
+ class TPUConnectorScheduler():
237
+
238
+ def __init__(self, vllm_config: "VllmConfig"):
239
+ self.vllm_config = vllm_config
240
+ self.config = vllm_config.kv_transfer_config
241
+ self.is_producer = self.config.is_kv_producer
242
+
243
+ self.block_size = vllm_config.cache_config.block_size
244
+
245
+ # This is updated in self.update_state_after_alloc() for D,
246
+ # each request that needs to pull KV cache from remote will be added to it.
247
+ self.reqs_to_send: dict[ReqId, SendMeta] = {}
248
+
249
+ # This is updated in self.request_finished() for P,
250
+ # each request that finished prefilling will be added to it.
251
+ self.reqs_to_load: dict[ReqId, LoadMeta] = {}
252
+
253
+ self.kv_ip = get_kv_ips()
254
+ self.kv_port = get_kv_ports()
255
+ logger.info(
256
+ f"Scheduler --> kv_ip={self.kv_ip} | kv_port={self.kv_port}")
257
+
258
+ def get_num_new_matched_tokens(
259
+ self,
260
+ request: "Request",
261
+ num_computed_tokens: int,
262
+ ) -> tuple[int, bool]:
263
+ """
264
+ D workers use this to get the number of new tokens
265
+ that can be loaded from remote P workers.
266
+ No-op for P workers.
267
+
268
+ Args:
269
+ request (Request): the request object.
270
+ num_computed_tokens (int): the number of locally
271
+ computed tokens for this request
272
+
273
+ Returns:
274
+ A tuple with the following elements:
275
+ - The number of tokens that will be loaded from the
276
+ external KV cache.
277
+ - If async loading. Must be 'False' for TPU connector
278
+ because TPU pulls KV cache in a blocking way.
279
+
280
+ """
281
+ if self.is_producer:
282
+ return 0, False
283
+
284
+ assert num_computed_tokens % self.block_size == 0
285
+ # This rounding logic must be consistent with calculating
286
+ # remote_block_ids in P's request_finished()
287
+ rounded_num_prompt_tokens = round_down(len(request.prompt_token_ids),
288
+ self.block_size)
289
+ count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
290
+ # NOTE(xiang): Although the JAX P2P pulling is a blocking op, we will run it in a
291
+ # separte thread to make it async, so we are safe to return True here.
292
+ if count > 0:
293
+ return count, True
294
+ return 0, False
295
+
296
+ def update_state_after_alloc(self, request: "Request",
297
+ blocks: "KVCacheBlocks",
298
+ num_external_tokens: int):
299
+ """
300
+ Update states after block allocation.
301
+ No-op for P workers.
302
+
303
+ Args:
304
+ request (Request): the request object.
305
+ blocks (KVCacheBlocks): the blocks allocated for the request.
306
+ num_external_tokens (int): the number of tokens that will be
307
+ loaded from the external KV cache.
308
+ """
309
+ if self.is_producer:
310
+ return
311
+
312
+ params = request.kv_transfer_params
313
+ if num_external_tokens > 0:
314
+ # We need to load KV-cache from remote (partial prefix cache hit).
315
+ local_block_ids = blocks.get_block_ids()[0]
316
+
317
+ # NOTE(xiang): D needs to pull the whole prefill blocks from the remote
318
+ # regardless how much ratio the prefix cache hits.
319
+ # The reason is JAX P2P doesn't work as RDMA, instead it works like:
320
+ # P just prepares the whole prefilled data and waits for pulling, then D pulls the
321
+ # whole data. Which means even with partial prefix cache hit on D, D cannot only
322
+ # pull the remaining partial data from P.
323
+ # Unless we implement a side channel to let P know the prefix cache hit info on D,
324
+ # so P can prepare those non-hit KV only, with that we need to change to:
325
+ # local_block_ids = blocks.get_unhashed_block_ids()
326
+
327
+ self.reqs_to_load[request.request_id] = LoadMeta(
328
+ uuid=params["uuid"],
329
+ local_block_ids=local_block_ids,
330
+ remote_block_ids=params["remote_block_ids"],
331
+ remote_host=params["remote_host"],
332
+ remote_port=params["remote_port"],
333
+ )
334
+ else:
335
+ # This branch means two cases:
336
+ # 1. We don't need to load KV-cache from remote because of full local cache.
337
+ # 2. The async pulling is done.
338
+ # In both cases we need to send notification to let P free memory.
339
+ self.reqs_to_load[request.request_id] = LoadMeta(
340
+ uuid=params["uuid"],
341
+ local_block_ids=None,
342
+ remote_block_ids=None,
343
+ remote_host=params["remote_host"],
344
+ remote_port=params["remote_port"],
345
+ )
346
+ logger.info(f"Scheduler --> reqs_to_load={self.reqs_to_load}")
347
+
348
+ def build_connector_meta(self) -> TPUConnectorMetadata:
349
+ """
350
+ Build the scheduler metadata and pass to the downstream worker.
351
+
352
+ This function should NOT modify fields in the scheduler_output.
353
+ Also, calling this function will reset the state of the connector.
354
+ """
355
+ meta = TPUConnectorMetadata()
356
+
357
+ if self.is_producer:
358
+ meta.reqs_to_send = self.reqs_to_send
359
+ self.reqs_to_send = {}
360
+ else:
361
+ meta.reqs_to_load = self.reqs_to_load
362
+ self.reqs_to_load = {}
363
+
364
+ return meta
365
+
366
+ def request_finished(
367
+ self,
368
+ request: "Request",
369
+ block_ids: list[int],
370
+ ) -> tuple[bool, Optional[dict[str, Any]]]:
371
+ """
372
+ Called when a request has finished, before its blocks are freed.
373
+ No-op for D workers.
374
+
375
+ Args:
376
+ request (Request): the request object.
377
+ block_ids: The block IDs allocated for this request and need to be freed.
378
+ Returns:
379
+ True if the request is being saved/sent asynchronously and blocks
380
+ should not be freed until the request_id is returned from
381
+ get_finished().
382
+ Optional KVTransferParams to be included in the request outputs
383
+ returned by the engine.
384
+ """
385
+ if not self.is_producer:
386
+ return False, None
387
+
388
+ # Mark the request finished only if the prefill is done and generates 1 output token.
389
+ # The request's max_tokens has been reset to 1, so it must be finished by length capped.
390
+ if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
391
+ return False, None
392
+
393
+ # NOTE(xiang): Get computed blocks rounded by block_size.
394
+ # This indication means for the last partially filled block, we won't bother transfering
395
+ # KV-cache, will just let D run prefill locally.
396
+ all_full = request.num_computed_tokens % self.block_size == 0
397
+ computed_block_ids = block_ids if all_full else block_ids[:-1]
398
+
399
+ # If prompt < block_size, no transfer so free blocks immediately.
400
+ delay_free_blocks = len(computed_block_ids) > 0
401
+
402
+ if delay_free_blocks:
403
+ uuid = get_uuid()
404
+ expiration_time = time.perf_counter() + P2P_WAIT_PULL_TIMEOUT
405
+ self.reqs_to_send[request.request_id] = SendMeta(
406
+ uuid=uuid,
407
+ local_block_ids=computed_block_ids,
408
+ expiration_time=expiration_time)
409
+ kv_transfer_params = dict(uuid=uuid,
410
+ remote_block_ids=computed_block_ids,
411
+ remote_host=self.kv_ip,
412
+ remote_port=self.kv_port)
413
+ logger.info(f"Scheduler ----> reqs_to_send={self.reqs_to_send} | "
414
+ f"kv_transfer_params={kv_transfer_params}")
415
+ else:
416
+ kv_transfer_params = {}
417
+
418
+ return delay_free_blocks, kv_transfer_params
419
+
420
+
421
+ class TPUConnectorWorker:
422
+
423
+ def __init__(self, vllm_config: VllmConfig):
424
+ self.vllm_config = vllm_config
425
+ self.config = vllm_config.kv_transfer_config
426
+ self.is_producer = self.config.is_kv_producer
427
+
428
+ self.runner: TPUModelRunner = None
429
+ self.mesh: Mesh = None
430
+ self.multi_host = os.getenv("TPU_MULTIHOST_BACKEND",
431
+ "").lower() == "ray"
432
+ # NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
433
+ # The worker rank is assigned with vLLM's sorting logic, which does not work
434
+ # for TPU host topology.
435
+ self.node_id = get_node_id()
436
+
437
+ # req_id: (kv, expiration_time)
438
+ self.reqs_wait_pull: dict[ReqId, list[list[jax.Array], float]] = {}
439
+ # req_id: thread_future
440
+ self.reqs_pulling: dict[ReqId, Future] = {}
441
+
442
+ self.host_ip = get_host_ip()
443
+ self.kv_transfer_port = get_kv_transfer_port()
444
+ self.side_channel_port = get_side_channel_port()
445
+
446
+ self.kv_transfer_server = None
447
+ self._maybe_start_p2p_server()
448
+ self.zmq_cxt = zmq.Context()
449
+ if self.is_producer:
450
+ ready_event = threading.Event()
451
+ self.pull_notify_listener_t = threading.Thread(
452
+ target=self._pull_notify_listener,
453
+ args=(ready_event, ),
454
+ daemon=True,
455
+ )
456
+ self.pull_notify_listener_t.start()
457
+ ready_event.wait()
458
+ else:
459
+ self.pull_executor = ThreadPoolExecutor(max_workers=64)
460
+ self.pull_conns: dict[str, Any] = {}
461
+ self.notif_sockets: dict[str, zmq.Socket] = {}
462
+
463
+ logger.info(f"Worker {self.node_id} --> init | "
464
+ f"ip={self.host_ip} | "
465
+ f"kv_transfer_port={self.kv_transfer_port} | "
466
+ f"side_channel_port={self.side_channel_port}")
467
+
468
+ def __del__(self):
469
+ if self.is_producer:
470
+ self.pull_notify_listener_t.join(timeout=0)
471
+ else:
472
+ self.pull_executor.shutdown(wait=False)
473
+ self.zmq_cxt.destroy(linger=0)
474
+
475
+ def register_runner(self, runner: TPUModelRunner):
476
+ self.runner = runner
477
+ self.mesh = runner.mesh
478
+
479
+ # Get the spec of the kv_caches
480
+ kv_caches = runner.kv_caches
481
+ kv_layer = kv_caches[0]
482
+ self.num_layers = len(kv_caches)
483
+ self.shape = list(kv_layer.shape)
484
+ self.dtype = kv_layer.dtype
485
+ self.sharding = kv_layer.sharding
486
+
487
+ def _maybe_start_p2p_server(self):
488
+ if self.kv_transfer_server is not None:
489
+ return
490
+ server_addr = f"{self.host_ip}:{self.kv_transfer_port}"
491
+ transport_addr = f'{self.host_ip}:0'
492
+ self.kv_transfer_server = start_transfer_server(
493
+ jax.local_devices()[0].client,
494
+ server_addr,
495
+ [transport_addr],
496
+ max_num_parallel_copies=8,
497
+ transfer_size=256 * 1024 * 1024,
498
+ use_raw_buffers=False,
499
+ )
500
+ logger.info(
501
+ f"Worker {self.node_id} --> kv transfer | addr={self.kv_transfer_server.address()}"
502
+ )
503
+
504
+ def _pull_notify_listener(self, ready_event: threading.Event):
505
+ sock_path = make_zmq_path("tcp", "*", self.side_channel_port)
506
+ sock = make_zmq_socket(ctx=self.zmq_cxt,
507
+ path=sock_path,
508
+ socket_type=zmq.ROUTER,
509
+ bind=True)
510
+ ready_event.set()
511
+ logger.info(
512
+ f"Worker {self.node_id} --> zmq listener | sock_path={sock_path}")
513
+
514
+ while True:
515
+ client_id, req_id_bytes = sock.recv_multipart()
516
+ req_id = req_id_bytes.decode('utf-8')
517
+ logger.info(
518
+ f"Worker {self.node_id} --> zmq recieve | req_id={req_id}")
519
+ if req_id in self.reqs_wait_pull:
520
+ # Set the expiration time of this request to -1, mark to be done
521
+ self.reqs_wait_pull[req_id][1] = -1
522
+ else:
523
+ raise ValueError(
524
+ f"Disagg producer recives a non-exist pulling finished notification request {req_id}"
525
+ )
526
+ time.sleep(0)
527
+ # The response is not really needed.
528
+ # sock.send_multipart([client_id, b"", b"ACK"])
529
+
530
+ def process_send_load(self, metadata: TPUConnectorMetadata):
531
+ """
532
+ This is called in runner before calling model forward,
533
+ whenever the scheduler_output.total_num_scheduled_tokens is empty or not.
534
+ """
535
+ reqs = metadata.reqs_to_send
536
+ if reqs:
537
+ assert self.is_producer
538
+ logger.info(f"Worker {self.node_id} --> reqs_to_send={reqs}")
539
+ for req_id, req_meta in reqs.items():
540
+ self._prepare_kv_and_wait(req_id, req_meta)
541
+
542
+ reqs = metadata.reqs_to_load
543
+ if reqs:
544
+ assert not self.is_producer
545
+ logger.info(f"Worker {self.node_id} --> reqs_to_load={reqs}")
546
+ for req_id, req_meta in reqs.items():
547
+ if req_meta.remote_block_ids is not None:
548
+ # The request requires to pull KV from P, build the connection and pull
549
+ # the data asyncly.
550
+ conn = self._maybe_build_kv_connection(req_meta)
551
+ self.reqs_pulling[req_id] = self.pull_executor.submit(
552
+ self._pull_kv, conn, req_meta)
553
+ else:
554
+ # The request has finished pulling the KV from remote, or it has full local
555
+ # prefix cache, need to notify P to let it free blocks.
556
+ socket = self._maybe_build_notif_socket(req_meta)
557
+ self._notify_pull_done(socket, req_id)
558
+
559
+ def _prepare_kv_and_wait(self, req_id: str, req_meta: SendMeta):
560
+ local_block_ids = req_meta.local_block_ids
561
+ # TODO(xiang): pad block_ids to avoid recompilation
562
+ indices = device_array(self.mesh, np.array(local_block_ids))
563
+ kv = select_from_kv_caches(self.runner.kv_caches, indices)
564
+ # NOTE(xiang): We need to manually store the kv because:
565
+ # Although we can set use_raw_buffers=True to let kv be safely destroyed after
566
+ # calling await_pull, it could be a stranding buffer if D never pulls it.
567
+ # So we have to set use_raw_buffers=False and stores the kv, then the kv buffer
568
+ # will be safely destroyed by either D notifying or expiration.
569
+ self.reqs_wait_pull[req_id] = [kv, req_meta.expiration_time]
570
+ self.kv_transfer_server.await_pull(req_meta.uuid, kv)
571
+
572
+ def _maybe_build_kv_connection(self, req_meta: LoadMeta) -> Any:
573
+ remote_addr = f"{req_meta.remote_host}:{req_meta.remote_port}"
574
+ if remote_addr in self.pull_conns:
575
+ conn = self.pull_conns[remote_addr]
576
+ else:
577
+ conn = self.kv_transfer_server.connect(remote_addr)
578
+ self.pull_conns[remote_addr] = conn
579
+ logger.info(
580
+ f"Worker {self.node_id} --> kv transfer | connect={remote_addr}"
581
+ )
582
+ return conn
583
+
584
+ def _pull_kv(self, conn: Any, req_meta: LoadMeta):
585
+ # The local allocated blocks which don't hit prefix caching.
586
+ local_block_ids = req_meta.local_block_ids
587
+ # The remote computed blocks which need to pull from P.
588
+ remote_block_ids = req_meta.remote_block_ids
589
+ # Make sure they have the same num blocks because we don't care
590
+ # if partial prefix cache hit now.
591
+ assert len(local_block_ids) == len(remote_block_ids)
592
+
593
+ kv_spec = self._get_kv_spec(len(remote_block_ids))
594
+ # TODO(xiang): pad block_ids to avoid recompilation
595
+ indices = device_array(self.mesh, np.array(local_block_ids))
596
+ kv = conn.pull(req_meta.uuid, kv_spec)
597
+ logger.info(
598
+ f"Worker {self.node_id} --> kv transfer | pull uuid={req_meta.uuid}"
599
+ )
600
+ return kv, indices
601
+
602
+ def _get_kv_spec(self, num_blocks: int) -> list[jax.ShapeDtypeStruct]:
603
+ assert num_blocks <= self.shape[0]
604
+ shape = copy.copy(self.shape)
605
+ shape[0] = num_blocks
606
+ return [
607
+ jax.ShapeDtypeStruct(shape, self.dtype, sharding=self.sharding)
608
+ ] * self.num_layers
609
+
610
+ def _maybe_build_notif_socket(self, req_meta: LoadMeta) -> zmq.Socket:
611
+ sock_path = make_zmq_path("tcp", req_meta.remote_host,
612
+ self.side_channel_port)
613
+ if sock_path in self.notif_sockets:
614
+ sock = self.notif_sockets[sock_path]
615
+ else:
616
+ sock = make_zmq_socket(ctx=self.zmq_cxt,
617
+ path=sock_path,
618
+ socket_type=zmq.DEALER,
619
+ bind=False)
620
+ logger.info(
621
+ f"Worker {self.node_id} --> zmq notify | sock_path={sock_path}"
622
+ )
623
+ return sock
624
+
625
+ def _notify_pull_done(self, sock: zmq.Socket, req_id: str):
626
+ logger.info(f"Worker {self.node_id} --> zmq notify | req_id={req_id}")
627
+ sock.send_string(req_id)
628
+ # The response is not really needed.
629
+ # ack = sock.recv_string()
630
+
631
+ def get_finished(self) -> tuple[set[str], set[str]]:
632
+ done_sending: set[str] = set()
633
+ done_recving: set[str] = set()
634
+ if not self.reqs_wait_pull and not self.reqs_pulling:
635
+ return done_sending, done_recving
636
+
637
+ # Mark a req as done recieving after its pulling thread returns.
638
+ # This req can then be scheduled for decoding in the next scheduler step.
639
+ for req_id in list(self.reqs_pulling.keys()):
640
+ future = self.reqs_pulling[req_id]
641
+ if future.done():
642
+ # NOTE(xiang): we do the scatter in main thread to avoid data racing.
643
+ # The data racing is not for the kv_caches buffer, it's for the runner.kv_caches ref.
644
+ kv, indices = future.result()
645
+ self.runner.kv_caches = scatter_kv_slices(
646
+ self.runner.kv_caches, kv, indices)
647
+ del self.reqs_pulling[req_id]
648
+ done_recving.add(req_id)
649
+
650
+ # Mark a req as done seding when it's expired.
651
+ # This req can then be released blocks in the current scheduler step.
652
+ now = time.perf_counter()
653
+ for req_id in list(self.reqs_wait_pull):
654
+ _, expires = self.reqs_wait_pull[req_id]
655
+ if now > expires:
656
+ del self.reqs_wait_pull[req_id]
657
+ done_sending.add(req_id)
658
+ if done_sending:
659
+ logger.info(
660
+ f"Worker {self.node_id} --> done_sending={done_sending}")
661
+ if done_recving:
662
+ logger.info(
663
+ f"Worker {self.node_id} --> done_recving={done_recving}")
664
+ return done_sending, done_recving
665
+
666
+
667
+ def get_uuid() -> int:
668
+ int128 = uuid4().int
669
+ # Must be 64-bit int, otherwise vllm output encoder would raise error.
670
+ int64 = int128 >> 64
671
+ return int64
672
+
673
+
674
+ @jax.jit
675
+ def select_from_kv_caches(kv_caches: list[jax.Array],
676
+ indices: list[jax.Array]) -> list[jax.Array]:
677
+ selected = [cache.at[indices].get() for cache in kv_caches]
678
+ return selected
679
+
680
+
681
+ @functools.partial(
682
+ jax.jit,
683
+ donate_argnames=("kv_caches", ),
684
+ )
685
+ def scatter_kv_slices(kv_caches: list[jax.Array], kv_slices: list[jax.Array],
686
+ indices: list[jax.Array]) -> list[jax.Array]:
687
+ num_indices = indices.shape[0]
688
+ num_slices = kv_slices[0].shape[0]
689
+ # indices might be padded
690
+ assert num_slices <= num_indices
691
+
692
+ new_kv_caches = []
693
+ for cache, slice in zip(kv_caches, kv_slices):
694
+ if num_slices < num_indices:
695
+ slice = jnp.pad(slice, ((0, num_indices - num_slices), (0, 0),
696
+ (0, 0), (0, 0)))
697
+ new_cache = cache.at[indices].set(slice)
698
+ new_kv_caches.append(new_cache)
699
+ return new_kv_caches