tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,727 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """
4
+ Proxy server routes the request to P with max_output_tokens=1
5
+
6
+ P workflow:
7
+ P recives the request
8
+
9
+ P scheduler checks if the prefill is full done in `request_finished()`
10
+ If done:
11
+ P puts the request-id in `scheduler_output.finished_req_ids`
12
+ and puts the request in `scheduler_output.kv_connector_metadata.reqs_to_send`
13
+ P responds the proxy server with `finished_req_ids` and the `kv_transfer_params`
14
+ P worker gets `reqs_to_send` and runs async `_prepare_kv_and_wait()`
15
+ Else:
16
+ P schedules the prefill with multiple turns due to chunked-prefill.
17
+
18
+ P worker checks if the request has been pulled by D
19
+ If done:
20
+ P worker puts the request-id in `done_sending()`
21
+ P scheduler frees blocks for the requet in done sending.
22
+ Else:
23
+ P holds the blocks for the request until it's pulled by D
24
+
25
+ (
26
+ One scheduler step can finish:
27
+ scheduler RUNNING -> connector reqs_to_send -> worker prefill -> output
28
+ The waiting buffer will get freed after notified by D or expired.
29
+ )
30
+
31
+ Proxy server recives the response from P and forwards it to D
32
+
33
+ D workflow:
34
+ D recives the request
35
+
36
+ D scheduler calculates the num of tokens needing to pull from P in `get_num_new_matched_tokens()`
37
+ D checks if need to pull from P
38
+ If true:
39
+ D puts the request in `scheduler_output.kv_connector_metadata.reqs_to_load`
40
+ D worker gets `reqs_to_load` and runs `_pull_and_write_kv()` in separate threads (to be async)
41
+ D worker checks if the async loading is done:
42
+ If done:
43
+ D worker puts the request-id in `done_recving`.
44
+ D scheduler then knows the request can be scheduled for decoding now. The model decode
45
+ will happen in the next scheduler step.
46
+ Else:
47
+ D worker handles other requests first.
48
+ Else (too short prompt, full local prefix-cache):
49
+ D still needs to puts the request in `reqs_to_load` but with None metadata, because D needs to
50
+ notify P the prefilled KV cache is no longer needed and can be freed in P.
51
+
52
+ (
53
+ Two scheduler steps can finish:
54
+ scheduler WAITING_FOR_REMOTE_KVS -> connector reqs_to_load -> worker wait for pulling
55
+ worker pulling done, notify P to free blocks
56
+ scheduler RUNNING -> connector reqs_to_load=None -> worker decode -> output
57
+ The waiting buffer will get freed after notified by D or expired.
58
+ )
59
+ """
60
+
61
+ import copy
62
+ import functools
63
+ import threading
64
+ import time
65
+ from concurrent.futures import Future, ThreadPoolExecutor
66
+ from dataclasses import dataclass, field
67
+ from typing import TYPE_CHECKING, Any, Optional
68
+ from uuid import uuid4
69
+
70
+ import jax
71
+ import jax.numpy as jnp
72
+ import numpy as np
73
+ import zmq
74
+ from jax.experimental.transfer import start_transfer_server
75
+ from jax.sharding import Mesh
76
+ from vllm.config import VllmConfig
77
+ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
78
+ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
79
+ from vllm.utils.math_utils import round_down
80
+ from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
81
+ from vllm.v1.core.sched.output import SchedulerOutput
82
+ from vllm.v1.request import RequestStatus
83
+
84
+ if TYPE_CHECKING:
85
+ from vllm.v1.core.kv_cache_manager import KVCacheBlocks
86
+ from vllm.v1.request import Request
87
+
88
+ from tpu_inference import envs
89
+ from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
90
+ get_kv_ports,
91
+ get_kv_transfer_port, get_node_id,
92
+ get_side_channel_port)
93
+ from tpu_inference.logger import init_logger
94
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
95
+ from tpu_inference.utils import device_array
96
+
97
+ ReqId = str
98
+
99
+ # Feature requests:
100
+ # 1. support async pulling natively
101
+ # 2. partial pulling (like RDMA)
102
+ # 3. non-blocking jax array read/write
103
+
104
+ # The await pull KV cache will be cleared after
105
+ # this time (in seconds) if no pulling occurred on it.
106
+ P2P_WAIT_PULL_TIMEOUT = 120
107
+
108
+ logger = init_logger(__name__)
109
+
110
+
111
+ @dataclass
112
+ class SendMeta:
113
+ uuid: int
114
+ local_block_ids: list[int]
115
+ expiration_time: float
116
+
117
+
118
+ @dataclass
119
+ class LoadMeta:
120
+ uuid: int
121
+ local_block_ids: list[int]
122
+ remote_block_ids: list[int]
123
+ remote_host: str | list[str]
124
+ remote_port: int | list[int]
125
+
126
+
127
+ @dataclass
128
+ class _kv_transfer_params:
129
+ """
130
+ P prepares this in request_finished() and responds to proxy server.
131
+ D recieves this from proxy server and uses this to create LoadMeta.
132
+ """
133
+ uuid: int
134
+ remote_block_ids: list[int]
135
+ # A single IP for single-host, or a list of IPs for mult-host.
136
+ remote_host: str | list[str]
137
+ # A single port for single-host, or a list of ports for mult-host.
138
+ remote_port: int | list[int]
139
+
140
+
141
+ # The metadata used for communicating between scheduler and worker connectors.
142
+ @dataclass
143
+ class TPUConnectorMetadata(KVConnectorMetadata):
144
+ reqs_to_send: dict[ReqId, SendMeta] = field(default_factory=dict)
145
+ reqs_to_load: dict[ReqId, LoadMeta] = field(default_factory=dict)
146
+
147
+
148
+ class TPUConnector(KVConnectorBase_V1):
149
+
150
+ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
151
+ assert vllm_config.kv_transfer_config is not None
152
+
153
+ if role == KVConnectorRole.SCHEDULER:
154
+ self.connector_scheduler = \
155
+ TPUConnectorScheduler(vllm_config)
156
+ self.connector_worker = None
157
+ elif role == KVConnectorRole.WORKER:
158
+ self.connector_scheduler = None
159
+ self.connector_worker = TPUConnectorWorker(vllm_config)
160
+
161
+ ############################################################
162
+ # Scheduler Side Methods
163
+ ############################################################
164
+ def get_num_new_matched_tokens(
165
+ self, request: "Request",
166
+ num_computed_tokens: int) -> tuple[int, bool]:
167
+ assert self.connector_scheduler is not None
168
+ return self.connector_scheduler.get_num_new_matched_tokens(
169
+ request, num_computed_tokens)
170
+
171
+ def update_state_after_alloc(self, request: "Request",
172
+ blocks: "KVCacheBlocks",
173
+ num_external_tokens: int):
174
+ assert self.connector_scheduler is not None
175
+ return self.connector_scheduler.update_state_after_alloc(
176
+ request, blocks, num_external_tokens)
177
+
178
+ def build_connector_meta(
179
+ self,
180
+ scheduler_output: SchedulerOutput,
181
+ ) -> TPUConnectorMetadata:
182
+ assert self.connector_scheduler is not None
183
+ return self.connector_scheduler.build_connector_meta()
184
+
185
+ def request_finished(
186
+ self,
187
+ request: "Request",
188
+ block_ids: list[int],
189
+ ) -> tuple[bool, Optional[dict[str, Any]]]:
190
+ assert self.connector_scheduler is not None
191
+ return self.connector_scheduler.request_finished(request, block_ids)
192
+
193
+ def get_finished_count(self) -> int:
194
+ assert self.connector_scheduler is not None
195
+ return self.connector_scheduler.get_finished_count()
196
+
197
+ ############################################################
198
+ # Worker Side Methods
199
+ ############################################################
200
+ def register_kv_caches(self, kv_caches: list[jax.Array]):
201
+ """
202
+ We don't register kv_caches in connector, we call `register_runner` and
203
+ use runner.kv_caches directly instead because the ref of runner.kv_caches
204
+ would be reassigned during model forward.
205
+ """
206
+ pass
207
+
208
+ def register_runner(self, runner: TPUModelRunner) -> None:
209
+ assert self.connector_worker is not None
210
+ self.connector_worker.register_runner(runner)
211
+
212
+ def start_load_kv(self, _, **kwargs) -> None:
213
+ assert self.connector_worker is not None
214
+ assert isinstance(self._connector_metadata, TPUConnectorMetadata)
215
+ self.connector_worker.process_send_load(self._connector_metadata)
216
+
217
+ def wait_for_layer_load(self, layer_name: str) -> None:
218
+ """TPU connector doesn't support layer wise load."""
219
+ pass
220
+
221
+ def save_kv_layer(self, **kwargs) -> None:
222
+ """TPU connector doesn't support layer wise save."""
223
+ pass
224
+
225
+ def wait_for_save(self):
226
+ """
227
+ Not useful for TPU, because by the design of vLLM KVConnectorModelRunnerMixin,
228
+ this function is only called when scheduler_output.total_num_scheduled_tokens is not 0.
229
+ But the reqs_to_send is only available after the req finished prefilling where the
230
+ total_num_scheduled_tokens could be 0 if no other running reqs.
231
+ So we run saving logic in `start_load_kv -> process_send_load` instead.
232
+ """
233
+ pass
234
+
235
+ def get_finished(self,
236
+ finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
237
+ assert self.connector_worker is not None
238
+ return self.connector_worker.get_finished()
239
+
240
+
241
+ class TPUConnectorScheduler():
242
+
243
+ def __init__(self, vllm_config: "VllmConfig"):
244
+ self.vllm_config = vllm_config
245
+ self.config = vllm_config.kv_transfer_config
246
+ self.is_producer = self.config.is_kv_producer
247
+
248
+ self.block_size = vllm_config.cache_config.block_size
249
+
250
+ # This is updated in self.update_state_after_alloc() for D,
251
+ # each request that needs to pull KV cache from remote will be added to it.
252
+ self.reqs_to_send: dict[ReqId, SendMeta] = {}
253
+
254
+ # This is updated in self.request_finished() for P,
255
+ # each request that finished prefilling will be added to it.
256
+ self.reqs_to_load: dict[ReqId, LoadMeta] = {}
257
+
258
+ self.kv_ip = get_kv_ips()
259
+ self.kv_port = get_kv_ports()
260
+ logger.info(
261
+ f"TPUConnectorScheduler --> kv_ip={self.kv_ip} | kv_port={self.kv_port}"
262
+ )
263
+
264
+ def get_num_new_matched_tokens(
265
+ self,
266
+ request: "Request",
267
+ num_computed_tokens: int,
268
+ ) -> tuple[int, bool]:
269
+ """
270
+ D workers use this to get the number of new tokens
271
+ that can be loaded from remote P workers.
272
+ No-op for P workers.
273
+
274
+ Args:
275
+ request (Request): the request object.
276
+ num_computed_tokens (int): the number of locally
277
+ computed tokens for this request
278
+
279
+ Returns:
280
+ A tuple with the following elements:
281
+ - The number of tokens that will be loaded from the
282
+ external KV cache.
283
+ - If async loading. Must be 'False' for TPU connector
284
+ because TPU pulls KV cache in a blocking way.
285
+
286
+ """
287
+ if self.is_producer or not request.kv_transfer_params:
288
+ return 0, False
289
+
290
+ assert num_computed_tokens % self.block_size == 0
291
+ # This rounding logic must be consistent with calculating
292
+ # remote_block_ids in P's request_finished()
293
+ rounded_num_prompt_tokens = round_down(len(request.prompt_token_ids),
294
+ self.block_size)
295
+ count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
296
+ # NOTE(xiang): Although the JAX P2P pulling is a blocking op, we will run it in a
297
+ # separte thread to make it async, so we are safe to return True here.
298
+ if count > 0:
299
+ return count, True
300
+ return 0, False
301
+
302
+ def update_state_after_alloc(self, request: "Request",
303
+ blocks: "KVCacheBlocks",
304
+ num_external_tokens: int):
305
+ """
306
+ Update states after block allocation.
307
+ No-op for P workers.
308
+
309
+ Args:
310
+ request (Request): the request object.
311
+ blocks (KVCacheBlocks): the blocks allocated for the request.
312
+ num_external_tokens (int): the number of tokens that will be
313
+ loaded from the external KV cache.
314
+ """
315
+ if self.is_producer or not request.kv_transfer_params:
316
+ return
317
+
318
+ params = request.kv_transfer_params
319
+ if num_external_tokens > 0:
320
+ # We need to load KV-cache from remote (partial prefix cache hit).
321
+ local_block_ids = blocks.get_block_ids()[0]
322
+
323
+ # NOTE(xiang): D needs to pull the whole prefill blocks from the remote
324
+ # regardless how much ratio the prefix cache hits.
325
+ # The reason is JAX P2P doesn't work as RDMA, instead it works like:
326
+ # P just prepares the whole prefilled data and waits for pulling, then D pulls the
327
+ # whole data. Which means even with partial prefix cache hit on D, D cannot only
328
+ # pull the remaining partial data from P.
329
+ # Unless we implement a side channel to let P know the prefix cache hit info on D,
330
+ # so P can prepare those non-hit KV only, with that we need to change to:
331
+ # local_block_ids = blocks.get_unhashed_block_ids()
332
+
333
+ self.reqs_to_load[request.request_id] = LoadMeta(
334
+ uuid=params["uuid"],
335
+ local_block_ids=local_block_ids,
336
+ remote_block_ids=params["remote_block_ids"],
337
+ remote_host=params["remote_host"],
338
+ remote_port=params["remote_port"],
339
+ )
340
+ else:
341
+ # This branch means two cases:
342
+ # 1. We don't need to load KV-cache from remote because of full local cache.
343
+ # 2. The async pulling is done.
344
+ # In both cases we need to send notification to let P free memory.
345
+ self.reqs_to_load[request.request_id] = LoadMeta(
346
+ uuid=params["uuid"],
347
+ local_block_ids=None,
348
+ remote_block_ids=None,
349
+ remote_host=params["remote_host"],
350
+ remote_port=params["remote_port"],
351
+ )
352
+ logger.info(
353
+ f"TPUConnector Scheduler update_state_after_alloc --> reqs_to_load={self.reqs_to_load}"
354
+ )
355
+
356
+ def build_connector_meta(self) -> TPUConnectorMetadata:
357
+ """
358
+ Build the scheduler metadata and pass to the downstream worker.
359
+
360
+ This function should NOT modify fields in the scheduler_output.
361
+ Also, calling this function will reset the state of the connector.
362
+ """
363
+ meta = TPUConnectorMetadata()
364
+
365
+ if self.is_producer:
366
+ meta.reqs_to_send = self.reqs_to_send
367
+ self.reqs_to_send = {}
368
+ else:
369
+ meta.reqs_to_load = self.reqs_to_load
370
+ self.reqs_to_load = {}
371
+
372
+ return meta
373
+
374
+ def get_finished_count(self) -> int:
375
+ """
376
+ Return how many workers need pull the kv cache and report back.
377
+ """
378
+ return len(self.kv_ip) if isinstance(self.kv_ip, list) else 1
379
+
380
+ def request_finished(
381
+ self,
382
+ request: "Request",
383
+ block_ids: list[int],
384
+ ) -> tuple[bool, Optional[dict[str, Any]]]:
385
+ """
386
+ Called when a request has finished, before its blocks are freed.
387
+ No-op for D workers.
388
+
389
+ Args:
390
+ request (Request): the request object.
391
+ block_ids: The block IDs allocated for this request and need to be freed.
392
+ Returns:
393
+ True if the request is being saved/sent asynchronously and blocks
394
+ should not be freed until the request_id is returned from
395
+ get_finished().
396
+ Optional KVTransferParams to be included in the request outputs
397
+ returned by the engine.
398
+ """
399
+ if not self.is_producer:
400
+ return False, None
401
+
402
+ # Mark the request finished only if the prefill is done and generates 1 output token.
403
+ # The request's max_tokens has been reset to 1, so it must be finished by length capped.
404
+ if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
405
+ return False, None
406
+
407
+ # NOTE(xiang): Get computed blocks rounded by block_size.
408
+ # This indication means for the last partially filled block, we won't bother transfering
409
+ # KV-cache, will just let D run prefill locally.
410
+ all_full = request.num_computed_tokens % self.block_size == 0
411
+ computed_block_ids = block_ids if all_full else block_ids[:-1]
412
+
413
+ # If prompt < block_size, no transfer so free blocks immediately.
414
+ delay_free_blocks = len(computed_block_ids) > 0
415
+ if delay_free_blocks:
416
+ uuid = get_uuid()
417
+ expiration_time = time.perf_counter() + P2P_WAIT_PULL_TIMEOUT
418
+ self.reqs_to_send[request.request_id] = SendMeta(
419
+ uuid=uuid,
420
+ local_block_ids=computed_block_ids,
421
+ expiration_time=expiration_time)
422
+ kv_transfer_params = dict(uuid=uuid,
423
+ remote_block_ids=computed_block_ids,
424
+ remote_host=self.kv_ip,
425
+ remote_port=self.kv_port)
426
+ logger.info(
427
+ f"TPUConnector Scheduler ----> generated reqs_to_send={self.reqs_to_send} | "
428
+ f"kv_transfer_params={kv_transfer_params}")
429
+ else:
430
+ kv_transfer_params = {}
431
+
432
+ return delay_free_blocks, kv_transfer_params
433
+
434
+
435
+ class TPUConnectorWorker:
436
+
437
+ def __init__(self, vllm_config: VllmConfig):
438
+ self.vllm_config = vllm_config
439
+ self.config = vllm_config.kv_transfer_config
440
+ self.is_producer = self.config.is_kv_producer
441
+
442
+ self.runner: TPUModelRunner = None
443
+ self.mesh: Mesh = None
444
+ self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
445
+ # NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
446
+ # The worker rank is assigned with vLLM's sorting logic, which does not work
447
+ # for TPU host topology.
448
+ self.node_id = get_node_id()
449
+
450
+ # req_id: (kv, expiration_time)
451
+ self.reqs_wait_pull: dict[ReqId, list[list[jax.Array], float]] = {}
452
+ # req_id: thread_future
453
+ self.reqs_pulling: dict[ReqId, Future] = {}
454
+
455
+ self.host_ip = get_host_ip()
456
+ self.kv_transfer_port = get_kv_transfer_port()
457
+ self.side_channel_port = get_side_channel_port()
458
+
459
+ self.kv_transfer_server = None
460
+ self.zmq_cxt = zmq.Context()
461
+ if self.is_producer:
462
+ ready_event = threading.Event()
463
+ self.pull_notify_listener_t = threading.Thread(
464
+ target=self._pull_notify_listener,
465
+ args=(ready_event, ),
466
+ daemon=True,
467
+ )
468
+ self.pull_notify_listener_t.start()
469
+ ready_event.wait()
470
+ else:
471
+ self.pull_executor = ThreadPoolExecutor(max_workers=64)
472
+ self.pull_conns: dict[str, Any] = {}
473
+ self.notif_sockets: dict[str, zmq.Socket] = {}
474
+
475
+ logger.info(f"TPUConnector Worker {self.node_id} --> init | "
476
+ f"ip={self.host_ip} | "
477
+ f"kv_transfer_port={self.kv_transfer_port} | "
478
+ f"side_channel_port={self.side_channel_port}")
479
+
480
+ def __del__(self):
481
+ if self.is_producer:
482
+ if hasattr(self, "pull_notify_listener_t"):
483
+ self.pull_notify_listener_t.join(timeout=0)
484
+ else:
485
+ if hasattr(self, "pull_executor"):
486
+ self.pull_executor.shutdown(wait=False)
487
+ if hasattr(self, "zmq_cxt"):
488
+ self.zmq_cxt.destroy(linger=0)
489
+
490
+ def register_runner(self, runner: TPUModelRunner):
491
+ self.runner = runner
492
+ self.mesh = runner.mesh
493
+
494
+ # Get the spec of the kv_caches
495
+ kv_caches = runner.kv_caches
496
+ kv_layer = kv_caches[0]
497
+ self.num_layers = len(kv_caches)
498
+ self.shape = list(kv_layer.shape)
499
+ self.dtype = kv_layer.dtype
500
+ self.sharding = kv_layer.sharding
501
+ self._maybe_start_p2p_server()
502
+
503
+ def _maybe_start_p2p_server(self):
504
+ if self.kv_transfer_server is not None:
505
+ return
506
+ server_addr = f"{self.host_ip}:{self.kv_transfer_port}"
507
+ transport_addr = f'{self.host_ip}:0'
508
+ self.kv_transfer_server = start_transfer_server(
509
+ jax.local_devices()[0].client,
510
+ server_addr,
511
+ [transport_addr],
512
+ max_num_parallel_copies=8,
513
+ transfer_size=256 * 1024 * 1024,
514
+ use_raw_buffers=False,
515
+ )
516
+ logger.info(
517
+ f"TPUConnector Worker {self.node_id} --> kv transfer | addr={self.kv_transfer_server.address()}"
518
+ )
519
+
520
+ def _pull_notify_listener(self, ready_event: threading.Event):
521
+ sock_path = make_zmq_path("tcp", "*", self.side_channel_port)
522
+ sock = make_zmq_socket(ctx=self.zmq_cxt,
523
+ path=sock_path,
524
+ socket_type=zmq.ROUTER,
525
+ bind=True)
526
+ ready_event.set()
527
+ logger.info(
528
+ f"TPUConnector Worker {self.node_id} --> zmq listener | sock_path={sock_path}"
529
+ )
530
+
531
+ while True:
532
+ client_id, req_id_bytes = sock.recv_multipart()
533
+ req_id = req_id_bytes.decode('utf-8')
534
+ logger.info(
535
+ f"TPUConnector Worker {self.node_id} --> zmq recieve | req_id={req_id}"
536
+ )
537
+ if req_id in self.reqs_wait_pull:
538
+ # Set the expiration time of this request to -1, mark to be done
539
+ self.reqs_wait_pull[req_id][1] = -1
540
+ else:
541
+ raise ValueError(
542
+ f"Disagg producer recives a non-exist pulling finished notification request {req_id}"
543
+ )
544
+ time.sleep(0)
545
+ # The response is not really needed.
546
+ # sock.send_multipart([client_id, b"", b"ACK"])
547
+
548
+ def process_send_load(self, metadata: TPUConnectorMetadata):
549
+ """
550
+ This is called in runner before calling model forward,
551
+ whenever the scheduler_output.total_num_scheduled_tokens is empty or not.
552
+ """
553
+ reqs = metadata.reqs_to_send
554
+ if reqs:
555
+ assert self.is_producer
556
+ logger.info(
557
+ f"TPUConnector Worker {self.node_id} --> reqs_to_send={reqs}")
558
+ for req_id, req_meta in reqs.items():
559
+ self._prepare_kv_and_wait(req_id, req_meta)
560
+
561
+ reqs = metadata.reqs_to_load
562
+ if reqs:
563
+ assert not self.is_producer
564
+ logger.info(
565
+ f"TPUConnector Worker {self.node_id} --> reqs_to_load={reqs}")
566
+ for req_id, req_meta in reqs.items():
567
+ if req_meta.remote_block_ids is not None:
568
+ # The request requires to pull KV from P, build the connection and pull
569
+ # the data asyncly.
570
+ conn = self._maybe_build_kv_connection(req_meta)
571
+ self.reqs_pulling[req_id] = self.pull_executor.submit(
572
+ self._pull_kv, conn, req_meta)
573
+ else:
574
+ # The request has finished pulling the KV from remote, or it has full local
575
+ # prefix cache, need to notify P to let it free blocks.
576
+ socket = self._maybe_build_notif_socket(req_meta)
577
+ self._notify_pull_done(socket, req_id)
578
+
579
+ def _prepare_kv_and_wait(self, req_id: str, req_meta: SendMeta):
580
+ local_block_ids = req_meta.local_block_ids
581
+ # TODO(xiang): pad block_ids to avoid recompilation
582
+ indices = device_array(self.mesh, np.array(local_block_ids))
583
+ kv = select_from_kv_caches(self.runner.kv_caches, indices)
584
+ # NOTE(xiang): We need to manually store the kv because:
585
+ # Although we can set use_raw_buffers=True to let kv be safely destroyed after
586
+ # calling await_pull, it could be a stranding buffer if D never pulls it.
587
+ # So we have to set use_raw_buffers=False and stores the kv, then the kv buffer
588
+ # will be safely destroyed by either D notifying or expiration.
589
+ self.reqs_wait_pull[req_id] = [kv, req_meta.expiration_time]
590
+ self.kv_transfer_server.await_pull(req_meta.uuid, kv)
591
+
592
+ def _maybe_build_kv_connection(self, req_meta: LoadMeta) -> Any:
593
+ if isinstance(req_meta.remote_host, list):
594
+ assert len(req_meta.remote_host) == len(req_meta.remote_port)
595
+ remote_addr = f"{req_meta.remote_host[self.node_id]}:{req_meta.remote_port[self.node_id]}"
596
+ else:
597
+ remote_addr = f"{req_meta.remote_host}:{req_meta.remote_port}"
598
+
599
+ if remote_addr in self.pull_conns:
600
+ conn = self.pull_conns[remote_addr]
601
+ else:
602
+ conn = self.kv_transfer_server.connect(remote_addr)
603
+ self.pull_conns[remote_addr] = conn
604
+ logger.info(
605
+ f"Worker {self.node_id} --> kv transfer | connect={remote_addr}"
606
+ )
607
+ return conn
608
+
609
+ def _pull_kv(self, conn: Any, req_meta: LoadMeta):
610
+ # The local allocated blocks which don't hit prefix caching.
611
+ local_block_ids = req_meta.local_block_ids
612
+ # The remote computed blocks which need to pull from P.
613
+ remote_block_ids = req_meta.remote_block_ids
614
+ # Make sure they have the same num blocks because we don't care
615
+ # if partial prefix cache hit now.
616
+ assert len(local_block_ids) == len(remote_block_ids)
617
+
618
+ kv_spec = self._get_kv_spec(len(remote_block_ids))
619
+ # TODO(xiang): pad block_ids to avoid recompilation
620
+ indices = device_array(self.mesh, np.array(local_block_ids))
621
+ kv = conn.pull(req_meta.uuid, kv_spec)
622
+ logger.info(
623
+ f"Worker {self.node_id} --> kv transfer | pull uuid={req_meta.uuid}"
624
+ )
625
+ return kv, indices
626
+
627
+ def _get_kv_spec(self, num_blocks: int) -> list[jax.ShapeDtypeStruct]:
628
+ assert num_blocks <= self.shape[0]
629
+ shape = copy.copy(self.shape)
630
+ shape[0] = num_blocks
631
+ return [
632
+ jax.ShapeDtypeStruct(shape, self.dtype, sharding=self.sharding)
633
+ ] * self.num_layers
634
+
635
+ def _maybe_build_notif_socket(self, req_meta: LoadMeta) -> zmq.Socket:
636
+ remote_host = req_meta.remote_host
637
+ if isinstance(req_meta.remote_host, list):
638
+ remote_host = req_meta.remote_host[self.node_id]
639
+
640
+ sock_path = make_zmq_path("tcp", remote_host, self.side_channel_port)
641
+ if sock_path in self.notif_sockets:
642
+ sock = self.notif_sockets[sock_path]
643
+ else:
644
+ sock = make_zmq_socket(ctx=self.zmq_cxt,
645
+ path=sock_path,
646
+ socket_type=zmq.DEALER,
647
+ bind=False)
648
+ logger.info(
649
+ f"Worker {self.node_id} --> zmq notify | sock_path={sock_path}"
650
+ )
651
+ return sock
652
+
653
+ def _notify_pull_done(self, sock: zmq.Socket, req_id: str):
654
+ logger.info(f"Worker {self.node_id} --> zmq notify | req_id={req_id}")
655
+ sock.send_string(req_id)
656
+ # The response is not really needed.
657
+ # ack = sock.recv_string()
658
+
659
+ def get_finished(self) -> tuple[set[str], set[str]]:
660
+ done_sending: set[str] = set()
661
+ done_recving: set[str] = set()
662
+ if not self.reqs_wait_pull and not self.reqs_pulling:
663
+ return done_sending, done_recving
664
+
665
+ # Mark a req as done recieving after its pulling thread returns.
666
+ # This req can then be scheduled for decoding in the next scheduler step.
667
+ for req_id in list(self.reqs_pulling.keys()):
668
+ future = self.reqs_pulling[req_id]
669
+ if future.done():
670
+ # NOTE(xiang): we do the scatter in main thread to avoid data racing.
671
+ # The data racing is not for the kv_caches buffer, it's for the runner.kv_caches ref.
672
+ kv, indices = future.result()
673
+ self.runner.kv_caches = scatter_kv_slices(
674
+ self.runner.kv_caches, kv, indices)
675
+ del self.reqs_pulling[req_id]
676
+ done_recving.add(req_id)
677
+
678
+ # Mark a req as done seding when it's expired.
679
+ # This req can then be released blocks in the current scheduler step.
680
+ now = time.perf_counter()
681
+ for req_id in list(self.reqs_wait_pull):
682
+ _, expires = self.reqs_wait_pull[req_id]
683
+ if now > expires:
684
+ del self.reqs_wait_pull[req_id]
685
+ done_sending.add(req_id)
686
+ if done_sending:
687
+ logger.info(
688
+ f"Worker {self.node_id} --> done_sending={done_sending}")
689
+ if done_recving:
690
+ logger.info(
691
+ f"Worker {self.node_id} --> done_recving={done_recving}")
692
+ return done_sending, done_recving
693
+
694
+
695
+ def get_uuid() -> int:
696
+ int128 = uuid4().int
697
+ # Must be 64-bit int, otherwise vllm output encoder would raise error.
698
+ int64 = int128 >> 64
699
+ return int64
700
+
701
+
702
+ @jax.jit
703
+ def select_from_kv_caches(kv_caches: list[jax.Array],
704
+ indices: list[jax.Array]) -> list[jax.Array]:
705
+ selected = [cache.at[indices].get() for cache in kv_caches]
706
+ return selected
707
+
708
+
709
+ @functools.partial(
710
+ jax.jit,
711
+ donate_argnames=("kv_caches", ),
712
+ )
713
+ def scatter_kv_slices(kv_caches: list[jax.Array], kv_slices: list[jax.Array],
714
+ indices: list[jax.Array]) -> list[jax.Array]:
715
+ num_indices = indices.shape[0]
716
+ num_slices = kv_slices[0].shape[0]
717
+ # indices might be padded
718
+ assert num_slices <= num_indices
719
+
720
+ new_kv_caches = []
721
+ for cache, slice in zip(kv_caches, kv_slices):
722
+ if num_slices < num_indices:
723
+ slice = jnp.pad(slice, ((0, num_indices - num_slices), (0, 0),
724
+ (0, 0), (0, 0)))
725
+ new_cache = cache.at[indices].set(slice)
726
+ new_kv_caches.append(new_cache)
727
+ return new_kv_caches