tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tests/test_envs.py +32 -11
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +61 -8
- tpu_inference/executors/ray_distributed_executor.py +31 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +45 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
- tpu_inference/platforms/tpu_platform.py +28 -22
- tpu_inference/runner/compilation_manager.py +144 -59
- tpu_inference/runner/kv_cache_manager.py +17 -18
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +271 -147
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +36 -13
- tpu_inference/worker/tpu_worker.py +162 -25
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@ import os
|
|
|
2
2
|
|
|
3
3
|
from vllm.utils.network_utils import get_ip
|
|
4
4
|
|
|
5
|
+
from tpu_inference import envs
|
|
5
6
|
from tpu_inference.logger import init_logger
|
|
6
7
|
|
|
7
8
|
logger = init_logger(__name__)
|
|
@@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
|
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
def get_kv_ips() -> str:
|
|
20
|
-
if
|
|
21
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
21
22
|
num_nodes = len(_NODES_KV_IP_PORT)
|
|
22
23
|
ips = []
|
|
23
24
|
for node_id in range(num_nodes):
|
|
@@ -28,7 +29,7 @@ def get_kv_ips() -> str:
|
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
def get_kv_ports() -> str:
|
|
31
|
-
if
|
|
32
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
32
33
|
num_nodes = len(_NODES_KV_IP_PORT)
|
|
33
34
|
ports = []
|
|
34
35
|
for node_id in range(num_nodes):
|
tpu_inference/envs.py
CHANGED
|
@@ -15,18 +15,64 @@ if TYPE_CHECKING:
|
|
|
15
15
|
PREFILL_SLICES: str = ""
|
|
16
16
|
DECODE_SLICES: str = ""
|
|
17
17
|
SKIP_JAX_PRECOMPILE: bool = False
|
|
18
|
+
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
|
18
19
|
MODEL_IMPL_TYPE: str = "flax_nnx"
|
|
19
20
|
NEW_MODEL_DESIGN: bool = False
|
|
20
21
|
PHASED_PROFILING_DIR: str = ""
|
|
21
22
|
PYTHON_TRACER_LEVEL: int = 1
|
|
22
23
|
USE_MOE_EP_KERNEL: bool = False
|
|
24
|
+
NUM_SLICES: int = 1
|
|
23
25
|
RAY_USAGE_STATS_ENABLED: str = "0"
|
|
24
26
|
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
|
|
25
27
|
|
|
28
|
+
|
|
29
|
+
def env_with_choices(
|
|
30
|
+
env_name: str,
|
|
31
|
+
default: str | None,
|
|
32
|
+
choices: list[str] | Callable[[], list[str]],
|
|
33
|
+
case_sensitive: bool = True,
|
|
34
|
+
) -> Callable[[], str | None]:
|
|
35
|
+
"""
|
|
36
|
+
Create a lambda that validates environment variable against allowed choices
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
env_name: Name of the environment variable
|
|
40
|
+
default: Default value if not set (can be None)
|
|
41
|
+
choices: List of valid string options or callable that returns list
|
|
42
|
+
case_sensitive: Whether validation should be case sensitive
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Lambda function for environment_variables dict
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def _get_validated_env() -> str | None:
|
|
49
|
+
value = os.getenv(env_name)
|
|
50
|
+
if value is None:
|
|
51
|
+
return default
|
|
52
|
+
|
|
53
|
+
# Resolve choices if it's a callable (for lazy loading)
|
|
54
|
+
actual_choices = choices() if callable(choices) else choices
|
|
55
|
+
|
|
56
|
+
if not case_sensitive:
|
|
57
|
+
check_value = value.lower()
|
|
58
|
+
check_choices = [choice.lower() for choice in actual_choices]
|
|
59
|
+
else:
|
|
60
|
+
check_value = value
|
|
61
|
+
check_choices = actual_choices
|
|
62
|
+
|
|
63
|
+
if check_value not in check_choices:
|
|
64
|
+
raise ValueError(f"Invalid value '{value}' for {env_name}. "
|
|
65
|
+
f"Valid options: {actual_choices}.")
|
|
66
|
+
|
|
67
|
+
return value
|
|
68
|
+
|
|
69
|
+
return _get_validated_env
|
|
70
|
+
|
|
71
|
+
|
|
26
72
|
environment_variables: dict[str, Callable[[], Any]] = {
|
|
27
73
|
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
|
|
28
74
|
"JAX_PLATFORMS":
|
|
29
|
-
lambda: os.getenv("JAX_PLATFORMS", ""),
|
|
75
|
+
lambda: os.getenv("JAX_PLATFORMS", "").lower(),
|
|
30
76
|
# TPU accelerator type (e.g., "v5litepod-16", "v4-8")
|
|
31
77
|
"TPU_ACCELERATOR_TYPE":
|
|
32
78
|
lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
|
|
@@ -38,7 +84,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
38
84
|
lambda: os.getenv("TPU_WORKER_ID", None),
|
|
39
85
|
# Backend for multi-host communication on TPU
|
|
40
86
|
"TPU_MULTIHOST_BACKEND":
|
|
41
|
-
|
|
87
|
+
env_with_choices("TPU_MULTIHOST_BACKEND", "", ["ray"]),
|
|
42
88
|
# Slice configuration for disaggregated prefill workers
|
|
43
89
|
"PREFILL_SLICES":
|
|
44
90
|
lambda: os.getenv("PREFILL_SLICES", ""),
|
|
@@ -47,28 +93,35 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
47
93
|
lambda: os.getenv("DECODE_SLICES", ""),
|
|
48
94
|
# Skip JAX precompilation step during initialization
|
|
49
95
|
"SKIP_JAX_PRECOMPILE":
|
|
50
|
-
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE"
|
|
96
|
+
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE") or "0")),
|
|
97
|
+
# Check for XLA recompilation during execution
|
|
98
|
+
"VLLM_XLA_CHECK_RECOMPILATION":
|
|
99
|
+
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
|
|
51
100
|
# Model implementation type (e.g., "flax_nnx")
|
|
52
101
|
"MODEL_IMPL_TYPE":
|
|
53
|
-
|
|
102
|
+
env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
|
|
103
|
+
["vllm", "flax_nnx", "jetpack"]),
|
|
54
104
|
# Enable new experimental model design
|
|
55
105
|
"NEW_MODEL_DESIGN":
|
|
56
|
-
lambda: bool(int(os.getenv("NEW_MODEL_DESIGN"
|
|
106
|
+
lambda: bool(int(os.getenv("NEW_MODEL_DESIGN") or "0")),
|
|
57
107
|
# Directory to store phased profiling output
|
|
58
108
|
"PHASED_PROFILING_DIR":
|
|
59
109
|
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
|
|
60
110
|
# Python tracer level for profiling
|
|
61
111
|
"PYTHON_TRACER_LEVEL":
|
|
62
|
-
lambda: int(os.getenv("PYTHON_TRACER_LEVEL"
|
|
112
|
+
lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
|
|
63
113
|
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
|
|
64
114
|
"USE_MOE_EP_KERNEL":
|
|
65
|
-
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL"
|
|
115
|
+
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL") or "0")),
|
|
116
|
+
# Number of TPU slices for multi-slice mesh
|
|
117
|
+
"NUM_SLICES":
|
|
118
|
+
lambda: int(os.getenv("NUM_SLICES") or "1"),
|
|
66
119
|
# Enable/disable Ray usage statistics collection
|
|
67
120
|
"RAY_USAGE_STATS_ENABLED":
|
|
68
121
|
lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
|
|
69
122
|
# Ray compiled DAG channel type for TPU
|
|
70
123
|
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
|
|
71
|
-
|
|
124
|
+
env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
|
|
72
125
|
}
|
|
73
126
|
|
|
74
127
|
|
|
@@ -108,6 +108,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
108
108
|
ip_port = self.collective_rpc("get_node_kv_ip_port")
|
|
109
109
|
for item in ip_port:
|
|
110
110
|
set_node_kv_ip_port(item)
|
|
111
|
+
self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
|
|
112
|
+
self.vllm_config.ec_transfer_config is None
|
|
113
|
+
or not self.vllm_config.ec_transfer_config.is_ec_producer)
|
|
111
114
|
|
|
112
115
|
def _initialize_ray_cluster(self) -> None:
|
|
113
116
|
"""Initialize the distributed cluster with Ray.
|
|
@@ -131,10 +134,21 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
131
134
|
f"current platform {current_platform.device_name} does not "
|
|
132
135
|
"support ray.")
|
|
133
136
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
137
|
+
pp_size = self.parallel_config.pipeline_parallel_size
|
|
138
|
+
placement_group_specs: List[Dict[str, float]] = []
|
|
139
|
+
|
|
140
|
+
ray_nodes = ray.nodes()
|
|
141
|
+
logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
|
|
142
|
+
|
|
143
|
+
if pp_size == 1:
|
|
144
|
+
placement_group_specs = [{
|
|
145
|
+
device_str: node['Resources'][device_str]
|
|
146
|
+
} for node in ray_nodes]
|
|
147
|
+
else:
|
|
148
|
+
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
|
|
149
|
+
placement_group_specs = [{
|
|
150
|
+
device_str: num_devices_per_pp_rank
|
|
151
|
+
} for _ in range(pp_size)]
|
|
138
152
|
|
|
139
153
|
# vLLM engine is also a worker to execute model with an accelerator,
|
|
140
154
|
# so it requires to have the device in a current node. Check if
|
|
@@ -329,6 +343,8 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
329
343
|
all_kwargs = []
|
|
330
344
|
for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
|
|
331
345
|
local_rank = node_workers[node_id].index(rank)
|
|
346
|
+
ip = sorted_worker_metadata[rank].ip
|
|
347
|
+
prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
|
|
332
348
|
kwargs = dict(
|
|
333
349
|
vllm_config=self.vllm_config,
|
|
334
350
|
local_rank=local_rank,
|
|
@@ -336,22 +352,26 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
336
352
|
distributed_init_method=distributed_init_method,
|
|
337
353
|
is_driver_worker=(not self.parallel_config)
|
|
338
354
|
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
|
355
|
+
ip=ip,
|
|
356
|
+
prev_worker_ip=prev_ip,
|
|
339
357
|
)
|
|
340
358
|
all_kwargs.append(kwargs)
|
|
341
359
|
self.collective_rpc("init_worker", args=(all_kwargs, ))
|
|
342
360
|
self.collective_rpc("init_device")
|
|
361
|
+
if self.parallel_config.pipeline_parallel_size > 1:
|
|
362
|
+
self.collective_rpc("initialize_pp_transfer_connect")
|
|
343
363
|
self.collective_rpc("load_model")
|
|
344
364
|
|
|
345
365
|
if self.use_ray_spmd_worker:
|
|
346
366
|
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
|
347
367
|
self.pp_tp_workers.append([])
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
#
|
|
353
|
-
|
|
354
|
-
|
|
368
|
+
num_tp_workers = int(
|
|
369
|
+
self.parallel_config.tensor_parallel_size //
|
|
370
|
+
num_tpu_per_worker)
|
|
371
|
+
for tp_rank in range(num_tp_workers):
|
|
372
|
+
# PP=2, TP=4, num_tpu_per_worker=2
|
|
373
|
+
# pp_tp_workers = [[0, 1], [2, 3]]
|
|
374
|
+
rank = (pp_rank * num_tp_workers) + tp_rank
|
|
355
375
|
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
|
356
376
|
assert pp_rank < len(self.pp_tp_workers)
|
|
357
377
|
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|