tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_utils.py +16 -24
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/core_tpu.py +9 -17
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +11 -31
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/{common → jax}/sharding.py +5 -5
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/__init__.py +3 -7
- tpu_inference/layers/vllm/quantization/awq.py +3 -4
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
- tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +12 -46
- tpu_inference/models/jax/llama3.py +3 -4
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +2 -3
- tpu_inference/models/jax/qwen2_5_vl.py +50 -165
- tpu_inference/models/jax/qwen3.py +2 -3
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
- tpu_inference/platforms/tpu_platform.py +34 -47
- tpu_inference/runner/compilation_manager.py +60 -145
- tpu_inference/runner/kv_cache.py +2 -2
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +135 -283
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +15 -38
- tpu_inference/worker/tpu_worker.py +26 -163
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
- tests/test_envs.py +0 -203
- tpu_inference/layers/common/quant_methods.py +0 -8
- tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
|
@@ -60,6 +60,7 @@ D workflow:
|
|
|
60
60
|
|
|
61
61
|
import copy
|
|
62
62
|
import functools
|
|
63
|
+
import os
|
|
63
64
|
import threading
|
|
64
65
|
import time
|
|
65
66
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
@@ -85,7 +86,6 @@ if TYPE_CHECKING:
|
|
|
85
86
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
86
87
|
from vllm.v1.request import Request
|
|
87
88
|
|
|
88
|
-
from tpu_inference import envs
|
|
89
89
|
from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
|
|
90
90
|
get_kv_ports,
|
|
91
91
|
get_kv_transfer_port, get_node_id,
|
|
@@ -441,7 +441,8 @@ class TPUConnectorWorker:
|
|
|
441
441
|
|
|
442
442
|
self.runner: TPUModelRunner = None
|
|
443
443
|
self.mesh: Mesh = None
|
|
444
|
-
self.multi_host =
|
|
444
|
+
self.multi_host = os.getenv("TPU_MULTIHOST_BACKEND",
|
|
445
|
+
"").lower() == "ray"
|
|
445
446
|
# NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
|
|
446
447
|
# The worker rank is assigned with vLLM's sorting logic, which does not work
|
|
447
448
|
# for TPU host topology.
|
|
@@ -457,6 +458,7 @@ class TPUConnectorWorker:
|
|
|
457
458
|
self.side_channel_port = get_side_channel_port()
|
|
458
459
|
|
|
459
460
|
self.kv_transfer_server = None
|
|
461
|
+
self._maybe_start_p2p_server()
|
|
460
462
|
self.zmq_cxt = zmq.Context()
|
|
461
463
|
if self.is_producer:
|
|
462
464
|
ready_event = threading.Event()
|
|
@@ -498,7 +500,6 @@ class TPUConnectorWorker:
|
|
|
498
500
|
self.shape = list(kv_layer.shape)
|
|
499
501
|
self.dtype = kv_layer.dtype
|
|
500
502
|
self.sharding = kv_layer.sharding
|
|
501
|
-
self._maybe_start_p2p_server()
|
|
502
503
|
|
|
503
504
|
def _maybe_start_p2p_server(self):
|
|
504
505
|
if self.kv_transfer_server is not None:
|
|
@@ -2,7 +2,6 @@ import os
|
|
|
2
2
|
|
|
3
3
|
from vllm.utils.network_utils import get_ip
|
|
4
4
|
|
|
5
|
-
from tpu_inference import envs
|
|
6
5
|
from tpu_inference.logger import init_logger
|
|
7
6
|
|
|
8
7
|
logger = init_logger(__name__)
|
|
@@ -18,7 +17,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
|
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
def get_kv_ips() -> str:
|
|
21
|
-
if
|
|
20
|
+
if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
|
|
22
21
|
num_nodes = len(_NODES_KV_IP_PORT)
|
|
23
22
|
ips = []
|
|
24
23
|
for node_id in range(num_nodes):
|
|
@@ -29,7 +28,7 @@ def get_kv_ips() -> str:
|
|
|
29
28
|
|
|
30
29
|
|
|
31
30
|
def get_kv_ports() -> str:
|
|
32
|
-
if
|
|
31
|
+
if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
|
|
33
32
|
num_nodes = len(_NODES_KV_IP_PORT)
|
|
34
33
|
ports = []
|
|
35
34
|
for node_id in range(num_nodes):
|
tpu_inference/envs.py
CHANGED
|
@@ -15,64 +15,18 @@ if TYPE_CHECKING:
|
|
|
15
15
|
PREFILL_SLICES: str = ""
|
|
16
16
|
DECODE_SLICES: str = ""
|
|
17
17
|
SKIP_JAX_PRECOMPILE: bool = False
|
|
18
|
-
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
|
19
18
|
MODEL_IMPL_TYPE: str = "flax_nnx"
|
|
20
19
|
NEW_MODEL_DESIGN: bool = False
|
|
21
20
|
PHASED_PROFILING_DIR: str = ""
|
|
22
21
|
PYTHON_TRACER_LEVEL: int = 1
|
|
23
22
|
USE_MOE_EP_KERNEL: bool = False
|
|
24
|
-
NUM_SLICES: int = 1
|
|
25
23
|
RAY_USAGE_STATS_ENABLED: str = "0"
|
|
26
24
|
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
|
|
27
25
|
|
|
28
|
-
|
|
29
|
-
def env_with_choices(
|
|
30
|
-
env_name: str,
|
|
31
|
-
default: str | None,
|
|
32
|
-
choices: list[str] | Callable[[], list[str]],
|
|
33
|
-
case_sensitive: bool = True,
|
|
34
|
-
) -> Callable[[], str | None]:
|
|
35
|
-
"""
|
|
36
|
-
Create a lambda that validates environment variable against allowed choices
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
env_name: Name of the environment variable
|
|
40
|
-
default: Default value if not set (can be None)
|
|
41
|
-
choices: List of valid string options or callable that returns list
|
|
42
|
-
case_sensitive: Whether validation should be case sensitive
|
|
43
|
-
|
|
44
|
-
Returns:
|
|
45
|
-
Lambda function for environment_variables dict
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
def _get_validated_env() -> str | None:
|
|
49
|
-
value = os.getenv(env_name)
|
|
50
|
-
if value is None:
|
|
51
|
-
return default
|
|
52
|
-
|
|
53
|
-
# Resolve choices if it's a callable (for lazy loading)
|
|
54
|
-
actual_choices = choices() if callable(choices) else choices
|
|
55
|
-
|
|
56
|
-
if not case_sensitive:
|
|
57
|
-
check_value = value.lower()
|
|
58
|
-
check_choices = [choice.lower() for choice in actual_choices]
|
|
59
|
-
else:
|
|
60
|
-
check_value = value
|
|
61
|
-
check_choices = actual_choices
|
|
62
|
-
|
|
63
|
-
if check_value not in check_choices:
|
|
64
|
-
raise ValueError(f"Invalid value '{value}' for {env_name}. "
|
|
65
|
-
f"Valid options: {actual_choices}.")
|
|
66
|
-
|
|
67
|
-
return value
|
|
68
|
-
|
|
69
|
-
return _get_validated_env
|
|
70
|
-
|
|
71
|
-
|
|
72
26
|
environment_variables: dict[str, Callable[[], Any]] = {
|
|
73
27
|
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
|
|
74
28
|
"JAX_PLATFORMS":
|
|
75
|
-
lambda: os.getenv("JAX_PLATFORMS", "")
|
|
29
|
+
lambda: os.getenv("JAX_PLATFORMS", ""),
|
|
76
30
|
# TPU accelerator type (e.g., "v5litepod-16", "v4-8")
|
|
77
31
|
"TPU_ACCELERATOR_TYPE":
|
|
78
32
|
lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
|
|
@@ -84,7 +38,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
84
38
|
lambda: os.getenv("TPU_WORKER_ID", None),
|
|
85
39
|
# Backend for multi-host communication on TPU
|
|
86
40
|
"TPU_MULTIHOST_BACKEND":
|
|
87
|
-
|
|
41
|
+
lambda: os.getenv("TPU_MULTIHOST_BACKEND", "").lower(),
|
|
88
42
|
# Slice configuration for disaggregated prefill workers
|
|
89
43
|
"PREFILL_SLICES":
|
|
90
44
|
lambda: os.getenv("PREFILL_SLICES", ""),
|
|
@@ -93,35 +47,28 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
93
47
|
lambda: os.getenv("DECODE_SLICES", ""),
|
|
94
48
|
# Skip JAX precompilation step during initialization
|
|
95
49
|
"SKIP_JAX_PRECOMPILE":
|
|
96
|
-
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE"
|
|
97
|
-
# Check for XLA recompilation during execution
|
|
98
|
-
"VLLM_XLA_CHECK_RECOMPILATION":
|
|
99
|
-
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
|
|
50
|
+
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))),
|
|
100
51
|
# Model implementation type (e.g., "flax_nnx")
|
|
101
52
|
"MODEL_IMPL_TYPE":
|
|
102
|
-
|
|
103
|
-
["vllm", "flax_nnx", "jetpack"]),
|
|
53
|
+
lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(),
|
|
104
54
|
# Enable new experimental model design
|
|
105
55
|
"NEW_MODEL_DESIGN":
|
|
106
|
-
lambda: bool(int(os.getenv("NEW_MODEL_DESIGN"
|
|
56
|
+
lambda: bool(int(os.getenv("NEW_MODEL_DESIGN", "0"))),
|
|
107
57
|
# Directory to store phased profiling output
|
|
108
58
|
"PHASED_PROFILING_DIR":
|
|
109
59
|
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
|
|
110
60
|
# Python tracer level for profiling
|
|
111
61
|
"PYTHON_TRACER_LEVEL":
|
|
112
|
-
lambda: int(os.getenv("PYTHON_TRACER_LEVEL"
|
|
62
|
+
lambda: int(os.getenv("PYTHON_TRACER_LEVEL", "1")),
|
|
113
63
|
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
|
|
114
64
|
"USE_MOE_EP_KERNEL":
|
|
115
|
-
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL"
|
|
116
|
-
# Number of TPU slices for multi-slice mesh
|
|
117
|
-
"NUM_SLICES":
|
|
118
|
-
lambda: int(os.getenv("NUM_SLICES") or "1"),
|
|
65
|
+
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))),
|
|
119
66
|
# Enable/disable Ray usage statistics collection
|
|
120
67
|
"RAY_USAGE_STATS_ENABLED":
|
|
121
68
|
lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
|
|
122
69
|
# Ray compiled DAG channel type for TPU
|
|
123
70
|
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
|
|
124
|
-
|
|
71
|
+
lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm"),
|
|
125
72
|
}
|
|
126
73
|
|
|
127
74
|
|
|
@@ -108,9 +108,6 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
108
108
|
ip_port = self.collective_rpc("get_node_kv_ip_port")
|
|
109
109
|
for item in ip_port:
|
|
110
110
|
set_node_kv_ip_port(item)
|
|
111
|
-
self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
|
|
112
|
-
self.vllm_config.ec_transfer_config is None
|
|
113
|
-
or not self.vllm_config.ec_transfer_config.is_ec_producer)
|
|
114
111
|
|
|
115
112
|
def _initialize_ray_cluster(self) -> None:
|
|
116
113
|
"""Initialize the distributed cluster with Ray.
|
|
@@ -134,21 +131,10 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
134
131
|
f"current platform {current_platform.device_name} does not "
|
|
135
132
|
"support ray.")
|
|
136
133
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
|
|
142
|
-
|
|
143
|
-
if pp_size == 1:
|
|
144
|
-
placement_group_specs = [{
|
|
145
|
-
device_str: node['Resources'][device_str]
|
|
146
|
-
} for node in ray_nodes]
|
|
147
|
-
else:
|
|
148
|
-
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
|
|
149
|
-
placement_group_specs = [{
|
|
150
|
-
device_str: num_devices_per_pp_rank
|
|
151
|
-
} for _ in range(pp_size)]
|
|
134
|
+
placement_group_specs: List[Dict[str, float]] = [{
|
|
135
|
+
device_str:
|
|
136
|
+
node['Resources'][device_str]
|
|
137
|
+
} for node in ray.nodes()]
|
|
152
138
|
|
|
153
139
|
# vLLM engine is also a worker to execute model with an accelerator,
|
|
154
140
|
# so it requires to have the device in a current node. Check if
|
|
@@ -343,8 +329,6 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
343
329
|
all_kwargs = []
|
|
344
330
|
for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
|
|
345
331
|
local_rank = node_workers[node_id].index(rank)
|
|
346
|
-
ip = sorted_worker_metadata[rank].ip
|
|
347
|
-
prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
|
|
348
332
|
kwargs = dict(
|
|
349
333
|
vllm_config=self.vllm_config,
|
|
350
334
|
local_rank=local_rank,
|
|
@@ -352,26 +336,22 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
352
336
|
distributed_init_method=distributed_init_method,
|
|
353
337
|
is_driver_worker=(not self.parallel_config)
|
|
354
338
|
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
|
355
|
-
ip=ip,
|
|
356
|
-
prev_worker_ip=prev_ip,
|
|
357
339
|
)
|
|
358
340
|
all_kwargs.append(kwargs)
|
|
359
341
|
self.collective_rpc("init_worker", args=(all_kwargs, ))
|
|
360
342
|
self.collective_rpc("init_device")
|
|
361
|
-
if self.parallel_config.pipeline_parallel_size > 1:
|
|
362
|
-
self.collective_rpc("initialize_pp_transfer_connect")
|
|
363
343
|
self.collective_rpc("load_model")
|
|
364
344
|
|
|
365
345
|
if self.use_ray_spmd_worker:
|
|
366
346
|
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
|
367
347
|
self.pp_tp_workers.append([])
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
#
|
|
373
|
-
|
|
374
|
-
|
|
348
|
+
for tp_rank in range(
|
|
349
|
+
int(self.parallel_config.tensor_parallel_size //
|
|
350
|
+
num_tpu_per_worker)):
|
|
351
|
+
# PP=2, TP=4
|
|
352
|
+
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
|
|
353
|
+
rank = (pp_rank * self.parallel_config.tensor_parallel_size
|
|
354
|
+
) + tp_rank
|
|
375
355
|
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
|
376
356
|
assert pp_rank < len(self.pp_tp_workers)
|
|
377
357
|
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|