tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -7
- tests/lora/test_lora_perf.py +53 -0
- tests/lora/utils.py +0 -8
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +93 -9
- tpu_inference/executors/ray_distributed_executor.py +9 -2
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +84 -28
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +205 -144
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
- tpu_inference/platforms/tpu_platform.py +34 -50
- tpu_inference/runner/compilation_manager.py +144 -60
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +48 -33
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +280 -149
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +46 -18
- tpu_inference/worker/tpu_worker.py +197 -63
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -2,14 +2,15 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import tempfile
|
|
5
|
+
from dataclasses import dataclass, field
|
|
5
6
|
from typing import Callable, Dict, Optional, Tuple
|
|
6
7
|
|
|
7
8
|
import jax
|
|
8
|
-
import jax.numpy as jnp
|
|
9
9
|
import jaxlib
|
|
10
10
|
import jaxtyping
|
|
11
11
|
import vllm.envs as vllm_envs
|
|
12
12
|
from vllm.config import VllmConfig, set_current_vllm_config
|
|
13
|
+
from vllm.distributed import get_pp_group
|
|
13
14
|
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
|
14
15
|
has_kv_transfer_group)
|
|
15
16
|
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
@@ -17,52 +18,59 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
|
17
18
|
from vllm.lora.request import LoRARequest
|
|
18
19
|
from vllm.tasks import SupportedTask
|
|
19
20
|
from vllm.v1 import utils as vllm_utils
|
|
20
|
-
from vllm.v1.core.kv_cache_utils import get_num_blocks,
|
|
21
|
+
from vllm.v1.core.kv_cache_utils import (get_kv_cache_groups, get_num_blocks,
|
|
22
|
+
get_uniform_page_size)
|
|
21
23
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
|
22
24
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|
23
25
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
|
24
26
|
|
|
25
27
|
from tpu_inference import envs, utils
|
|
28
|
+
from tpu_inference.distributed import jax_parallel_state
|
|
26
29
|
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
|
|
27
30
|
get_node_id)
|
|
28
31
|
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
29
32
|
from tpu_inference.logger import init_logger
|
|
30
|
-
from tpu_inference.
|
|
33
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
34
|
+
JaxIntermediateTensors
|
|
35
|
+
from tpu_inference.runner.kv_cache import get_attention_page_size_bytes
|
|
31
36
|
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
32
37
|
|
|
33
38
|
logger = init_logger(__name__)
|
|
34
39
|
|
|
35
|
-
_DTYPE: dict[str, jnp.dtype] = {
|
|
36
|
-
"bfloat16": jnp.bfloat16,
|
|
37
|
-
"float": jnp.float32,
|
|
38
|
-
"float32": jnp.float32,
|
|
39
|
-
}
|
|
40
40
|
|
|
41
|
+
@dataclass
|
|
42
|
+
class PPConfig:
|
|
43
|
+
rank: int
|
|
44
|
+
ip: str
|
|
45
|
+
prev_worker_ip: str
|
|
46
|
+
pp_world_size: int
|
|
41
47
|
|
|
42
|
-
|
|
48
|
+
# default env vars for
|
|
49
|
+
# TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
|
|
50
|
+
# if PP is used in single host.
|
|
51
|
+
default_tpu_process_bounds: str = field(init=False)
|
|
52
|
+
default_tpu_chips_per_process_bounds: str = field(init=False)
|
|
53
|
+
default_tpu_visible_chips: str = field(init=False)
|
|
54
|
+
|
|
55
|
+
def __post_init__(self):
|
|
56
|
+
self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
|
|
57
|
+
self.default_tpu_chips_per_process_bounds = "1,1,1"
|
|
58
|
+
self.default_tpu_visible_chips = f"{self.rank}"
|
|
43
59
|
|
|
44
|
-
def __init__(self,
|
|
45
|
-
vllm_config: VllmConfig,
|
|
46
|
-
local_rank: int,
|
|
47
|
-
rank: int,
|
|
48
|
-
distributed_init_method: str,
|
|
49
|
-
is_driver_worker: bool = False,
|
|
50
|
-
devices=None):
|
|
51
|
-
# If we use vLLM's model implementation in PyTorch, we should set it
|
|
52
|
-
# with torch version of the dtype.
|
|
53
|
-
impl = envs.MODEL_IMPL_TYPE
|
|
54
|
-
if impl != "vllm": # vllm-pytorch implementation does not need this conversion
|
|
55
|
-
|
|
56
|
-
# NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
|
|
57
|
-
if not isinstance(vllm_config.model_config.dtype, str):
|
|
58
|
-
logger.warning(
|
|
59
|
-
"The model dtype is not properly set for JAX backend. "
|
|
60
|
-
"Overwriting it to jnp.bfloat16")
|
|
61
|
-
vllm_config.model_config.dtype = jnp.bfloat16
|
|
62
|
-
else:
|
|
63
|
-
vllm_config.model_config.dtype = _DTYPE.get(
|
|
64
|
-
vllm_config.model_config.dtype, jnp.bfloat16)
|
|
65
60
|
|
|
61
|
+
class TPUWorker:
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
vllm_config: VllmConfig,
|
|
66
|
+
local_rank: int,
|
|
67
|
+
rank: int,
|
|
68
|
+
distributed_init_method: str,
|
|
69
|
+
is_driver_worker: bool = False,
|
|
70
|
+
devices=None,
|
|
71
|
+
ip: str = "localhost",
|
|
72
|
+
prev_worker_ip: str = "localhost",
|
|
73
|
+
):
|
|
66
74
|
self.vllm_config = vllm_config
|
|
67
75
|
self.model_config = vllm_config.model_config
|
|
68
76
|
self.parallel_config = vllm_config.parallel_config
|
|
@@ -74,10 +82,12 @@ class TPUWorker:
|
|
|
74
82
|
self.devices = devices if devices is not None else []
|
|
75
83
|
self.device_ranks = set(device.id for device in self.devices
|
|
76
84
|
if isinstance(device, jaxlib._jax.Device))
|
|
85
|
+
self.pp_config = PPConfig(rank, ip, prev_worker_ip,
|
|
86
|
+
self.parallel_config.pipeline_parallel_size)
|
|
77
87
|
|
|
78
88
|
if self.model_config.trust_remote_code:
|
|
79
89
|
# note: lazy import to avoid importing torch before initializing
|
|
80
|
-
from vllm.utils import init_cached_hf_modules
|
|
90
|
+
from vllm.utils.import_utils import init_cached_hf_modules
|
|
81
91
|
|
|
82
92
|
init_cached_hf_modules()
|
|
83
93
|
|
|
@@ -86,7 +96,7 @@ class TPUWorker:
|
|
|
86
96
|
# TPU Worker is initialized. The profiler server needs to start after
|
|
87
97
|
# MP runtime is initialized.
|
|
88
98
|
self.profile_dir = None
|
|
89
|
-
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
|
|
99
|
+
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1:
|
|
90
100
|
if not self.devices or 0 in self.device_ranks:
|
|
91
101
|
# For TPU, we can only have 1 active profiler session for 1 profiler
|
|
92
102
|
# server. So we only profile on rank0.
|
|
@@ -94,6 +104,14 @@ class TPUWorker:
|
|
|
94
104
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
|
95
105
|
self.profile_dir)
|
|
96
106
|
|
|
107
|
+
# For PP, we use MPMD so we want to profile every worker.
|
|
108
|
+
if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
|
|
109
|
+
self.profile_dir = os.path.join(
|
|
110
|
+
vllm_envs.VLLM_TORCH_PROFILER_DIR,
|
|
111
|
+
f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
|
|
112
|
+
)
|
|
113
|
+
os.makedirs(self.profile_dir, exist_ok=True)
|
|
114
|
+
|
|
97
115
|
use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
|
|
98
116
|
# Only one instance of profiler is allowed
|
|
99
117
|
if use_jax_profiler_server and self.rank < 1:
|
|
@@ -105,31 +123,87 @@ class TPUWorker:
|
|
|
105
123
|
)
|
|
106
124
|
jax.profiler.start_server(jax_profiler_server_port)
|
|
107
125
|
|
|
126
|
+
# step_counter is used to calculate uuid to transfer intermediate tensors.
|
|
127
|
+
self.step_counter = 0
|
|
128
|
+
|
|
108
129
|
def initialize_cache(self, num_gpu_blocks: int,
|
|
109
130
|
num_cpu_blocks: int) -> None:
|
|
110
131
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
111
132
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
112
133
|
|
|
113
|
-
def init_device(self
|
|
134
|
+
def init_device(self,
|
|
135
|
+
tpu_process_bounds="",
|
|
136
|
+
tpu_chips_per_process_bounds="",
|
|
137
|
+
tpu_visible_chips=""):
|
|
138
|
+
# set tpu visible devices for Jax runtime in single host PP.
|
|
139
|
+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
140
|
+
if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
|
|
141
|
+
tpu_ports = [
|
|
142
|
+
jax_parallel_state.BASE_JAX_PORT + i
|
|
143
|
+
for i in range(self.pp_config.pp_world_size)
|
|
144
|
+
]
|
|
145
|
+
os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
|
|
146
|
+
[f"localhost:{port}" for port in tpu_ports])
|
|
147
|
+
os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
|
|
148
|
+
os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
|
|
149
|
+
|
|
150
|
+
# Note: Below is the setting for v6e8 host (8 chips of v6e)
|
|
151
|
+
# Replace with your own topology.
|
|
152
|
+
# There are 2 ways of subslicing a v6e
|
|
153
|
+
# 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
|
|
154
|
+
# TPU_PROCESS_BOUNDS = "1,1,1"
|
|
155
|
+
# TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
|
|
156
|
+
# TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
|
|
157
|
+
# 2) 1 chip for each subslice, with at most 8 subslices,
|
|
158
|
+
# we can do TP=1, PP=1/2/3/4/5/6/7/8
|
|
159
|
+
os.environ[
|
|
160
|
+
"TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
|
|
161
|
+
if tpu_process_bounds \
|
|
162
|
+
else self.pp_config.default_tpu_process_bounds
|
|
163
|
+
os.environ[
|
|
164
|
+
"TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
|
|
165
|
+
if tpu_chips_per_process_bounds \
|
|
166
|
+
else self.pp_config.default_tpu_chips_per_process_bounds
|
|
167
|
+
os.environ[
|
|
168
|
+
"TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
|
|
169
|
+
if tpu_visible_chips \
|
|
170
|
+
else self.pp_config.default_tpu_visible_chips
|
|
171
|
+
|
|
114
172
|
if not self.devices:
|
|
115
173
|
sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
|
|
116
174
|
device_indexes = sharding_config.device_indexes
|
|
117
175
|
if device_indexes is not None and len(device_indexes) > 0:
|
|
118
176
|
# Enforcing the devices sequence to be consistent with the specified device indexes
|
|
119
|
-
|
|
120
|
-
device_dict = {
|
|
177
|
+
all_local_devices = jax.local_devices()
|
|
178
|
+
device_dict = {
|
|
179
|
+
device.id: device
|
|
180
|
+
for device in all_local_devices
|
|
181
|
+
}
|
|
121
182
|
self.devices = []
|
|
122
183
|
for device_index in device_indexes:
|
|
123
184
|
device = device_dict[device_index]
|
|
124
185
|
if device is None:
|
|
125
186
|
raise KeyError(
|
|
126
187
|
f"Device index {device_index} not found in "
|
|
127
|
-
f"jax.
|
|
188
|
+
f"jax.local_devices() with IDs {list(device_dict.keys())}!"
|
|
128
189
|
)
|
|
129
190
|
self.devices.append(device)
|
|
191
|
+
assert len(self.devices) >= sharding_config.total_devices
|
|
130
192
|
self.devices = self.devices[:sharding_config.total_devices]
|
|
131
193
|
else:
|
|
132
|
-
self.
|
|
194
|
+
if self.pp_config.pp_world_size > 1:
|
|
195
|
+
# We only support a mixed tp + pp scenario that tp size is
|
|
196
|
+
# smaller or equals the total TPUs in one node
|
|
197
|
+
# say: we have 4 nodes with 4 TPUs each, we can only do pp:4, tp:4, but not pp:2, tp:8
|
|
198
|
+
assert jax.local_device_count(
|
|
199
|
+
) >= sharding_config.total_devices
|
|
200
|
+
self.devices = jax.local_devices()[:sharding_config.
|
|
201
|
+
total_devices]
|
|
202
|
+
else:
|
|
203
|
+
# In a multi-host distributed env, say: Ray, local_device count may smaller
|
|
204
|
+
# than the total devices, we just choose the smaller set here.
|
|
205
|
+
self.devices = jax.devices()[:sharding_config.
|
|
206
|
+
total_devices]
|
|
133
207
|
|
|
134
208
|
# Initialize the vLLM distribution layer as a single chip environment,
|
|
135
209
|
# we'll swap the model's parallel modules with TPU SPMD equivalents.
|
|
@@ -146,15 +220,40 @@ class TPUWorker:
|
|
|
146
220
|
tensor_model_parallel_size=1,
|
|
147
221
|
pipeline_model_parallel_size=1,
|
|
148
222
|
)
|
|
223
|
+
|
|
224
|
+
jax_parallel_state.init_pp_distributed_environment(
|
|
225
|
+
self.pp_config.ip,
|
|
226
|
+
self.rank,
|
|
227
|
+
self.parallel_config.pipeline_parallel_size,
|
|
228
|
+
self.devices[0],
|
|
229
|
+
need_pp=self.parallel_config.pipeline_parallel_size > 1)
|
|
230
|
+
|
|
149
231
|
ensure_kv_transfer_initialized(self.vllm_config)
|
|
150
|
-
|
|
232
|
+
|
|
233
|
+
is_first_rank = True
|
|
234
|
+
is_last_rank = True
|
|
235
|
+
if self.parallel_config.pipeline_parallel_size > 1:
|
|
236
|
+
is_first_rank = self.rank == 0
|
|
237
|
+
is_last_rank = self.rank == self.pp_config.pp_world_size - 1
|
|
238
|
+
|
|
239
|
+
self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
|
|
240
|
+
self.rank, is_first_rank,
|
|
241
|
+
is_last_rank)
|
|
151
242
|
logger.info(f"Init worker | "
|
|
152
243
|
f"rank={self.rank} | "
|
|
244
|
+
f"is_first_rank={is_first_rank} | "
|
|
245
|
+
f"is_last_rank={is_last_rank} | "
|
|
153
246
|
f"node_id={get_node_id()} | "
|
|
154
247
|
f"is_driver_worker={self.is_driver_worker} | "
|
|
155
248
|
f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
|
|
156
249
|
vllm_utils.report_usage_stats(self.vllm_config)
|
|
157
250
|
|
|
251
|
+
def initialize_pp_transfer_connect(self):
|
|
252
|
+
if self.rank == 0:
|
|
253
|
+
return
|
|
254
|
+
jax_parallel_state.connect(self.pp_config.prev_worker_ip,
|
|
255
|
+
self.rank - 1)
|
|
256
|
+
|
|
158
257
|
def determine_available_memory(self) -> int:
|
|
159
258
|
gpu_memory_utilization = self.cache_config.gpu_memory_utilization
|
|
160
259
|
hbm_usage = utils.hbm_usage_bytes(self.devices)
|
|
@@ -194,14 +293,39 @@ class TPUWorker:
|
|
|
194
293
|
# deliberate, temporary compromise for the same reasons outlined in
|
|
195
294
|
# the `get_kv_cache_spec` method.
|
|
196
295
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
296
|
+
if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
|
|
297
|
+
intermediate_tensors = None
|
|
298
|
+
else:
|
|
299
|
+
# receive intermediate tensors
|
|
300
|
+
uuid = self.model_runner.get_uuid_for_jax_transfer(
|
|
301
|
+
scheduler_output, self.rank - 1, self.step_counter)
|
|
302
|
+
# TODO: this method might only works for vllm model, not sure about jax models.
|
|
303
|
+
tensor_spec = self.model_runner.get_intermediate_tensor_spec(
|
|
304
|
+
scheduler_output.total_num_scheduled_tokens)
|
|
305
|
+
intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
|
|
306
|
+
uuid, tensor_spec)
|
|
307
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
308
|
+
intermediate_tensors_dict)
|
|
309
|
+
|
|
310
|
+
output = self.model_runner.execute_model(scheduler_output,
|
|
311
|
+
intermediate_tensors)
|
|
312
|
+
|
|
313
|
+
if isinstance(output, JaxIntermediateTensors):
|
|
314
|
+
assert self.parallel_config.pipeline_parallel_size > 1
|
|
315
|
+
assert not get_pp_group().is_last_rank
|
|
316
|
+
# send intermediate tensors
|
|
317
|
+
uuid = self.model_runner.get_uuid_for_jax_transfer(
|
|
318
|
+
scheduler_output, self.rank, self.step_counter)
|
|
319
|
+
get_pp_group().send_tensor_dict(uuid, output.tensors)
|
|
320
|
+
self.step_counter += 1
|
|
321
|
+
return None
|
|
322
|
+
else:
|
|
323
|
+
self.step_counter += 1
|
|
324
|
+
# With a connector, the scheduler expects output from all workers
|
|
325
|
+
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
|
|
326
|
+
if has_kv_transfer_group():
|
|
327
|
+
return output
|
|
328
|
+
return output if self.is_driver_worker else None
|
|
205
329
|
|
|
206
330
|
def sample_tokens(self,
|
|
207
331
|
grammar_output: GrammarOutput) -> ModelRunnerOutput:
|
|
@@ -221,7 +345,7 @@ class TPUWorker:
|
|
|
221
345
|
if is_start:
|
|
222
346
|
options = jax.profiler.ProfileOptions()
|
|
223
347
|
# default: https://docs.jax.dev/en/latest/profiling.html#general-options
|
|
224
|
-
options.python_tracer_level =
|
|
348
|
+
options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
|
|
225
349
|
options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
|
|
226
350
|
jax.profiler.start_trace(self.profile_dir,
|
|
227
351
|
profiler_options=options)
|
|
@@ -259,32 +383,37 @@ class TPUWorker:
|
|
|
259
383
|
# responsible for this translation. When vLLM can be modified, this
|
|
260
384
|
# method should be changed to return `dict[str, AbstractKVCacheSpec]`,
|
|
261
385
|
# and the vLLM side should be updated to handle the translation.
|
|
262
|
-
|
|
386
|
+
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
|
263
387
|
|
|
264
|
-
if len(
|
|
265
|
-
return
|
|
388
|
+
if len(kv_cache_spec) == 0:
|
|
389
|
+
return kv_cache_spec
|
|
266
390
|
|
|
267
391
|
# TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
|
|
268
392
|
# feature that allows overriding page_size_bytes of KVCacheSpec.
|
|
269
|
-
vllm_page_size_bytes = get_uniform_page_size(
|
|
270
|
-
|
|
271
|
-
|
|
393
|
+
vllm_page_size_bytes = get_uniform_page_size(
|
|
394
|
+
list(kv_cache_spec.values()))
|
|
395
|
+
attention_page_size_bytes = get_attention_page_size_bytes(
|
|
396
|
+
self.model_runner.mesh, kv_cache_spec)
|
|
272
397
|
|
|
273
|
-
if vllm_page_size_bytes !=
|
|
398
|
+
if vllm_page_size_bytes != attention_page_size_bytes:
|
|
274
399
|
logger.info(
|
|
275
|
-
f"
|
|
276
|
-
f"
|
|
277
|
-
f"
|
|
278
|
-
f"
|
|
279
|
-
|
|
400
|
+
f"Page size calculated by vLLM ({vllm_page_size_bytes} Bytes) "
|
|
401
|
+
f"does not match with actual page size used by the kernel "
|
|
402
|
+
f"({attention_page_size_bytes} Bytes). Recalculating number of "
|
|
403
|
+
f"KV blocks using actual page size.")
|
|
404
|
+
|
|
405
|
+
kv_cache_groups = get_kv_cache_groups(self.vllm_config,
|
|
406
|
+
kv_cache_spec)
|
|
407
|
+
group_size = max(
|
|
408
|
+
len(group.layer_names) for group in kv_cache_groups)
|
|
280
409
|
available_memory = self.determine_available_memory()
|
|
281
|
-
num_blocks = get_num_blocks(self.vllm_config,
|
|
282
|
-
available_memory,
|
|
283
|
-
|
|
410
|
+
num_blocks = get_num_blocks(self.vllm_config, group_size,
|
|
411
|
+
available_memory,
|
|
412
|
+
attention_page_size_bytes)
|
|
284
413
|
cache_config = self.vllm_config.cache_config
|
|
285
414
|
cache_config.num_gpu_blocks_override = num_blocks
|
|
286
415
|
|
|
287
|
-
return
|
|
416
|
+
return kv_cache_spec
|
|
288
417
|
|
|
289
418
|
def initialize_from_config(
|
|
290
419
|
self,
|
|
@@ -319,3 +448,8 @@ class TPUWorker:
|
|
|
319
448
|
|
|
320
449
|
def shutdown(self) -> None:
|
|
321
450
|
return
|
|
451
|
+
|
|
452
|
+
# Ray executor do not need handshake metadata
|
|
453
|
+
# as we pass the kv_parameters through proxy server
|
|
454
|
+
def get_kv_connector_handshake_metadata(self) -> None:
|
|
455
|
+
pass
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tpu_inference
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.12.0.dev20251213
|
|
4
4
|
Author: tpu_inference Contributors
|
|
5
5
|
Classifier: Development Status :: 3 - Alpha
|
|
6
6
|
Classifier: Intended Audience :: Developers
|
|
@@ -14,7 +14,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
14
14
|
Requires-Python: >=3.10
|
|
15
15
|
Description-Content-Type: text/markdown
|
|
16
16
|
License-File: LICENSE
|
|
17
|
-
Requires-Dist: tpu-info==0.
|
|
17
|
+
Requires-Dist: tpu-info==0.7.1
|
|
18
18
|
Requires-Dist: yapf==0.43.0
|
|
19
19
|
Requires-Dist: pytest
|
|
20
20
|
Requires-Dist: pytest-mock
|
|
@@ -25,12 +25,13 @@ Requires-Dist: jax[tpu]==0.8.0
|
|
|
25
25
|
Requires-Dist: jaxlib==0.8.0
|
|
26
26
|
Requires-Dist: jaxtyping
|
|
27
27
|
Requires-Dist: flax==0.11.1
|
|
28
|
-
Requires-Dist: torchax==0.0.
|
|
28
|
+
Requires-Dist: torchax==0.0.10
|
|
29
29
|
Requires-Dist: qwix==0.1.1
|
|
30
|
-
Requires-Dist: torchvision==0.
|
|
30
|
+
Requires-Dist: torchvision==0.24.0
|
|
31
31
|
Requires-Dist: pathwaysutils
|
|
32
32
|
Requires-Dist: parameterized
|
|
33
33
|
Requires-Dist: numba==0.62.1
|
|
34
|
+
Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
|
|
34
35
|
Dynamic: author
|
|
35
36
|
Dynamic: classifier
|
|
36
37
|
Dynamic: description
|
|
@@ -52,14 +53,12 @@ Dynamic: requires-python
|
|
|
52
53
|
|
|
53
54
|
---
|
|
54
55
|
|
|
55
|
-
_Upcoming Events_ 🔥
|
|
56
|
-
|
|
57
|
-
- Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
|
|
58
|
-
- Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
59
|
-
- Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
60
|
-
|
|
61
56
|
_Latest News_ 🔥
|
|
62
57
|
|
|
58
|
+
- [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
|
|
59
|
+
- Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
60
|
+
- Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
61
|
+
|
|
63
62
|
- [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
|
|
64
63
|
|
|
65
64
|
<details>
|