tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
tpu_inference/utils.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
import os
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import Any, Callable, List, Tuple
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import numpy as np
|
|
10
|
+
from jax._src import dtypes
|
|
11
|
+
from jax._src import mesh as mesh_lib
|
|
12
|
+
from jax._src import xla_bridge as xb
|
|
13
|
+
from jax._src.lib import xla_client as xc
|
|
14
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
15
|
+
from vllm import envs, utils
|
|
16
|
+
|
|
17
|
+
from tpu_inference.logger import init_logger
|
|
18
|
+
|
|
19
|
+
GBYTES = 1024 * 1024 * 1024
|
|
20
|
+
TPU_HEAD_SIZE_ALIGNMENT = 128
|
|
21
|
+
TPU_SECOND_LAST_MINOR = 8
|
|
22
|
+
|
|
23
|
+
# This is used to translate from a string name for a dtype
|
|
24
|
+
# to formal jax.numpy DType. One use case for this is
|
|
25
|
+
# converting the `--kv_cache_dtype` flag to a dtype.
|
|
26
|
+
TPU_STR_DTYPE_TO_JAX_DTYPE = {
|
|
27
|
+
"bfloat16": jnp.bfloat16,
|
|
28
|
+
"fp8": jnp.float8_e4m3fn,
|
|
29
|
+
"fp8_e4m3": jnp.float8_e4m3,
|
|
30
|
+
"fp8_e5m2": jnp.float8_e5m2,
|
|
31
|
+
"int8": jnp.int8,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
_megacore = False
|
|
35
|
+
logger = init_logger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def enable_megacore() -> None:
|
|
39
|
+
global _megacore
|
|
40
|
+
_megacore = True
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_megacore() -> bool:
|
|
44
|
+
return _megacore
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
|
|
48
|
+
if tp_size <= num_kv_heads:
|
|
49
|
+
assert num_kv_heads % tp_size == 0
|
|
50
|
+
return num_kv_heads
|
|
51
|
+
else:
|
|
52
|
+
assert tp_size % num_kv_heads == 0
|
|
53
|
+
return tp_size
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
|
|
57
|
+
usage = []
|
|
58
|
+
if envs.VLLM_TPU_USING_PATHWAYS:
|
|
59
|
+
return pathways_hbm_usage_gb(devices)
|
|
60
|
+
|
|
61
|
+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
62
|
+
if multihost_backend == "ray":
|
|
63
|
+
# MemoryStats is only supported for addressable PjRt devices.
|
|
64
|
+
# Assume all the devices have similar memory usage for now.
|
|
65
|
+
# TODO(ranlihao): find a proper way to get the memory usage of each device.
|
|
66
|
+
for device in devices:
|
|
67
|
+
try:
|
|
68
|
+
hbm_used = device.memory_stats()["bytes_in_use"]
|
|
69
|
+
hbm_limit = device.memory_stats()["bytes_limit"]
|
|
70
|
+
logger.info(
|
|
71
|
+
"Get memory stats for device %s. Assuming all devices have the same usage.",
|
|
72
|
+
device)
|
|
73
|
+
usage.extend([(hbm_used, hbm_limit)] * len(devices))
|
|
74
|
+
break
|
|
75
|
+
except Exception as e:
|
|
76
|
+
logger.warning(
|
|
77
|
+
"Failed to get memory stats for device %s: %s. ", device,
|
|
78
|
+
e)
|
|
79
|
+
else:
|
|
80
|
+
for device in devices:
|
|
81
|
+
hbm_used = device.memory_stats()["bytes_in_use"]
|
|
82
|
+
hbm_limit = device.memory_stats()["bytes_limit"]
|
|
83
|
+
usage.append((hbm_used, hbm_limit))
|
|
84
|
+
|
|
85
|
+
return usage
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_device_name(num_devices: int | None = None):
|
|
89
|
+
kind = jax.devices()[0].device_kind
|
|
90
|
+
if 'TPU' not in kind:
|
|
91
|
+
raise RuntimeError('Expected TPU devices')
|
|
92
|
+
suffix = ''
|
|
93
|
+
if kind.endswith(' lite'):
|
|
94
|
+
kind = kind[:-len(' lite')]
|
|
95
|
+
suffix = 'e'
|
|
96
|
+
elif kind.endswith('e'):
|
|
97
|
+
kind = kind[:-1]
|
|
98
|
+
suffix = 'e'
|
|
99
|
+
elif kind.endswith('p'):
|
|
100
|
+
kind = kind[:-1]
|
|
101
|
+
suffix = 'p'
|
|
102
|
+
elif kind == 'TPU7x':
|
|
103
|
+
kind = 'TPU v7'
|
|
104
|
+
assert kind[:-1] == 'TPU v', kind
|
|
105
|
+
kind += suffix
|
|
106
|
+
if num_devices is not None:
|
|
107
|
+
kind += f'-{num_devices}'
|
|
108
|
+
return kind
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_device_hbm_limit() -> int:
|
|
112
|
+
|
|
113
|
+
device_kind = get_device_name()
|
|
114
|
+
if device_kind == "TPU v5p" or device_kind == "TPU v5":
|
|
115
|
+
return 95 * GBYTES
|
|
116
|
+
elif device_kind == "TPU v5e":
|
|
117
|
+
return 16 * GBYTES
|
|
118
|
+
elif device_kind == "TPU v6e" or device_kind == "TPU v4":
|
|
119
|
+
return 32 * GBYTES
|
|
120
|
+
elif device_kind == "TPU v7":
|
|
121
|
+
return 192 * GBYTES
|
|
122
|
+
else:
|
|
123
|
+
raise ValueError(f"Unknown device kind: {device_kind}")
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def pathways_hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
|
|
127
|
+
live_arrays = jax.live_arrays()
|
|
128
|
+
hbm_used = defaultdict(int)
|
|
129
|
+
hbm_limit = get_device_hbm_limit()
|
|
130
|
+
for array in live_arrays:
|
|
131
|
+
assert hasattr(array, 'sharding') and hasattr(
|
|
132
|
+
array.sharding, 'device_set'
|
|
133
|
+
), "This function must not be called within jax tracer (e.g. jit, vmap, grad)"
|
|
134
|
+
for device in array.sharding.device_set:
|
|
135
|
+
hbm_used[device] += array.dtype.itemsize * array.size // len(
|
|
136
|
+
array.sharding.device_set)
|
|
137
|
+
return [(hbm_used[device], hbm_limit) for device in devices]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
|
|
141
|
+
usage = hbm_usage_bytes(devices)
|
|
142
|
+
usage = [(round(used / GBYTES, 2), round(limit / GBYTES, 2))
|
|
143
|
+
for used, limit in usage]
|
|
144
|
+
return usage
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def get_padded_head_dim(head_dim: int) -> int:
|
|
148
|
+
"""Pads head_dim up to the nearest multiple of 128 for kernel performance."""
|
|
149
|
+
return (head_dim + 127) // 128 * 128
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
|
|
153
|
+
if num_heads >= sharding_size:
|
|
154
|
+
assert num_heads % sharding_size == 0
|
|
155
|
+
else:
|
|
156
|
+
assert sharding_size % num_heads == 0
|
|
157
|
+
num_heads = sharding_size
|
|
158
|
+
return num_heads
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def get_dtype_packing(dtype):
|
|
162
|
+
bits = dtypes.bit_width(dtype)
|
|
163
|
+
return 32 // bits
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def make_optimized_mesh(axis_shapes: Sequence[int],
|
|
167
|
+
axis_names: Sequence[str],
|
|
168
|
+
*,
|
|
169
|
+
devices: Sequence[xc.Device] | None = None):
|
|
170
|
+
if devices is None:
|
|
171
|
+
devices = xb.devices()
|
|
172
|
+
# Sort the devices in case it's passed in an arbitary order
|
|
173
|
+
devices = sorted(devices, key=lambda x: x.coords)
|
|
174
|
+
|
|
175
|
+
def _is_1D(axis_shapes):
|
|
176
|
+
return sum(x > 1 for x in axis_shapes) == 1
|
|
177
|
+
|
|
178
|
+
if _is_1D(axis_shapes):
|
|
179
|
+
dev_kind = devices[0].device_kind
|
|
180
|
+
device_num = len(devices)
|
|
181
|
+
if dev_kind == "TPU v6 lite":
|
|
182
|
+
ordered_devices = None
|
|
183
|
+
# NOTE(chengjiyao):
|
|
184
|
+
# The coords of v6e-8 are
|
|
185
|
+
# (0,0,0)
|
|
186
|
+
# (1,0,0)
|
|
187
|
+
# (0,1,0)
|
|
188
|
+
# (1,1,0)
|
|
189
|
+
# (0,2,0)
|
|
190
|
+
# (1,2,0)
|
|
191
|
+
# (0,3,0)
|
|
192
|
+
# (1,3,0)
|
|
193
|
+
if device_num == 8:
|
|
194
|
+
ordered_devices = np.array([
|
|
195
|
+
devices[0],
|
|
196
|
+
devices[1],
|
|
197
|
+
devices[2],
|
|
198
|
+
devices[3],
|
|
199
|
+
devices[7],
|
|
200
|
+
devices[6],
|
|
201
|
+
devices[5],
|
|
202
|
+
devices[4],
|
|
203
|
+
])
|
|
204
|
+
# NOTE(chengjiyao):
|
|
205
|
+
# The coords of v6e-4 are
|
|
206
|
+
# (0,0,0)
|
|
207
|
+
# (1,0,0)
|
|
208
|
+
# (0,1,0)
|
|
209
|
+
# (1,1,0)
|
|
210
|
+
elif device_num == 4:
|
|
211
|
+
ordered_devices = np.array([
|
|
212
|
+
devices[0],
|
|
213
|
+
devices[1],
|
|
214
|
+
devices[3],
|
|
215
|
+
devices[2],
|
|
216
|
+
])
|
|
217
|
+
if ordered_devices is not None:
|
|
218
|
+
ordered_devices = np.array(ordered_devices)
|
|
219
|
+
ordered_devices = ordered_devices.reshape(axis_shapes)
|
|
220
|
+
mesh = mesh_lib.Mesh(ordered_devices, axis_names)
|
|
221
|
+
logger.info("Use customized mesh: %s", mesh)
|
|
222
|
+
return mesh
|
|
223
|
+
|
|
224
|
+
return jax.make_mesh(axis_shapes, axis_names, devices=devices)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
|
|
228
|
+
"""
|
|
229
|
+
Create a device array with the specified mesh and sharding.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
mesh: The JAX mesh to use for device placement
|
|
233
|
+
*args: Positional arguments to pass to jax.device_put
|
|
234
|
+
sharding: Optional sharding specification. If None, uses PartitionSpec(None)
|
|
235
|
+
**kwargs: Keyword arguments to pass to jax.device_put
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
A JAX array placed on the specified devices
|
|
239
|
+
"""
|
|
240
|
+
if sharding is None:
|
|
241
|
+
sharding = NamedSharding(mesh, PartitionSpec(None))
|
|
242
|
+
return jax.device_put(*args, device=sharding, **kwargs)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
|
246
|
+
"""
|
|
247
|
+
A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
|
|
248
|
+
"""
|
|
249
|
+
if hash_fn_name == "builtin":
|
|
250
|
+
return hash
|
|
251
|
+
return utils.get_hash_fn_by_name(hash_fn_name)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def quantize_kv(key: jax.Array, value: jax.Array,
|
|
255
|
+
kv_cache_quantized_dtype: jnp.dtype, k_scale: float,
|
|
256
|
+
v_scale: float) -> Tuple[jax.Array, jax.Array]:
|
|
257
|
+
"""
|
|
258
|
+
Quantize the key and value tensors.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
key: The key tensor to quantize.
|
|
262
|
+
value: The value tensor to quantize.
|
|
263
|
+
kv_cache_quantized_dtype: The dtype to quantize the key and value tensors to.
|
|
264
|
+
q_scale: The scale to quantize the key and value tensors by.
|
|
265
|
+
k_scale: The scale to quantize the key tensor by.
|
|
266
|
+
v_scale: The scale to quantize the value tensor by.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
Tuple[jax.Array, jax.Array]: The quantized key and value tensors.
|
|
270
|
+
"""
|
|
271
|
+
dtype_info = jnp.finfo(kv_cache_quantized_dtype)
|
|
272
|
+
minval, maxval = float(dtype_info.min), float(dtype_info.max)
|
|
273
|
+
key = key.astype(jnp.float32) / k_scale
|
|
274
|
+
key = jnp.clip(key, minval, maxval)
|
|
275
|
+
key = key.astype(kv_cache_quantized_dtype)
|
|
276
|
+
value = value.astype(jnp.float32) / v_scale
|
|
277
|
+
value = jnp.clip(value, minval, maxval)
|
|
278
|
+
value = value.astype(kv_cache_quantized_dtype)
|
|
279
|
+
|
|
280
|
+
return key, value
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
284
|
+
"""
|
|
285
|
+
Get the JAX dtype from a string dtype.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
str_dtype: The string dtype to get the JAX dtype from.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
jnp.dtype: The JAX dtype.
|
|
292
|
+
"""
|
|
293
|
+
str_dtype = str_dtype.lower().strip()
|
|
294
|
+
return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
|
|
File without changes
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
#
|
|
4
|
+
# WARNING: This is a temporary compatibility module.
|
|
5
|
+
#
|
|
6
|
+
#
|
|
7
|
+
# THE PROBLEM:
|
|
8
|
+
# The ideal dependency injection pattern dictates that the "producer" of data
|
|
9
|
+
# (in this case, the vLLM engine) should be responsible for adapting its data
|
|
10
|
+
# into the abstract format that the "consumer" (the TPU worker) expects.
|
|
11
|
+
#
|
|
12
|
+
# However, this would require a simultaneous code change in both the `vllm` and
|
|
13
|
+
# `tpu_inference` repositories. Such cross-repository changes are difficult to
|
|
14
|
+
# coordinate, slow to land, and can easily cause breakages if the releases
|
|
15
|
+
# are not perfectly synchronized.
|
|
16
|
+
#
|
|
17
|
+
#
|
|
18
|
+
# THE TEMPORARY SOLUTION:
|
|
19
|
+
# To enable independent development and deployment, we are temporarily violating
|
|
20
|
+
# this pattern. We are making the consumer (`tpu_inference`) responsible for
|
|
21
|
+
# detecting and adapting the producer's raw data.
|
|
22
|
+
#
|
|
23
|
+
# This function checks if it has received a raw `vllm.SchedulerOutput` and,
|
|
24
|
+
# if so, wraps it in the appropriate adapter. This allows `vllm` to continue
|
|
25
|
+
# sending its raw data type without modification, decoupling the release cycles.
|
|
26
|
+
#
|
|
27
|
+
#
|
|
28
|
+
# THE FUTURE (HOW TO REMOVE THIS):
|
|
29
|
+
# This entire file should be deleted once the `vllm` repository has been updated.
|
|
30
|
+
# The required change in `vllm` is small and looks like this:
|
|
31
|
+
#
|
|
32
|
+
# --- SKELETON CODE FOR FUTURE vLLM CHANGE ---
|
|
33
|
+
# In the vLLM engine, where `execute_model` is called:
|
|
34
|
+
#
|
|
35
|
+
# from tpu_inference.adapters.vllm_adapters import VllmSchedulerOutputAdapter
|
|
36
|
+
# from vllm.v1.core.sched.output import SchedulerOutput
|
|
37
|
+
#
|
|
38
|
+
# # ... inside some method ...
|
|
39
|
+
#
|
|
40
|
+
# # OLD CODE:
|
|
41
|
+
# # concrete_work = SchedulerOutput(...)
|
|
42
|
+
# # self.tpu_backend.execute_model(concrete_work)
|
|
43
|
+
#
|
|
44
|
+
# # NEW CODE:
|
|
45
|
+
# concrete_work = SchedulerOutput(...)
|
|
46
|
+
# adapted_work = VllmSchedulerOutputAdapter(concrete_work) # This line is added
|
|
47
|
+
# self.tpu_backend.execute_model(adapted_work) # Pass the adapter
|
|
48
|
+
#
|
|
49
|
+
# --- END SKELETON CODE ---
|
|
50
|
+
#
|
|
51
|
+
|
|
52
|
+
import logging
|
|
53
|
+
from typing import Union
|
|
54
|
+
|
|
55
|
+
from vllm.lora.request import LoRARequest as VllmLoRARequest
|
|
56
|
+
# Import the concrete vLLM type for the check
|
|
57
|
+
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
58
|
+
from vllm.v1.kv_cache_interface import KVCacheConfig as VllmKVCacheConfig
|
|
59
|
+
|
|
60
|
+
from tpu_inference.adapters.vllm_adapters import (VllmKVCacheConfigAdapter,
|
|
61
|
+
VllmLoRARequestAdapter,
|
|
62
|
+
VllmSchedulerOutputAdapter)
|
|
63
|
+
from tpu_inference.di.abstracts import (AbstractKVCacheConfig,
|
|
64
|
+
AbstractLoRARequest,
|
|
65
|
+
AbstractSchedulerOutput)
|
|
66
|
+
|
|
67
|
+
logger = logging.getLogger(__name__)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def adapt_scheduler_output_if_needed(
|
|
71
|
+
scheduler_output: Union[AbstractSchedulerOutput, VllmSchedulerOutput]
|
|
72
|
+
) -> AbstractSchedulerOutput:
|
|
73
|
+
"""
|
|
74
|
+
Checks if the input is a raw VllmSchedulerOutput and wraps it.
|
|
75
|
+
If it's already an AbstractSchedulerOutput, it's passed through.
|
|
76
|
+
"""
|
|
77
|
+
if isinstance(scheduler_output, VllmSchedulerOutput):
|
|
78
|
+
# logger.warning(
|
|
79
|
+
# "Received raw VllmSchedulerOutput. Performing temporary, on-the-fly "
|
|
80
|
+
# "adaptation. This is a compatibility feature and should be removed "
|
|
81
|
+
# "once the vLLM engine is updated to provide an adapted object.")
|
|
82
|
+
return VllmSchedulerOutputAdapter(scheduler_output)
|
|
83
|
+
|
|
84
|
+
if isinstance(scheduler_output, AbstractSchedulerOutput):
|
|
85
|
+
return scheduler_output
|
|
86
|
+
|
|
87
|
+
raise TypeError(
|
|
88
|
+
f"Unsupported type for scheduler_output: {type(scheduler_output)}")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def adapt_kv_cache_config_if_needed(
|
|
92
|
+
kv_cache_config: Union[AbstractKVCacheConfig, VllmKVCacheConfig]
|
|
93
|
+
) -> AbstractKVCacheConfig:
|
|
94
|
+
"""
|
|
95
|
+
Checks if the input is a raw VllmKVCacheConfig and wraps it.
|
|
96
|
+
If it's already an AbstractKVCacheConfig, it's passed through.
|
|
97
|
+
"""
|
|
98
|
+
if isinstance(kv_cache_config, VllmKVCacheConfig):
|
|
99
|
+
# logger.warning(
|
|
100
|
+
# "Received raw VllmKVCacheConfig. Performing temporary, on-the-fly "
|
|
101
|
+
# "adaptation. This is a compatibility feature and should be removed "
|
|
102
|
+
# "once the vLLM engine is updated to provide an adapted object.")
|
|
103
|
+
return VllmKVCacheConfigAdapter(kv_cache_config)
|
|
104
|
+
|
|
105
|
+
if isinstance(kv_cache_config, AbstractKVCacheConfig):
|
|
106
|
+
return kv_cache_config
|
|
107
|
+
|
|
108
|
+
raise TypeError(
|
|
109
|
+
f"Unsupported type for kv_cache_config: {type(kv_cache_config)}")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def adapt_lora_request_if_needed(
|
|
113
|
+
lora_request: Union[AbstractLoRARequest, VllmLoRARequest]
|
|
114
|
+
) -> AbstractLoRARequest:
|
|
115
|
+
"""
|
|
116
|
+
Checks if the input is a raw VllmLoRARequest and wraps it.
|
|
117
|
+
If it's already an AbstractLoRARequest, it's passed through.
|
|
118
|
+
"""
|
|
119
|
+
if isinstance(lora_request, VllmLoRARequest):
|
|
120
|
+
# logger.warning(
|
|
121
|
+
# "Received raw VllmLoRARequest. Performing temporary, on-the-fly "
|
|
122
|
+
# "adaptation. This is a compatibility feature and should be removed "
|
|
123
|
+
# "once the vLLM engine is updated to provide an adapted object.")
|
|
124
|
+
return VllmLoRARequestAdapter(lora_request)
|
|
125
|
+
|
|
126
|
+
if isinstance(lora_request, AbstractLoRARequest):
|
|
127
|
+
return lora_request
|
|
128
|
+
|
|
129
|
+
raise TypeError(f"Unsupported type for lora_request: {type(lora_request)}")
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Optional, Union
|
|
5
|
+
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from vllm.lora.request import LoRARequest
|
|
8
|
+
from vllm.v1.core.sched.output import SchedulerOutput
|
|
9
|
+
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
10
|
+
from vllm.v1.outputs import ModelRunnerOutput
|
|
11
|
+
|
|
12
|
+
from tpu_inference.di.abstracts import (AbstractKVCacheConfig,
|
|
13
|
+
AbstractKVCacheSpec,
|
|
14
|
+
AbstractLoRARequest,
|
|
15
|
+
AbstractSchedulerOutput)
|
|
16
|
+
from tpu_inference.di.interfaces import HostInterface
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AbstractTpuWorker(ABC):
|
|
20
|
+
"""Base class for TPU workers.
|
|
21
|
+
|
|
22
|
+
This class defines a pure, host-agnostic contract for what a TPU worker
|
|
23
|
+
must be able to do. It is intentionally decoupled from any specific host
|
|
24
|
+
system like vLLM or SGLang.
|
|
25
|
+
|
|
26
|
+
Architectural Note on Dependencies:
|
|
27
|
+
This abstract class only depends on other abstractions (e.g., HostInterface).
|
|
28
|
+
It does NOT hold configuration objects from any specific host (e.g.,
|
|
29
|
+
VllmConfig). Doing so would create a "leaky abstraction," forcing all
|
|
30
|
+
future implementations to depend on a concrete detail from a single host.
|
|
31
|
+
|
|
32
|
+
The responsibility for managing concrete configuration is pushed down to the
|
|
33
|
+
concrete subclasses (e.g., TPUWorkerJax), which keeps this base class
|
|
34
|
+
pure and truly reusable across different host systems.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, host_interface: Optional[HostInterface] = None):
|
|
38
|
+
self.host_interface = host_interface
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def initialize_cache(self, num_gpu_blocks: int,
|
|
42
|
+
num_cpu_blocks: int) -> None:
|
|
43
|
+
"""Initialize the cache with the given number of blocks."""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def init_device(self):
|
|
48
|
+
"""Initialize the TPU device and distributed environment."""
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def determine_available_memory(self) -> int:
|
|
53
|
+
"""Determine available memory for the TPU worker."""
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def execute_model(
|
|
58
|
+
self,
|
|
59
|
+
scheduler_output: Union[AbstractSchedulerOutput, SchedulerOutput],
|
|
60
|
+
) -> Optional[ModelRunnerOutput]:
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def profile(self, is_start: bool = True):
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def add_lora(
|
|
69
|
+
self,
|
|
70
|
+
lora_request: Union[AbstractLoRARequest, LoRARequest],
|
|
71
|
+
) -> bool:
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def load_model(self) -> None:
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def compile_or_warm_up_model(self) -> None:
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
@abstractmethod
|
|
83
|
+
def get_model(self) -> nn.Module:
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def get_kv_cache_spec(self) -> dict[str, AbstractKVCacheSpec]:
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def initialize_from_config(
|
|
92
|
+
self,
|
|
93
|
+
kv_cache_config: Union[AbstractKVCacheConfig, KVCacheConfig],
|
|
94
|
+
) -> None:
|
|
95
|
+
"""Allocate KV cache with the specified kv_cache_config."""
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
def check_health(self) -> None:
|
|
99
|
+
# worker will always be healthy as long as it's running.
|
|
100
|
+
return
|