tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
tpu_inference/utils.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from functools import wraps
|
|
7
|
+
from typing import Any, Callable, List, Tuple
|
|
8
|
+
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
import numpy as np
|
|
12
|
+
from jax._src import dtypes
|
|
13
|
+
from jax._src import mesh as mesh_lib
|
|
14
|
+
from jax._src import xla_bridge as xb
|
|
15
|
+
from jax._src.lib import xla_client as xc
|
|
16
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
17
|
+
from vllm import envs, utils
|
|
18
|
+
|
|
19
|
+
from tpu_inference.logger import init_logger
|
|
20
|
+
|
|
21
|
+
GBYTES = 1024 * 1024 * 1024
|
|
22
|
+
TPU_HEAD_SIZE_ALIGNMENT = 128
|
|
23
|
+
TPU_SECOND_LAST_MINOR = 8
|
|
24
|
+
|
|
25
|
+
# This is used to translate from a string name for a dtype
|
|
26
|
+
# to formal jax.numpy DType. One use case for this is
|
|
27
|
+
# converting the `--kv_cache_dtype` flag to a dtype.
|
|
28
|
+
TPU_STR_DTYPE_TO_JAX_DTYPE = {
|
|
29
|
+
"bfloat16": jnp.bfloat16,
|
|
30
|
+
"fp8": jnp.float8_e4m3fn,
|
|
31
|
+
"fp8_e4m3": jnp.float8_e4m3,
|
|
32
|
+
"fp8_e5m2": jnp.float8_e5m2,
|
|
33
|
+
"int8": jnp.int8,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
_megacore = False
|
|
37
|
+
logger = init_logger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def enable_megacore() -> None:
|
|
41
|
+
global _megacore
|
|
42
|
+
_megacore = True
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_megacore() -> bool:
|
|
46
|
+
return _megacore
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
|
|
50
|
+
if tp_size <= num_kv_heads:
|
|
51
|
+
assert num_kv_heads % tp_size == 0
|
|
52
|
+
return num_kv_heads
|
|
53
|
+
else:
|
|
54
|
+
assert tp_size % num_kv_heads == 0
|
|
55
|
+
return tp_size
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
|
|
59
|
+
usage = []
|
|
60
|
+
if envs.VLLM_TPU_USING_PATHWAYS:
|
|
61
|
+
return pathways_hbm_usage_gb(devices)
|
|
62
|
+
|
|
63
|
+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
64
|
+
if multihost_backend == "ray":
|
|
65
|
+
# MemoryStats is only supported for addressable PjRt devices.
|
|
66
|
+
# Assume all the devices have similar memory usage for now.
|
|
67
|
+
# TODO(ranlihao): find a proper way to get the memory usage of each device.
|
|
68
|
+
for device in devices:
|
|
69
|
+
try:
|
|
70
|
+
hbm_used = device.memory_stats()["bytes_in_use"]
|
|
71
|
+
hbm_limit = device.memory_stats()["bytes_limit"]
|
|
72
|
+
logger.info(
|
|
73
|
+
"Get memory stats for device %s. Assuming all devices have the same usage.",
|
|
74
|
+
device)
|
|
75
|
+
usage.extend([(hbm_used, hbm_limit)] * len(devices))
|
|
76
|
+
break
|
|
77
|
+
except Exception as e:
|
|
78
|
+
logger.warning(
|
|
79
|
+
"Failed to get memory stats for device %s: %s. ", device,
|
|
80
|
+
e)
|
|
81
|
+
else:
|
|
82
|
+
for device in devices:
|
|
83
|
+
hbm_used = device.memory_stats()["bytes_in_use"]
|
|
84
|
+
hbm_limit = device.memory_stats()["bytes_limit"]
|
|
85
|
+
usage.append((hbm_used, hbm_limit))
|
|
86
|
+
|
|
87
|
+
return usage
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_device_name(num_devices: int | None = None):
|
|
91
|
+
kind = jax.devices()[0].device_kind
|
|
92
|
+
if 'TPU' not in kind:
|
|
93
|
+
raise RuntimeError('Expected TPU devices')
|
|
94
|
+
suffix = ''
|
|
95
|
+
if kind.endswith(' lite'):
|
|
96
|
+
kind = kind[:-len(' lite')]
|
|
97
|
+
suffix = 'e'
|
|
98
|
+
elif kind.endswith('e'):
|
|
99
|
+
kind = kind[:-1]
|
|
100
|
+
suffix = 'e'
|
|
101
|
+
elif kind.endswith('p'):
|
|
102
|
+
kind = kind[:-1]
|
|
103
|
+
suffix = 'p'
|
|
104
|
+
elif kind == 'TPU7x':
|
|
105
|
+
kind = 'TPU v7'
|
|
106
|
+
assert kind[:-1] == 'TPU v', kind
|
|
107
|
+
kind += suffix
|
|
108
|
+
if num_devices is not None:
|
|
109
|
+
kind += f'-{num_devices}'
|
|
110
|
+
return kind
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_device_hbm_limit() -> int:
|
|
114
|
+
|
|
115
|
+
device_kind = get_device_name()
|
|
116
|
+
if device_kind == "TPU v5p" or device_kind == "TPU v5":
|
|
117
|
+
return 95 * GBYTES
|
|
118
|
+
elif device_kind == "TPU v5e":
|
|
119
|
+
return 16 * GBYTES
|
|
120
|
+
elif device_kind == "TPU v6e" or device_kind == "TPU v4":
|
|
121
|
+
return 32 * GBYTES
|
|
122
|
+
elif device_kind == "TPU v7":
|
|
123
|
+
# 192 * GBYTES / 2 because each JAX device (v7x core) has
|
|
124
|
+
# 1/2 of the total chip HBM
|
|
125
|
+
return 96 * GBYTES
|
|
126
|
+
else:
|
|
127
|
+
raise ValueError(f"Unknown device kind: {device_kind}")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def pathways_hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
|
|
131
|
+
live_arrays = jax.live_arrays()
|
|
132
|
+
hbm_used = defaultdict(int)
|
|
133
|
+
hbm_limit = get_device_hbm_limit()
|
|
134
|
+
for array in live_arrays:
|
|
135
|
+
for buffer in array.addressable_shards:
|
|
136
|
+
hbm_used[buffer.data.device] += buffer.data.nbytes
|
|
137
|
+
return [(hbm_used[device], hbm_limit) for device in devices]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
|
|
141
|
+
usage = hbm_usage_bytes(devices)
|
|
142
|
+
usage = [(round(used / GBYTES, 2), round(limit / GBYTES, 2))
|
|
143
|
+
for used, limit in usage]
|
|
144
|
+
return usage
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def get_padded_head_dim(head_dim: int) -> int:
|
|
148
|
+
"""Pads head_dim up to the nearest multiple of 128 for kernel performance."""
|
|
149
|
+
# When head_dim == 64, we use kernel specificly optimized for it which does
|
|
150
|
+
# not require any padding.
|
|
151
|
+
if head_dim == 64:
|
|
152
|
+
return 64
|
|
153
|
+
return (head_dim + 127) // 128 * 128
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
|
|
157
|
+
if num_heads >= sharding_size:
|
|
158
|
+
assert num_heads % sharding_size == 0
|
|
159
|
+
else:
|
|
160
|
+
assert sharding_size % num_heads == 0
|
|
161
|
+
num_heads = sharding_size
|
|
162
|
+
return num_heads
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def get_dtype_packing(dtype):
|
|
166
|
+
bits = dtypes.bit_width(dtype)
|
|
167
|
+
return 32 // bits
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def make_optimized_mesh(axis_shapes: Sequence[int],
|
|
171
|
+
axis_names: Sequence[str],
|
|
172
|
+
*,
|
|
173
|
+
devices: Sequence[xc.Device] | None = None):
|
|
174
|
+
if devices is None:
|
|
175
|
+
devices = xb.devices()
|
|
176
|
+
# Sort the devices in case it's passed in an arbitary order
|
|
177
|
+
devices = sorted(devices, key=lambda x: x.coords)
|
|
178
|
+
|
|
179
|
+
def _is_1D(axis_shapes):
|
|
180
|
+
return sum(x > 1 for x in axis_shapes) == 1
|
|
181
|
+
|
|
182
|
+
if _is_1D(axis_shapes):
|
|
183
|
+
dev_kind = devices[0].device_kind
|
|
184
|
+
device_num = len(devices)
|
|
185
|
+
if dev_kind == "TPU v6 lite":
|
|
186
|
+
ordered_devices = None
|
|
187
|
+
# NOTE(chengjiyao):
|
|
188
|
+
# The coords of v6e-8 are
|
|
189
|
+
# (0,0,0)
|
|
190
|
+
# (1,0,0)
|
|
191
|
+
# (0,1,0)
|
|
192
|
+
# (1,1,0)
|
|
193
|
+
# (0,2,0)
|
|
194
|
+
# (1,2,0)
|
|
195
|
+
# (0,3,0)
|
|
196
|
+
# (1,3,0)
|
|
197
|
+
if device_num == 8:
|
|
198
|
+
ordered_devices = np.array([
|
|
199
|
+
devices[0],
|
|
200
|
+
devices[1],
|
|
201
|
+
devices[2],
|
|
202
|
+
devices[3],
|
|
203
|
+
devices[7],
|
|
204
|
+
devices[6],
|
|
205
|
+
devices[5],
|
|
206
|
+
devices[4],
|
|
207
|
+
])
|
|
208
|
+
# NOTE(chengjiyao):
|
|
209
|
+
# The coords of v6e-4 are
|
|
210
|
+
# (0,0,0)
|
|
211
|
+
# (1,0,0)
|
|
212
|
+
# (0,1,0)
|
|
213
|
+
# (1,1,0)
|
|
214
|
+
elif device_num == 4:
|
|
215
|
+
ordered_devices = np.array([
|
|
216
|
+
devices[0],
|
|
217
|
+
devices[1],
|
|
218
|
+
devices[3],
|
|
219
|
+
devices[2],
|
|
220
|
+
])
|
|
221
|
+
if ordered_devices is not None:
|
|
222
|
+
ordered_devices = np.array(ordered_devices)
|
|
223
|
+
ordered_devices = ordered_devices.reshape(axis_shapes)
|
|
224
|
+
mesh = mesh_lib.Mesh(ordered_devices, axis_names)
|
|
225
|
+
logger.info("Use customized mesh: %s", mesh)
|
|
226
|
+
return mesh
|
|
227
|
+
|
|
228
|
+
return jax.make_mesh(axis_shapes, axis_names, devices=devices)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
|
|
232
|
+
"""
|
|
233
|
+
Create a device array with the specified mesh and sharding.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
mesh: The JAX mesh to use for device placement
|
|
237
|
+
*args: Positional arguments to pass to jax.device_put
|
|
238
|
+
sharding: Optional sharding specification. If None, uses PartitionSpec(None)
|
|
239
|
+
**kwargs: Keyword arguments to pass to jax.device_put
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
A JAX array placed on the specified devices
|
|
243
|
+
"""
|
|
244
|
+
if sharding is None:
|
|
245
|
+
sharding = NamedSharding(mesh, PartitionSpec(None))
|
|
246
|
+
return jax.device_put(*args, device=sharding, **kwargs)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
|
250
|
+
"""
|
|
251
|
+
A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
|
|
252
|
+
"""
|
|
253
|
+
if hash_fn_name == "builtin":
|
|
254
|
+
return hash
|
|
255
|
+
return utils.get_hash_fn_by_name(hash_fn_name)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def quantize_kv(key: jax.Array, value: jax.Array,
|
|
259
|
+
kv_cache_quantized_dtype: jnp.dtype, k_scale: float,
|
|
260
|
+
v_scale: float) -> Tuple[jax.Array, jax.Array]:
|
|
261
|
+
"""
|
|
262
|
+
Quantize the key and value tensors.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
key: The key tensor to quantize.
|
|
266
|
+
value: The value tensor to quantize.
|
|
267
|
+
kv_cache_quantized_dtype: The dtype to quantize the key and value tensors to.
|
|
268
|
+
q_scale: The scale to quantize the key and value tensors by.
|
|
269
|
+
k_scale: The scale to quantize the key tensor by.
|
|
270
|
+
v_scale: The scale to quantize the value tensor by.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
Tuple[jax.Array, jax.Array]: The quantized key and value tensors.
|
|
274
|
+
"""
|
|
275
|
+
dtype_info = jnp.finfo(kv_cache_quantized_dtype)
|
|
276
|
+
minval, maxval = float(dtype_info.min), float(dtype_info.max)
|
|
277
|
+
key = key.astype(jnp.float32) / k_scale
|
|
278
|
+
key = jnp.clip(key, minval, maxval)
|
|
279
|
+
key = key.astype(kv_cache_quantized_dtype)
|
|
280
|
+
value = value.astype(jnp.float32) / v_scale
|
|
281
|
+
value = jnp.clip(value, minval, maxval)
|
|
282
|
+
value = value.astype(kv_cache_quantized_dtype)
|
|
283
|
+
|
|
284
|
+
return key, value
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
288
|
+
"""
|
|
289
|
+
Get the JAX dtype from a string dtype.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
str_dtype: The string dtype to get the JAX dtype from.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
jnp.dtype: The JAX dtype.
|
|
296
|
+
"""
|
|
297
|
+
str_dtype = str_dtype.lower().strip()
|
|
298
|
+
return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def time_function(func):
|
|
302
|
+
"""
|
|
303
|
+
A decorator to measure the execution time of a function.
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
@wraps(func)
|
|
307
|
+
def wrapper(*args, **kwargs):
|
|
308
|
+
start_time = time.perf_counter()
|
|
309
|
+
result = func(*args, **kwargs)
|
|
310
|
+
end_time = time.perf_counter()
|
|
311
|
+
execution_time = end_time - start_time
|
|
312
|
+
logger.debug(
|
|
313
|
+
f"Function '{func.__name__}' executed in {execution_time:.4f} seconds."
|
|
314
|
+
)
|
|
315
|
+
return result
|
|
316
|
+
|
|
317
|
+
return wrapper
|
|
File without changes
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import tempfile
|
|
5
|
+
from typing import Callable, Dict, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import jaxlib
|
|
10
|
+
import jaxtyping
|
|
11
|
+
import vllm.envs as vllm_envs
|
|
12
|
+
from vllm.config import VllmConfig, set_current_vllm_config
|
|
13
|
+
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
|
14
|
+
has_kv_transfer_group)
|
|
15
|
+
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
16
|
+
init_distributed_environment)
|
|
17
|
+
from vllm.lora.request import LoRARequest
|
|
18
|
+
from vllm.tasks import SupportedTask
|
|
19
|
+
from vllm.v1 import utils as vllm_utils
|
|
20
|
+
from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size
|
|
21
|
+
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
|
22
|
+
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|
23
|
+
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
|
24
|
+
|
|
25
|
+
from tpu_inference import envs, utils
|
|
26
|
+
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
|
|
27
|
+
get_node_id)
|
|
28
|
+
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
29
|
+
from tpu_inference.logger import init_logger
|
|
30
|
+
from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
|
|
31
|
+
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
32
|
+
|
|
33
|
+
logger = init_logger(__name__)
|
|
34
|
+
|
|
35
|
+
_DTYPE: dict[str, jnp.dtype] = {
|
|
36
|
+
"bfloat16": jnp.bfloat16,
|
|
37
|
+
"float": jnp.float32,
|
|
38
|
+
"float32": jnp.float32,
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class TPUWorker:
|
|
43
|
+
|
|
44
|
+
def __init__(self,
|
|
45
|
+
vllm_config: VllmConfig,
|
|
46
|
+
local_rank: int,
|
|
47
|
+
rank: int,
|
|
48
|
+
distributed_init_method: str,
|
|
49
|
+
is_driver_worker: bool = False,
|
|
50
|
+
devices=None):
|
|
51
|
+
# If we use vLLM's model implementation in PyTorch, we should set it
|
|
52
|
+
# with torch version of the dtype.
|
|
53
|
+
impl = envs.MODEL_IMPL_TYPE
|
|
54
|
+
if impl != "vllm": # vllm-pytorch implementation does not need this conversion
|
|
55
|
+
|
|
56
|
+
# NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
|
|
57
|
+
if not isinstance(vllm_config.model_config.dtype, str):
|
|
58
|
+
logger.warning(
|
|
59
|
+
"The model dtype is not properly set for JAX backend. "
|
|
60
|
+
"Overwriting it to jnp.bfloat16")
|
|
61
|
+
vllm_config.model_config.dtype = jnp.bfloat16
|
|
62
|
+
else:
|
|
63
|
+
vllm_config.model_config.dtype = _DTYPE.get(
|
|
64
|
+
vllm_config.model_config.dtype, jnp.bfloat16)
|
|
65
|
+
|
|
66
|
+
self.vllm_config = vllm_config
|
|
67
|
+
self.model_config = vllm_config.model_config
|
|
68
|
+
self.parallel_config = vllm_config.parallel_config
|
|
69
|
+
self.cache_config = vllm_config.cache_config
|
|
70
|
+
self.local_rank = local_rank
|
|
71
|
+
self.rank = rank
|
|
72
|
+
self.distributed_init_method = distributed_init_method
|
|
73
|
+
self.is_driver_worker = is_driver_worker
|
|
74
|
+
self.devices = devices if devices is not None else []
|
|
75
|
+
self.device_ranks = set(device.id for device in self.devices
|
|
76
|
+
if isinstance(device, jaxlib._jax.Device))
|
|
77
|
+
|
|
78
|
+
if self.model_config.trust_remote_code:
|
|
79
|
+
# note: lazy import to avoid importing torch before initializing
|
|
80
|
+
from vllm.utils import init_cached_hf_modules
|
|
81
|
+
|
|
82
|
+
init_cached_hf_modules()
|
|
83
|
+
|
|
84
|
+
# Delay profiler initialization to the start of the profiling.
|
|
85
|
+
# This is because in vLLM V1, MP runtime is initialized before the
|
|
86
|
+
# TPU Worker is initialized. The profiler server needs to start after
|
|
87
|
+
# MP runtime is initialized.
|
|
88
|
+
self.profile_dir = None
|
|
89
|
+
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
|
|
90
|
+
if not self.devices or 0 in self.device_ranks:
|
|
91
|
+
# For TPU, we can only have 1 active profiler session for 1 profiler
|
|
92
|
+
# server. So we only profile on rank0.
|
|
93
|
+
self.profile_dir = vllm_envs.VLLM_TORCH_PROFILER_DIR
|
|
94
|
+
logger.info("Profiling enabled. Traces will be saved to: %s",
|
|
95
|
+
self.profile_dir)
|
|
96
|
+
|
|
97
|
+
use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
|
|
98
|
+
# Only one instance of profiler is allowed
|
|
99
|
+
if use_jax_profiler_server and self.rank < 1:
|
|
100
|
+
if not self.devices or 0 in self.device_ranks:
|
|
101
|
+
jax_profiler_server_port = int(
|
|
102
|
+
os.getenv("JAX_PROFILER_SERVER_PORT", 9999))
|
|
103
|
+
logger.info(
|
|
104
|
+
f"Starting JAX profiler server on port {jax_profiler_server_port}"
|
|
105
|
+
)
|
|
106
|
+
jax.profiler.start_server(jax_profiler_server_port)
|
|
107
|
+
|
|
108
|
+
def initialize_cache(self, num_gpu_blocks: int,
|
|
109
|
+
num_cpu_blocks: int) -> None:
|
|
110
|
+
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
111
|
+
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
112
|
+
|
|
113
|
+
def init_device(self):
|
|
114
|
+
if not self.devices:
|
|
115
|
+
sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
|
|
116
|
+
device_indexes = sharding_config.device_indexes
|
|
117
|
+
if device_indexes is not None and len(device_indexes) > 0:
|
|
118
|
+
# Enforcing the devices sequence to be consistent with the specified device indexes
|
|
119
|
+
all_devices = jax.devices()
|
|
120
|
+
device_dict = {device.id: device for device in all_devices}
|
|
121
|
+
self.devices = []
|
|
122
|
+
for device_index in device_indexes:
|
|
123
|
+
device = device_dict[device_index]
|
|
124
|
+
if device is None:
|
|
125
|
+
raise KeyError(
|
|
126
|
+
f"Device index {device_index} not found in "
|
|
127
|
+
f"jax.devices() with IDs {list(device_dict.keys())}!"
|
|
128
|
+
)
|
|
129
|
+
self.devices.append(device)
|
|
130
|
+
self.devices = self.devices[:sharding_config.total_devices]
|
|
131
|
+
else:
|
|
132
|
+
self.devices = jax.devices()[:sharding_config.total_devices]
|
|
133
|
+
|
|
134
|
+
# Initialize the vLLM distribution layer as a single chip environment,
|
|
135
|
+
# we'll swap the model's parallel modules with TPU SPMD equivalents.
|
|
136
|
+
with set_current_vllm_config(self.vllm_config):
|
|
137
|
+
temp_file = tempfile.mkstemp()[1]
|
|
138
|
+
init_distributed_environment(
|
|
139
|
+
world_size=1,
|
|
140
|
+
rank=0,
|
|
141
|
+
local_rank=0,
|
|
142
|
+
distributed_init_method=f"file://{temp_file}",
|
|
143
|
+
backend="gloo",
|
|
144
|
+
)
|
|
145
|
+
ensure_model_parallel_initialized(
|
|
146
|
+
tensor_model_parallel_size=1,
|
|
147
|
+
pipeline_model_parallel_size=1,
|
|
148
|
+
)
|
|
149
|
+
ensure_kv_transfer_initialized(self.vllm_config)
|
|
150
|
+
self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
|
|
151
|
+
logger.info(f"Init worker | "
|
|
152
|
+
f"rank={self.rank} | "
|
|
153
|
+
f"node_id={get_node_id()} | "
|
|
154
|
+
f"is_driver_worker={self.is_driver_worker} | "
|
|
155
|
+
f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
|
|
156
|
+
vllm_utils.report_usage_stats(self.vllm_config)
|
|
157
|
+
|
|
158
|
+
def determine_available_memory(self) -> int:
|
|
159
|
+
gpu_memory_utilization = self.cache_config.gpu_memory_utilization
|
|
160
|
+
hbm_usage = utils.hbm_usage_bytes(self.devices)
|
|
161
|
+
total_hbm_limit = total_hbm_used = 0
|
|
162
|
+
for used, limit in hbm_usage:
|
|
163
|
+
total_hbm_used += used
|
|
164
|
+
total_hbm_limit += limit
|
|
165
|
+
|
|
166
|
+
total_hbm_limit_cap = total_hbm_limit * gpu_memory_utilization
|
|
167
|
+
total_hbm_avail = int(total_hbm_limit_cap - total_hbm_used)
|
|
168
|
+
|
|
169
|
+
total_hbm_limit_gb = round(total_hbm_limit / utils.GBYTES, 2)
|
|
170
|
+
total_hbm_limit_cap_gb = round(total_hbm_limit_cap / utils.GBYTES, 2)
|
|
171
|
+
total_hbm_used_gb = round(total_hbm_used / utils.GBYTES, 2)
|
|
172
|
+
total_hbm_avail_gb = round(total_hbm_avail / utils.GBYTES, 2)
|
|
173
|
+
|
|
174
|
+
logger.info(f"Memory statistics | "
|
|
175
|
+
f"{total_hbm_limit_gb=}GiB | "
|
|
176
|
+
f"{total_hbm_limit_cap_gb=}GiB | "
|
|
177
|
+
f"{total_hbm_used_gb=}GiB | "
|
|
178
|
+
f"{total_hbm_avail_gb=}GiB")
|
|
179
|
+
|
|
180
|
+
if total_hbm_avail <= 0:
|
|
181
|
+
raise ValueError(f"{total_hbm_used_gb=}GiB exceeds "
|
|
182
|
+
f"{total_hbm_limit_cap_gb=}GiB by "
|
|
183
|
+
f"{-total_hbm_avail_gb}GiB. Please consider "
|
|
184
|
+
f"increasing --gpu-memory-utilization from "
|
|
185
|
+
f"{gpu_memory_utilization} to a larger value.")
|
|
186
|
+
return total_hbm_avail
|
|
187
|
+
|
|
188
|
+
def execute_model(
|
|
189
|
+
self,
|
|
190
|
+
scheduler_output: SchedulerOutput,
|
|
191
|
+
) -> Optional[ModelRunnerOutput]:
|
|
192
|
+
# NOTE: This method intentionally returns a concrete vLLM type, which
|
|
193
|
+
# violates the pure abstract contract of the base class. This is a
|
|
194
|
+
# deliberate, temporary compromise for the same reasons outlined in
|
|
195
|
+
# the `get_kv_cache_spec` method.
|
|
196
|
+
|
|
197
|
+
output = self.model_runner.execute_model(scheduler_output)
|
|
198
|
+
|
|
199
|
+
# With a connector, the scheduler expects output from all workers
|
|
200
|
+
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
|
|
201
|
+
if has_kv_transfer_group():
|
|
202
|
+
return output
|
|
203
|
+
|
|
204
|
+
return output if self.is_driver_worker else None
|
|
205
|
+
|
|
206
|
+
def sample_tokens(self,
|
|
207
|
+
grammar_output: GrammarOutput) -> ModelRunnerOutput:
|
|
208
|
+
return self.model_runner.sample_tokens(grammar_output)
|
|
209
|
+
|
|
210
|
+
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
|
211
|
+
return self.model_runner.take_draft_token_ids()
|
|
212
|
+
|
|
213
|
+
def add_lora(
|
|
214
|
+
self,
|
|
215
|
+
lora_request: LoRARequest,
|
|
216
|
+
) -> bool:
|
|
217
|
+
raise NotImplementedError(
|
|
218
|
+
"LoRA is not supported by the JAX worker yet.")
|
|
219
|
+
|
|
220
|
+
def profile(self, is_start: bool = True):
|
|
221
|
+
if is_start:
|
|
222
|
+
options = jax.profiler.ProfileOptions()
|
|
223
|
+
# default: https://docs.jax.dev/en/latest/profiling.html#general-options
|
|
224
|
+
options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0)
|
|
225
|
+
options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
|
|
226
|
+
jax.profiler.start_trace(self.profile_dir,
|
|
227
|
+
profiler_options=options)
|
|
228
|
+
else:
|
|
229
|
+
jax.profiler.stop_trace()
|
|
230
|
+
|
|
231
|
+
def load_model(self) -> None:
|
|
232
|
+
self.model_runner.load_model()
|
|
233
|
+
|
|
234
|
+
def compile_or_warm_up_model(self) -> None:
|
|
235
|
+
self.model_runner.capture_model()
|
|
236
|
+
# Reset the seed to ensure that the random state is not affected by
|
|
237
|
+
# the model initialization and profiling.
|
|
238
|
+
self.model_runner._init_random()
|
|
239
|
+
|
|
240
|
+
def reset_mm_cache(self) -> None:
|
|
241
|
+
pass
|
|
242
|
+
|
|
243
|
+
def get_model(self):
|
|
244
|
+
return self.model_runner.get_model()
|
|
245
|
+
|
|
246
|
+
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
|
247
|
+
return self.model_runner.get_supported_tasks()
|
|
248
|
+
|
|
249
|
+
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
|
250
|
+
# NOTE: This method intentionally returns a concrete vLLM type, which
|
|
251
|
+
# violates the pure abstract contract of the base class. This is a
|
|
252
|
+
# deliberate, temporary compromise.
|
|
253
|
+
#
|
|
254
|
+
# The vLLM executor that calls this method expects the concrete
|
|
255
|
+
# `vllm.KVCacheSpec` object to perform its own internal logic. If we
|
|
256
|
+
# returned an abstract adapter, the vLLM code would break.
|
|
257
|
+
#
|
|
258
|
+
# The ideal long-term solution is for the vLLM DI container to be
|
|
259
|
+
# responsible for this translation. When vLLM can be modified, this
|
|
260
|
+
# method should be changed to return `dict[str, AbstractKVCacheSpec]`,
|
|
261
|
+
# and the vLLM side should be updated to handle the translation.
|
|
262
|
+
kv_cache_specs = self.model_runner.get_kv_cache_spec()
|
|
263
|
+
|
|
264
|
+
if len(kv_cache_specs) == 0:
|
|
265
|
+
return kv_cache_specs
|
|
266
|
+
|
|
267
|
+
# TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
|
|
268
|
+
# feature that allows overriding page_size_bytes of KVCacheSpec.
|
|
269
|
+
vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs)
|
|
270
|
+
rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
|
|
271
|
+
kv_cache_specs)
|
|
272
|
+
|
|
273
|
+
if vllm_page_size_bytes != rpa_page_size_bytes:
|
|
274
|
+
logger.info(
|
|
275
|
+
f"KV cache page size calculated by vLLM "
|
|
276
|
+
f"({vllm_page_size_bytes} Bytes) does not match with actual "
|
|
277
|
+
f"page size used by RPA kernel ({rpa_page_size_bytes} Bytes). "
|
|
278
|
+
f"Recalculating number of KV blocks using actual page size.")
|
|
279
|
+
|
|
280
|
+
available_memory = self.determine_available_memory()
|
|
281
|
+
num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs),
|
|
282
|
+
available_memory, rpa_page_size_bytes)
|
|
283
|
+
|
|
284
|
+
cache_config = self.vllm_config.cache_config
|
|
285
|
+
cache_config.num_gpu_blocks_override = num_blocks
|
|
286
|
+
|
|
287
|
+
return kv_cache_specs
|
|
288
|
+
|
|
289
|
+
def initialize_from_config(
|
|
290
|
+
self,
|
|
291
|
+
kv_cache_config: KVCacheConfig,
|
|
292
|
+
) -> None:
|
|
293
|
+
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
|
294
|
+
self.model_runner.initialize_kv_cache(kv_cache_config)
|
|
295
|
+
|
|
296
|
+
def get_node_kv_ip_port(self) -> tuple[int, str, int]:
|
|
297
|
+
node_id = get_node_id()
|
|
298
|
+
ip = get_host_ip()
|
|
299
|
+
port = get_kv_transfer_port()
|
|
300
|
+
return (int(node_id), ip, int(port))
|
|
301
|
+
|
|
302
|
+
def check_health(self) -> None:
|
|
303
|
+
# worker will always be healthy as long as it's running.
|
|
304
|
+
return
|
|
305
|
+
|
|
306
|
+
def sync_weights(
|
|
307
|
+
self,
|
|
308
|
+
updated_weights: jaxtyping.PyTree,
|
|
309
|
+
mappings: Dict[str, Tuple[str, Tuple[str]]],
|
|
310
|
+
transpose_keys: Dict[str, Tuple[int]],
|
|
311
|
+
reshard_fn: Callable[[jaxtyping.PyTree, jaxtyping.PyTree],
|
|
312
|
+
jaxtyping.PyTree] = None
|
|
313
|
+
) -> None:
|
|
314
|
+
"""Sync the updated weights to the model runner."""
|
|
315
|
+
return self.model_runner._sync_weights(updated_weights=updated_weights,
|
|
316
|
+
mappings=mappings,
|
|
317
|
+
transpose_keys=transpose_keys,
|
|
318
|
+
reshard_fn=reshard_fn)
|
|
319
|
+
|
|
320
|
+
def shutdown(self) -> None:
|
|
321
|
+
return
|