tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,699 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
"""
|
|
4
|
+
Proxy server routes the request to P with max_output_tokens=1
|
|
5
|
+
|
|
6
|
+
P workflow:
|
|
7
|
+
P recives the request
|
|
8
|
+
|
|
9
|
+
P scheduler checks if the prefill is full done in `request_finished()`
|
|
10
|
+
If done:
|
|
11
|
+
P puts the request-id in `scheduler_output.finished_req_ids`
|
|
12
|
+
and puts the request in `scheduler_output.kv_connector_metadata.reqs_to_send`
|
|
13
|
+
P responds the proxy server with `finished_req_ids` and the `kv_transfer_params`
|
|
14
|
+
P worker gets `reqs_to_send` and runs async `_prepare_kv_and_wait()`
|
|
15
|
+
Else:
|
|
16
|
+
P schedules the prefill with multiple turns due to chunked-prefill.
|
|
17
|
+
|
|
18
|
+
P worker checks if the request has been pulled by D
|
|
19
|
+
If done:
|
|
20
|
+
P worker puts the request-id in `done_sending()`
|
|
21
|
+
P scheduler frees blocks for the requet in done sending.
|
|
22
|
+
Else:
|
|
23
|
+
P holds the blocks for the request until it's pulled by D
|
|
24
|
+
|
|
25
|
+
(
|
|
26
|
+
One scheduler step can finish:
|
|
27
|
+
scheduler RUNNING -> connector reqs_to_send -> worker prefill -> output
|
|
28
|
+
The waiting buffer will get freed after notified by D or expired.
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
Proxy server recives the response from P and forwards it to D
|
|
32
|
+
|
|
33
|
+
D workflow:
|
|
34
|
+
D recives the request
|
|
35
|
+
|
|
36
|
+
D scheduler calculates the num of tokens needing to pull from P in `get_num_new_matched_tokens()`
|
|
37
|
+
D checks if need to pull from P
|
|
38
|
+
If true:
|
|
39
|
+
D puts the request in `scheduler_output.kv_connector_metadata.reqs_to_load`
|
|
40
|
+
D worker gets `reqs_to_load` and runs `_pull_and_write_kv()` in separate threads (to be async)
|
|
41
|
+
D worker checks if the async loading is done:
|
|
42
|
+
If done:
|
|
43
|
+
D worker puts the request-id in `done_recving`.
|
|
44
|
+
D scheduler then knows the request can be scheduled for decoding now. The model decode
|
|
45
|
+
will happen in the next scheduler step.
|
|
46
|
+
Else:
|
|
47
|
+
D worker handles other requests first.
|
|
48
|
+
Else (too short prompt, full local prefix-cache):
|
|
49
|
+
D still needs to puts the request in `reqs_to_load` but with None metadata, because D needs to
|
|
50
|
+
notify P the prefilled KV cache is no longer needed and can be freed in P.
|
|
51
|
+
|
|
52
|
+
(
|
|
53
|
+
Two scheduler steps can finish:
|
|
54
|
+
scheduler WAITING_FOR_REMOTE_KVS -> connector reqs_to_load -> worker wait for pulling
|
|
55
|
+
worker pulling done, notify P to free blocks
|
|
56
|
+
scheduler RUNNING -> connector reqs_to_load=None -> worker decode -> output
|
|
57
|
+
The waiting buffer will get freed after notified by D or expired.
|
|
58
|
+
)
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
import copy
|
|
62
|
+
import functools
|
|
63
|
+
import os
|
|
64
|
+
import threading
|
|
65
|
+
import time
|
|
66
|
+
from concurrent.futures import Future, ThreadPoolExecutor
|
|
67
|
+
from dataclasses import dataclass, field
|
|
68
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
69
|
+
from uuid import uuid4
|
|
70
|
+
|
|
71
|
+
import jax
|
|
72
|
+
import jax.numpy as jnp
|
|
73
|
+
import numpy as np
|
|
74
|
+
import zmq
|
|
75
|
+
from jax.experimental.transfer import start_transfer_server
|
|
76
|
+
from jax.sharding import Mesh
|
|
77
|
+
from vllm.config import VllmConfig
|
|
78
|
+
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|
79
|
+
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
|
80
|
+
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
|
81
|
+
from vllm.v1.core.sched.output import SchedulerOutput
|
|
82
|
+
from vllm.v1.request import RequestStatus
|
|
83
|
+
|
|
84
|
+
if TYPE_CHECKING:
|
|
85
|
+
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
86
|
+
from vllm.v1.request import Request
|
|
87
|
+
|
|
88
|
+
from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
|
|
89
|
+
get_kv_ports,
|
|
90
|
+
get_kv_transfer_port, get_node_id,
|
|
91
|
+
get_side_channel_port)
|
|
92
|
+
from tpu_inference.logger import init_logger
|
|
93
|
+
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
|
|
94
|
+
from tpu_inference.utils import device_array
|
|
95
|
+
|
|
96
|
+
ReqId = str
|
|
97
|
+
|
|
98
|
+
# Feature requests:
|
|
99
|
+
# 1. support async pulling natively
|
|
100
|
+
# 2. partial pulling (like RDMA)
|
|
101
|
+
# 3. non-blocking jax array read/write
|
|
102
|
+
|
|
103
|
+
# The await pull KV cache will be cleared after
|
|
104
|
+
# this time (in seconds) if no pulling occurred on it.
|
|
105
|
+
P2P_WAIT_PULL_TIMEOUT = 120
|
|
106
|
+
|
|
107
|
+
logger = init_logger(__name__)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@dataclass
|
|
111
|
+
class SendMeta:
|
|
112
|
+
uuid: int
|
|
113
|
+
local_block_ids: list[int]
|
|
114
|
+
expiration_time: float
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@dataclass
|
|
118
|
+
class LoadMeta:
|
|
119
|
+
uuid: int
|
|
120
|
+
local_block_ids: list[int]
|
|
121
|
+
remote_block_ids: list[int]
|
|
122
|
+
remote_host: str | list[str]
|
|
123
|
+
remote_port: int | list[int]
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@dataclass
|
|
127
|
+
class _kv_transfer_params:
|
|
128
|
+
"""
|
|
129
|
+
P prepares this in request_finished() and responds to proxy server.
|
|
130
|
+
D recieves this from proxy server and uses this to create LoadMeta.
|
|
131
|
+
"""
|
|
132
|
+
uuid: int
|
|
133
|
+
remote_block_ids: list[int]
|
|
134
|
+
# A single IP for single-host, or a list of IPs for mult-host.
|
|
135
|
+
remote_host: str | list[str]
|
|
136
|
+
# A single port for single-host, or a list of ports for mult-host.
|
|
137
|
+
remote_port: int | list[int]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# The metadata used for communicating between scheduler and worker connectors.
|
|
141
|
+
@dataclass
|
|
142
|
+
class TPUConnectorMetadata(KVConnectorMetadata):
|
|
143
|
+
reqs_to_send: dict[ReqId, SendMeta] = field(default_factory=dict)
|
|
144
|
+
reqs_to_load: dict[ReqId, LoadMeta] = field(default_factory=dict)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class TPUConnector(KVConnectorBase_V1):
|
|
148
|
+
|
|
149
|
+
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
|
150
|
+
assert vllm_config.kv_transfer_config is not None
|
|
151
|
+
|
|
152
|
+
if role == KVConnectorRole.SCHEDULER:
|
|
153
|
+
self.connector_scheduler = \
|
|
154
|
+
TPUConnectorScheduler(vllm_config)
|
|
155
|
+
self.connector_worker = None
|
|
156
|
+
elif role == KVConnectorRole.WORKER:
|
|
157
|
+
self.connector_scheduler = None
|
|
158
|
+
self.connector_worker = TPUConnectorWorker(vllm_config)
|
|
159
|
+
|
|
160
|
+
############################################################
|
|
161
|
+
# Scheduler Side Methods
|
|
162
|
+
############################################################
|
|
163
|
+
def get_num_new_matched_tokens(
|
|
164
|
+
self, request: "Request",
|
|
165
|
+
num_computed_tokens: int) -> tuple[int, bool]:
|
|
166
|
+
assert self.connector_scheduler is not None
|
|
167
|
+
return self.connector_scheduler.get_num_new_matched_tokens(
|
|
168
|
+
request, num_computed_tokens)
|
|
169
|
+
|
|
170
|
+
def update_state_after_alloc(self, request: "Request",
|
|
171
|
+
blocks: "KVCacheBlocks",
|
|
172
|
+
num_external_tokens: int):
|
|
173
|
+
assert self.connector_scheduler is not None
|
|
174
|
+
return self.connector_scheduler.update_state_after_alloc(
|
|
175
|
+
request, blocks, num_external_tokens)
|
|
176
|
+
|
|
177
|
+
def build_connector_meta(
|
|
178
|
+
self,
|
|
179
|
+
scheduler_output: SchedulerOutput,
|
|
180
|
+
) -> TPUConnectorMetadata:
|
|
181
|
+
assert self.connector_scheduler is not None
|
|
182
|
+
return self.connector_scheduler.build_connector_meta()
|
|
183
|
+
|
|
184
|
+
def request_finished(
|
|
185
|
+
self,
|
|
186
|
+
request: "Request",
|
|
187
|
+
block_ids: list[int],
|
|
188
|
+
) -> tuple[bool, Optional[dict[str, Any]]]:
|
|
189
|
+
assert self.connector_scheduler is not None
|
|
190
|
+
return self.connector_scheduler.request_finished(request, block_ids)
|
|
191
|
+
|
|
192
|
+
############################################################
|
|
193
|
+
# Worker Side Methods
|
|
194
|
+
############################################################
|
|
195
|
+
def register_kv_caches(self, kv_caches: list[jax.Array]):
|
|
196
|
+
"""
|
|
197
|
+
We don't register kv_caches in connector, we call `register_runner` and
|
|
198
|
+
use runner.kv_caches directly instead because the ref of runner.kv_caches
|
|
199
|
+
would be reassigned during model forward.
|
|
200
|
+
"""
|
|
201
|
+
pass
|
|
202
|
+
|
|
203
|
+
def register_runner(self, runner: TPUModelRunner) -> None:
|
|
204
|
+
assert self.connector_worker is not None
|
|
205
|
+
self.connector_worker.register_runner(runner)
|
|
206
|
+
|
|
207
|
+
def start_load_kv(self, _, **kwargs) -> None:
|
|
208
|
+
assert self.connector_worker is not None
|
|
209
|
+
assert isinstance(self._connector_metadata, TPUConnectorMetadata)
|
|
210
|
+
self.connector_worker.process_send_load(self._connector_metadata)
|
|
211
|
+
|
|
212
|
+
def wait_for_layer_load(self, layer_name: str) -> None:
|
|
213
|
+
"""TPU connector doesn't support layer wise load."""
|
|
214
|
+
pass
|
|
215
|
+
|
|
216
|
+
def save_kv_layer(self, **kwargs) -> None:
|
|
217
|
+
"""TPU connector doesn't support layer wise save."""
|
|
218
|
+
pass
|
|
219
|
+
|
|
220
|
+
def wait_for_save(self):
|
|
221
|
+
"""
|
|
222
|
+
Not useful for TPU, because by the design of vLLM KVConnectorModelRunnerMixin,
|
|
223
|
+
this function is only called when scheduler_output.total_num_scheduled_tokens is not 0.
|
|
224
|
+
But the reqs_to_send is only available after the req finished prefilling where the
|
|
225
|
+
total_num_scheduled_tokens could be 0 if no other running reqs.
|
|
226
|
+
So we run saving logic in `start_load_kv -> process_send_load` instead.
|
|
227
|
+
"""
|
|
228
|
+
pass
|
|
229
|
+
|
|
230
|
+
def get_finished(self,
|
|
231
|
+
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
|
232
|
+
assert self.connector_worker is not None
|
|
233
|
+
return self.connector_worker.get_finished()
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class TPUConnectorScheduler():
|
|
237
|
+
|
|
238
|
+
def __init__(self, vllm_config: "VllmConfig"):
|
|
239
|
+
self.vllm_config = vllm_config
|
|
240
|
+
self.config = vllm_config.kv_transfer_config
|
|
241
|
+
self.is_producer = self.config.is_kv_producer
|
|
242
|
+
|
|
243
|
+
self.block_size = vllm_config.cache_config.block_size
|
|
244
|
+
|
|
245
|
+
# This is updated in self.update_state_after_alloc() for D,
|
|
246
|
+
# each request that needs to pull KV cache from remote will be added to it.
|
|
247
|
+
self.reqs_to_send: dict[ReqId, SendMeta] = {}
|
|
248
|
+
|
|
249
|
+
# This is updated in self.request_finished() for P,
|
|
250
|
+
# each request that finished prefilling will be added to it.
|
|
251
|
+
self.reqs_to_load: dict[ReqId, LoadMeta] = {}
|
|
252
|
+
|
|
253
|
+
self.kv_ip = get_kv_ips()
|
|
254
|
+
self.kv_port = get_kv_ports()
|
|
255
|
+
logger.info(
|
|
256
|
+
f"Scheduler --> kv_ip={self.kv_ip} | kv_port={self.kv_port}")
|
|
257
|
+
|
|
258
|
+
def get_num_new_matched_tokens(
|
|
259
|
+
self,
|
|
260
|
+
request: "Request",
|
|
261
|
+
num_computed_tokens: int,
|
|
262
|
+
) -> tuple[int, bool]:
|
|
263
|
+
"""
|
|
264
|
+
D workers use this to get the number of new tokens
|
|
265
|
+
that can be loaded from remote P workers.
|
|
266
|
+
No-op for P workers.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
request (Request): the request object.
|
|
270
|
+
num_computed_tokens (int): the number of locally
|
|
271
|
+
computed tokens for this request
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
A tuple with the following elements:
|
|
275
|
+
- The number of tokens that will be loaded from the
|
|
276
|
+
external KV cache.
|
|
277
|
+
- If async loading. Must be 'False' for TPU connector
|
|
278
|
+
because TPU pulls KV cache in a blocking way.
|
|
279
|
+
|
|
280
|
+
"""
|
|
281
|
+
if self.is_producer:
|
|
282
|
+
return 0, False
|
|
283
|
+
|
|
284
|
+
assert num_computed_tokens % self.block_size == 0
|
|
285
|
+
# This rounding logic must be consistent with calculating
|
|
286
|
+
# remote_block_ids in P's request_finished()
|
|
287
|
+
rounded_num_prompt_tokens = round_down(len(request.prompt_token_ids),
|
|
288
|
+
self.block_size)
|
|
289
|
+
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
|
|
290
|
+
# NOTE(xiang): Although the JAX P2P pulling is a blocking op, we will run it in a
|
|
291
|
+
# separte thread to make it async, so we are safe to return True here.
|
|
292
|
+
if count > 0:
|
|
293
|
+
return count, True
|
|
294
|
+
return 0, False
|
|
295
|
+
|
|
296
|
+
def update_state_after_alloc(self, request: "Request",
|
|
297
|
+
blocks: "KVCacheBlocks",
|
|
298
|
+
num_external_tokens: int):
|
|
299
|
+
"""
|
|
300
|
+
Update states after block allocation.
|
|
301
|
+
No-op for P workers.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
request (Request): the request object.
|
|
305
|
+
blocks (KVCacheBlocks): the blocks allocated for the request.
|
|
306
|
+
num_external_tokens (int): the number of tokens that will be
|
|
307
|
+
loaded from the external KV cache.
|
|
308
|
+
"""
|
|
309
|
+
if self.is_producer:
|
|
310
|
+
return
|
|
311
|
+
|
|
312
|
+
params = request.kv_transfer_params
|
|
313
|
+
if num_external_tokens > 0:
|
|
314
|
+
# We need to load KV-cache from remote (partial prefix cache hit).
|
|
315
|
+
local_block_ids = blocks.get_block_ids()[0]
|
|
316
|
+
|
|
317
|
+
# NOTE(xiang): D needs to pull the whole prefill blocks from the remote
|
|
318
|
+
# regardless how much ratio the prefix cache hits.
|
|
319
|
+
# The reason is JAX P2P doesn't work as RDMA, instead it works like:
|
|
320
|
+
# P just prepares the whole prefilled data and waits for pulling, then D pulls the
|
|
321
|
+
# whole data. Which means even with partial prefix cache hit on D, D cannot only
|
|
322
|
+
# pull the remaining partial data from P.
|
|
323
|
+
# Unless we implement a side channel to let P know the prefix cache hit info on D,
|
|
324
|
+
# so P can prepare those non-hit KV only, with that we need to change to:
|
|
325
|
+
# local_block_ids = blocks.get_unhashed_block_ids()
|
|
326
|
+
|
|
327
|
+
self.reqs_to_load[request.request_id] = LoadMeta(
|
|
328
|
+
uuid=params["uuid"],
|
|
329
|
+
local_block_ids=local_block_ids,
|
|
330
|
+
remote_block_ids=params["remote_block_ids"],
|
|
331
|
+
remote_host=params["remote_host"],
|
|
332
|
+
remote_port=params["remote_port"],
|
|
333
|
+
)
|
|
334
|
+
else:
|
|
335
|
+
# This branch means two cases:
|
|
336
|
+
# 1. We don't need to load KV-cache from remote because of full local cache.
|
|
337
|
+
# 2. The async pulling is done.
|
|
338
|
+
# In both cases we need to send notification to let P free memory.
|
|
339
|
+
self.reqs_to_load[request.request_id] = LoadMeta(
|
|
340
|
+
uuid=params["uuid"],
|
|
341
|
+
local_block_ids=None,
|
|
342
|
+
remote_block_ids=None,
|
|
343
|
+
remote_host=params["remote_host"],
|
|
344
|
+
remote_port=params["remote_port"],
|
|
345
|
+
)
|
|
346
|
+
logger.info(f"Scheduler --> reqs_to_load={self.reqs_to_load}")
|
|
347
|
+
|
|
348
|
+
def build_connector_meta(self) -> TPUConnectorMetadata:
|
|
349
|
+
"""
|
|
350
|
+
Build the scheduler metadata and pass to the downstream worker.
|
|
351
|
+
|
|
352
|
+
This function should NOT modify fields in the scheduler_output.
|
|
353
|
+
Also, calling this function will reset the state of the connector.
|
|
354
|
+
"""
|
|
355
|
+
meta = TPUConnectorMetadata()
|
|
356
|
+
|
|
357
|
+
if self.is_producer:
|
|
358
|
+
meta.reqs_to_send = self.reqs_to_send
|
|
359
|
+
self.reqs_to_send = {}
|
|
360
|
+
else:
|
|
361
|
+
meta.reqs_to_load = self.reqs_to_load
|
|
362
|
+
self.reqs_to_load = {}
|
|
363
|
+
|
|
364
|
+
return meta
|
|
365
|
+
|
|
366
|
+
def request_finished(
|
|
367
|
+
self,
|
|
368
|
+
request: "Request",
|
|
369
|
+
block_ids: list[int],
|
|
370
|
+
) -> tuple[bool, Optional[dict[str, Any]]]:
|
|
371
|
+
"""
|
|
372
|
+
Called when a request has finished, before its blocks are freed.
|
|
373
|
+
No-op for D workers.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
request (Request): the request object.
|
|
377
|
+
block_ids: The block IDs allocated for this request and need to be freed.
|
|
378
|
+
Returns:
|
|
379
|
+
True if the request is being saved/sent asynchronously and blocks
|
|
380
|
+
should not be freed until the request_id is returned from
|
|
381
|
+
get_finished().
|
|
382
|
+
Optional KVTransferParams to be included in the request outputs
|
|
383
|
+
returned by the engine.
|
|
384
|
+
"""
|
|
385
|
+
if not self.is_producer:
|
|
386
|
+
return False, None
|
|
387
|
+
|
|
388
|
+
# Mark the request finished only if the prefill is done and generates 1 output token.
|
|
389
|
+
# The request's max_tokens has been reset to 1, so it must be finished by length capped.
|
|
390
|
+
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
|
|
391
|
+
return False, None
|
|
392
|
+
|
|
393
|
+
# NOTE(xiang): Get computed blocks rounded by block_size.
|
|
394
|
+
# This indication means for the last partially filled block, we won't bother transfering
|
|
395
|
+
# KV-cache, will just let D run prefill locally.
|
|
396
|
+
all_full = request.num_computed_tokens % self.block_size == 0
|
|
397
|
+
computed_block_ids = block_ids if all_full else block_ids[:-1]
|
|
398
|
+
|
|
399
|
+
# If prompt < block_size, no transfer so free blocks immediately.
|
|
400
|
+
delay_free_blocks = len(computed_block_ids) > 0
|
|
401
|
+
|
|
402
|
+
if delay_free_blocks:
|
|
403
|
+
uuid = get_uuid()
|
|
404
|
+
expiration_time = time.perf_counter() + P2P_WAIT_PULL_TIMEOUT
|
|
405
|
+
self.reqs_to_send[request.request_id] = SendMeta(
|
|
406
|
+
uuid=uuid,
|
|
407
|
+
local_block_ids=computed_block_ids,
|
|
408
|
+
expiration_time=expiration_time)
|
|
409
|
+
kv_transfer_params = dict(uuid=uuid,
|
|
410
|
+
remote_block_ids=computed_block_ids,
|
|
411
|
+
remote_host=self.kv_ip,
|
|
412
|
+
remote_port=self.kv_port)
|
|
413
|
+
logger.info(f"Scheduler ----> reqs_to_send={self.reqs_to_send} | "
|
|
414
|
+
f"kv_transfer_params={kv_transfer_params}")
|
|
415
|
+
else:
|
|
416
|
+
kv_transfer_params = {}
|
|
417
|
+
|
|
418
|
+
return delay_free_blocks, kv_transfer_params
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
class TPUConnectorWorker:
|
|
422
|
+
|
|
423
|
+
def __init__(self, vllm_config: VllmConfig):
|
|
424
|
+
self.vllm_config = vllm_config
|
|
425
|
+
self.config = vllm_config.kv_transfer_config
|
|
426
|
+
self.is_producer = self.config.is_kv_producer
|
|
427
|
+
|
|
428
|
+
self.runner: TPUModelRunner = None
|
|
429
|
+
self.mesh: Mesh = None
|
|
430
|
+
self.multi_host = os.getenv("TPU_MULTIHOST_BACKEND",
|
|
431
|
+
"").lower() == "ray"
|
|
432
|
+
# NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
|
|
433
|
+
# The worker rank is assigned with vLLM's sorting logic, which does not work
|
|
434
|
+
# for TPU host topology.
|
|
435
|
+
self.node_id = get_node_id()
|
|
436
|
+
|
|
437
|
+
# req_id: (kv, expiration_time)
|
|
438
|
+
self.reqs_wait_pull: dict[ReqId, list[list[jax.Array], float]] = {}
|
|
439
|
+
# req_id: thread_future
|
|
440
|
+
self.reqs_pulling: dict[ReqId, Future] = {}
|
|
441
|
+
|
|
442
|
+
self.host_ip = get_host_ip()
|
|
443
|
+
self.kv_transfer_port = get_kv_transfer_port()
|
|
444
|
+
self.side_channel_port = get_side_channel_port()
|
|
445
|
+
|
|
446
|
+
self.kv_transfer_server = None
|
|
447
|
+
self._maybe_start_p2p_server()
|
|
448
|
+
self.zmq_cxt = zmq.Context()
|
|
449
|
+
if self.is_producer:
|
|
450
|
+
ready_event = threading.Event()
|
|
451
|
+
self.pull_notify_listener_t = threading.Thread(
|
|
452
|
+
target=self._pull_notify_listener,
|
|
453
|
+
args=(ready_event, ),
|
|
454
|
+
daemon=True,
|
|
455
|
+
)
|
|
456
|
+
self.pull_notify_listener_t.start()
|
|
457
|
+
ready_event.wait()
|
|
458
|
+
else:
|
|
459
|
+
self.pull_executor = ThreadPoolExecutor(max_workers=64)
|
|
460
|
+
self.pull_conns: dict[str, Any] = {}
|
|
461
|
+
self.notif_sockets: dict[str, zmq.Socket] = {}
|
|
462
|
+
|
|
463
|
+
logger.info(f"Worker {self.node_id} --> init | "
|
|
464
|
+
f"ip={self.host_ip} | "
|
|
465
|
+
f"kv_transfer_port={self.kv_transfer_port} | "
|
|
466
|
+
f"side_channel_port={self.side_channel_port}")
|
|
467
|
+
|
|
468
|
+
def __del__(self):
|
|
469
|
+
if self.is_producer:
|
|
470
|
+
self.pull_notify_listener_t.join(timeout=0)
|
|
471
|
+
else:
|
|
472
|
+
self.pull_executor.shutdown(wait=False)
|
|
473
|
+
self.zmq_cxt.destroy(linger=0)
|
|
474
|
+
|
|
475
|
+
def register_runner(self, runner: TPUModelRunner):
|
|
476
|
+
self.runner = runner
|
|
477
|
+
self.mesh = runner.mesh
|
|
478
|
+
|
|
479
|
+
# Get the spec of the kv_caches
|
|
480
|
+
kv_caches = runner.kv_caches
|
|
481
|
+
kv_layer = kv_caches[0]
|
|
482
|
+
self.num_layers = len(kv_caches)
|
|
483
|
+
self.shape = list(kv_layer.shape)
|
|
484
|
+
self.dtype = kv_layer.dtype
|
|
485
|
+
self.sharding = kv_layer.sharding
|
|
486
|
+
|
|
487
|
+
def _maybe_start_p2p_server(self):
|
|
488
|
+
if self.kv_transfer_server is not None:
|
|
489
|
+
return
|
|
490
|
+
server_addr = f"{self.host_ip}:{self.kv_transfer_port}"
|
|
491
|
+
transport_addr = f'{self.host_ip}:0'
|
|
492
|
+
self.kv_transfer_server = start_transfer_server(
|
|
493
|
+
jax.local_devices()[0].client,
|
|
494
|
+
server_addr,
|
|
495
|
+
[transport_addr],
|
|
496
|
+
max_num_parallel_copies=8,
|
|
497
|
+
transfer_size=256 * 1024 * 1024,
|
|
498
|
+
use_raw_buffers=False,
|
|
499
|
+
)
|
|
500
|
+
logger.info(
|
|
501
|
+
f"Worker {self.node_id} --> kv transfer | addr={self.kv_transfer_server.address()}"
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
def _pull_notify_listener(self, ready_event: threading.Event):
|
|
505
|
+
sock_path = make_zmq_path("tcp", "*", self.side_channel_port)
|
|
506
|
+
sock = make_zmq_socket(ctx=self.zmq_cxt,
|
|
507
|
+
path=sock_path,
|
|
508
|
+
socket_type=zmq.ROUTER,
|
|
509
|
+
bind=True)
|
|
510
|
+
ready_event.set()
|
|
511
|
+
logger.info(
|
|
512
|
+
f"Worker {self.node_id} --> zmq listener | sock_path={sock_path}")
|
|
513
|
+
|
|
514
|
+
while True:
|
|
515
|
+
client_id, req_id_bytes = sock.recv_multipart()
|
|
516
|
+
req_id = req_id_bytes.decode('utf-8')
|
|
517
|
+
logger.info(
|
|
518
|
+
f"Worker {self.node_id} --> zmq recieve | req_id={req_id}")
|
|
519
|
+
if req_id in self.reqs_wait_pull:
|
|
520
|
+
# Set the expiration time of this request to -1, mark to be done
|
|
521
|
+
self.reqs_wait_pull[req_id][1] = -1
|
|
522
|
+
else:
|
|
523
|
+
raise ValueError(
|
|
524
|
+
f"Disagg producer recives a non-exist pulling finished notification request {req_id}"
|
|
525
|
+
)
|
|
526
|
+
time.sleep(0)
|
|
527
|
+
# The response is not really needed.
|
|
528
|
+
# sock.send_multipart([client_id, b"", b"ACK"])
|
|
529
|
+
|
|
530
|
+
def process_send_load(self, metadata: TPUConnectorMetadata):
|
|
531
|
+
"""
|
|
532
|
+
This is called in runner before calling model forward,
|
|
533
|
+
whenever the scheduler_output.total_num_scheduled_tokens is empty or not.
|
|
534
|
+
"""
|
|
535
|
+
reqs = metadata.reqs_to_send
|
|
536
|
+
if reqs:
|
|
537
|
+
assert self.is_producer
|
|
538
|
+
logger.info(f"Worker {self.node_id} --> reqs_to_send={reqs}")
|
|
539
|
+
for req_id, req_meta in reqs.items():
|
|
540
|
+
self._prepare_kv_and_wait(req_id, req_meta)
|
|
541
|
+
|
|
542
|
+
reqs = metadata.reqs_to_load
|
|
543
|
+
if reqs:
|
|
544
|
+
assert not self.is_producer
|
|
545
|
+
logger.info(f"Worker {self.node_id} --> reqs_to_load={reqs}")
|
|
546
|
+
for req_id, req_meta in reqs.items():
|
|
547
|
+
if req_meta.remote_block_ids is not None:
|
|
548
|
+
# The request requires to pull KV from P, build the connection and pull
|
|
549
|
+
# the data asyncly.
|
|
550
|
+
conn = self._maybe_build_kv_connection(req_meta)
|
|
551
|
+
self.reqs_pulling[req_id] = self.pull_executor.submit(
|
|
552
|
+
self._pull_kv, conn, req_meta)
|
|
553
|
+
else:
|
|
554
|
+
# The request has finished pulling the KV from remote, or it has full local
|
|
555
|
+
# prefix cache, need to notify P to let it free blocks.
|
|
556
|
+
socket = self._maybe_build_notif_socket(req_meta)
|
|
557
|
+
self._notify_pull_done(socket, req_id)
|
|
558
|
+
|
|
559
|
+
def _prepare_kv_and_wait(self, req_id: str, req_meta: SendMeta):
|
|
560
|
+
local_block_ids = req_meta.local_block_ids
|
|
561
|
+
# TODO(xiang): pad block_ids to avoid recompilation
|
|
562
|
+
indices = device_array(self.mesh, np.array(local_block_ids))
|
|
563
|
+
kv = select_from_kv_caches(self.runner.kv_caches, indices)
|
|
564
|
+
# NOTE(xiang): We need to manually store the kv because:
|
|
565
|
+
# Although we can set use_raw_buffers=True to let kv be safely destroyed after
|
|
566
|
+
# calling await_pull, it could be a stranding buffer if D never pulls it.
|
|
567
|
+
# So we have to set use_raw_buffers=False and stores the kv, then the kv buffer
|
|
568
|
+
# will be safely destroyed by either D notifying or expiration.
|
|
569
|
+
self.reqs_wait_pull[req_id] = [kv, req_meta.expiration_time]
|
|
570
|
+
self.kv_transfer_server.await_pull(req_meta.uuid, kv)
|
|
571
|
+
|
|
572
|
+
def _maybe_build_kv_connection(self, req_meta: LoadMeta) -> Any:
|
|
573
|
+
remote_addr = f"{req_meta.remote_host}:{req_meta.remote_port}"
|
|
574
|
+
if remote_addr in self.pull_conns:
|
|
575
|
+
conn = self.pull_conns[remote_addr]
|
|
576
|
+
else:
|
|
577
|
+
conn = self.kv_transfer_server.connect(remote_addr)
|
|
578
|
+
self.pull_conns[remote_addr] = conn
|
|
579
|
+
logger.info(
|
|
580
|
+
f"Worker {self.node_id} --> kv transfer | connect={remote_addr}"
|
|
581
|
+
)
|
|
582
|
+
return conn
|
|
583
|
+
|
|
584
|
+
def _pull_kv(self, conn: Any, req_meta: LoadMeta):
|
|
585
|
+
# The local allocated blocks which don't hit prefix caching.
|
|
586
|
+
local_block_ids = req_meta.local_block_ids
|
|
587
|
+
# The remote computed blocks which need to pull from P.
|
|
588
|
+
remote_block_ids = req_meta.remote_block_ids
|
|
589
|
+
# Make sure they have the same num blocks because we don't care
|
|
590
|
+
# if partial prefix cache hit now.
|
|
591
|
+
assert len(local_block_ids) == len(remote_block_ids)
|
|
592
|
+
|
|
593
|
+
kv_spec = self._get_kv_spec(len(remote_block_ids))
|
|
594
|
+
# TODO(xiang): pad block_ids to avoid recompilation
|
|
595
|
+
indices = device_array(self.mesh, np.array(local_block_ids))
|
|
596
|
+
kv = conn.pull(req_meta.uuid, kv_spec)
|
|
597
|
+
logger.info(
|
|
598
|
+
f"Worker {self.node_id} --> kv transfer | pull uuid={req_meta.uuid}"
|
|
599
|
+
)
|
|
600
|
+
return kv, indices
|
|
601
|
+
|
|
602
|
+
def _get_kv_spec(self, num_blocks: int) -> list[jax.ShapeDtypeStruct]:
|
|
603
|
+
assert num_blocks <= self.shape[0]
|
|
604
|
+
shape = copy.copy(self.shape)
|
|
605
|
+
shape[0] = num_blocks
|
|
606
|
+
return [
|
|
607
|
+
jax.ShapeDtypeStruct(shape, self.dtype, sharding=self.sharding)
|
|
608
|
+
] * self.num_layers
|
|
609
|
+
|
|
610
|
+
def _maybe_build_notif_socket(self, req_meta: LoadMeta) -> zmq.Socket:
|
|
611
|
+
sock_path = make_zmq_path("tcp", req_meta.remote_host,
|
|
612
|
+
self.side_channel_port)
|
|
613
|
+
if sock_path in self.notif_sockets:
|
|
614
|
+
sock = self.notif_sockets[sock_path]
|
|
615
|
+
else:
|
|
616
|
+
sock = make_zmq_socket(ctx=self.zmq_cxt,
|
|
617
|
+
path=sock_path,
|
|
618
|
+
socket_type=zmq.DEALER,
|
|
619
|
+
bind=False)
|
|
620
|
+
logger.info(
|
|
621
|
+
f"Worker {self.node_id} --> zmq notify | sock_path={sock_path}"
|
|
622
|
+
)
|
|
623
|
+
return sock
|
|
624
|
+
|
|
625
|
+
def _notify_pull_done(self, sock: zmq.Socket, req_id: str):
|
|
626
|
+
logger.info(f"Worker {self.node_id} --> zmq notify | req_id={req_id}")
|
|
627
|
+
sock.send_string(req_id)
|
|
628
|
+
# The response is not really needed.
|
|
629
|
+
# ack = sock.recv_string()
|
|
630
|
+
|
|
631
|
+
def get_finished(self) -> tuple[set[str], set[str]]:
|
|
632
|
+
done_sending: set[str] = set()
|
|
633
|
+
done_recving: set[str] = set()
|
|
634
|
+
if not self.reqs_wait_pull and not self.reqs_pulling:
|
|
635
|
+
return done_sending, done_recving
|
|
636
|
+
|
|
637
|
+
# Mark a req as done recieving after its pulling thread returns.
|
|
638
|
+
# This req can then be scheduled for decoding in the next scheduler step.
|
|
639
|
+
for req_id in list(self.reqs_pulling.keys()):
|
|
640
|
+
future = self.reqs_pulling[req_id]
|
|
641
|
+
if future.done():
|
|
642
|
+
# NOTE(xiang): we do the scatter in main thread to avoid data racing.
|
|
643
|
+
# The data racing is not for the kv_caches buffer, it's for the runner.kv_caches ref.
|
|
644
|
+
kv, indices = future.result()
|
|
645
|
+
self.runner.kv_caches = scatter_kv_slices(
|
|
646
|
+
self.runner.kv_caches, kv, indices)
|
|
647
|
+
del self.reqs_pulling[req_id]
|
|
648
|
+
done_recving.add(req_id)
|
|
649
|
+
|
|
650
|
+
# Mark a req as done seding when it's expired.
|
|
651
|
+
# This req can then be released blocks in the current scheduler step.
|
|
652
|
+
now = time.perf_counter()
|
|
653
|
+
for req_id in list(self.reqs_wait_pull):
|
|
654
|
+
_, expires = self.reqs_wait_pull[req_id]
|
|
655
|
+
if now > expires:
|
|
656
|
+
del self.reqs_wait_pull[req_id]
|
|
657
|
+
done_sending.add(req_id)
|
|
658
|
+
if done_sending:
|
|
659
|
+
logger.info(
|
|
660
|
+
f"Worker {self.node_id} --> done_sending={done_sending}")
|
|
661
|
+
if done_recving:
|
|
662
|
+
logger.info(
|
|
663
|
+
f"Worker {self.node_id} --> done_recving={done_recving}")
|
|
664
|
+
return done_sending, done_recving
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def get_uuid() -> int:
|
|
668
|
+
int128 = uuid4().int
|
|
669
|
+
# Must be 64-bit int, otherwise vllm output encoder would raise error.
|
|
670
|
+
int64 = int128 >> 64
|
|
671
|
+
return int64
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
@jax.jit
|
|
675
|
+
def select_from_kv_caches(kv_caches: list[jax.Array],
|
|
676
|
+
indices: list[jax.Array]) -> list[jax.Array]:
|
|
677
|
+
selected = [cache.at[indices].get() for cache in kv_caches]
|
|
678
|
+
return selected
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
@functools.partial(
|
|
682
|
+
jax.jit,
|
|
683
|
+
donate_argnames=("kv_caches", ),
|
|
684
|
+
)
|
|
685
|
+
def scatter_kv_slices(kv_caches: list[jax.Array], kv_slices: list[jax.Array],
|
|
686
|
+
indices: list[jax.Array]) -> list[jax.Array]:
|
|
687
|
+
num_indices = indices.shape[0]
|
|
688
|
+
num_slices = kv_slices[0].shape[0]
|
|
689
|
+
# indices might be padded
|
|
690
|
+
assert num_slices <= num_indices
|
|
691
|
+
|
|
692
|
+
new_kv_caches = []
|
|
693
|
+
for cache, slice in zip(kv_caches, kv_slices):
|
|
694
|
+
if num_slices < num_indices:
|
|
695
|
+
slice = jnp.pad(slice, ((0, num_indices - num_slices), (0, 0),
|
|
696
|
+
(0, 0), (0, 0)))
|
|
697
|
+
new_cache = cache.at[indices].set(slice)
|
|
698
|
+
new_kv_caches.append(new_cache)
|
|
699
|
+
return new_kv_caches
|