tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tests/test_envs.py +182 -0
- tests/test_utils.py +23 -14
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/core_tpu.py +17 -9
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +27 -11
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +110 -64
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
- tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/quantization/__init__.py +7 -3
- tpu_inference/layers/vllm/quantization/awq.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +12 -11
- tpu_inference/models/jax/llama3.py +4 -3
- tpu_inference/models/jax/llama_eagle3.py +9 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +3 -2
- tpu_inference/models/jax/qwen2_5_vl.py +4 -3
- tpu_inference/models/jax/qwen3.py +3 -2
- tpu_inference/models/jax/utils/weight_utils.py +21 -8
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
- tpu_inference/platforms/tpu_platform.py +17 -7
- tpu_inference/runner/compilation_manager.py +37 -17
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +8 -2
- tpu_inference/runner/tpu_runner.py +199 -87
- tpu_inference/spec_decode/jax/eagle3.py +2 -1
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +7 -6
- tpu_inference/worker/tpu_worker.py +159 -23
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +52 -54
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
- /tpu_inference/layers/{jax → common}/sharding.py +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
tests/lora/test_layers.py
CHANGED
|
@@ -91,7 +91,6 @@ def populate_loras(
|
|
|
91
91
|
index_to_id: list[Optional[int]],
|
|
92
92
|
lora_layer: BaseLayerWithLoRA,
|
|
93
93
|
baselayer_weights: torch.Tensor,
|
|
94
|
-
generate_embeddings_tensor: int = 0,
|
|
95
94
|
repeats: int = 1,
|
|
96
95
|
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
|
|
97
96
|
"""This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
|
|
@@ -103,8 +102,6 @@ def populate_loras(
|
|
|
103
102
|
lora_layer: the LoRAlayer to populate.
|
|
104
103
|
baselayer_weights: the PyTorch tensor containing the layer's
|
|
105
104
|
weights.
|
|
106
|
-
generate_embeddings_tensor: whether to generate an
|
|
107
|
-
embeddings tensor for each LoRA.
|
|
108
105
|
repeats: must only be set for column parallel packed
|
|
109
106
|
layers. Indicates the number of loras to compose
|
|
110
107
|
together to create a single lora layer.
|
|
@@ -131,7 +128,6 @@ def populate_loras(
|
|
|
131
128
|
baselayer_weights.device).init_random_lora(
|
|
132
129
|
module_name=f"fake_{i}",
|
|
133
130
|
weight=baselayer_weights,
|
|
134
|
-
generate_embeddings_tensor=generate_embeddings_tensor,
|
|
135
131
|
)
|
|
136
132
|
sublora.lora_b = sublora.lora_b[(sublora_len *
|
|
137
133
|
i):(sublora_len * (i + 1)), :]
|
|
@@ -147,7 +143,6 @@ def populate_loras(
|
|
|
147
143
|
slot_idx,
|
|
148
144
|
lora_a=lora.lora_a,
|
|
149
145
|
lora_b=lora.lora_b,
|
|
150
|
-
embeddings_tensor=lora.embeddings_tensor,
|
|
151
146
|
)
|
|
152
147
|
|
|
153
148
|
lora_dict[lora_id] = lora
|
|
@@ -546,7 +541,6 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
|
|
|
546
541
|
index_to_id,
|
|
547
542
|
lora_config.max_loras,
|
|
548
543
|
vocab_size=512,
|
|
549
|
-
extra_vocab_size=lora_config.lora_extra_vocab_size,
|
|
550
544
|
)
|
|
551
545
|
assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
|
|
552
546
|
) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
|
tests/lora/utils.py
CHANGED
|
@@ -24,7 +24,6 @@ class DummyLoRAManager:
|
|
|
24
24
|
module_name: str,
|
|
25
25
|
weight: torch.Tensor,
|
|
26
26
|
rank: int = 8,
|
|
27
|
-
generate_embeddings_tensor: int = 0,
|
|
28
27
|
):
|
|
29
28
|
lora = LoRALayerWeights(
|
|
30
29
|
module_name,
|
|
@@ -37,13 +36,6 @@ class DummyLoRAManager:
|
|
|
37
36
|
dtype=weight.dtype,
|
|
38
37
|
device=self._device),
|
|
39
38
|
)
|
|
40
|
-
if generate_embeddings_tensor:
|
|
41
|
-
lora.embeddings_tensor = torch.rand(
|
|
42
|
-
5,
|
|
43
|
-
generate_embeddings_tensor,
|
|
44
|
-
dtype=weight.dtype,
|
|
45
|
-
device=self._device,
|
|
46
|
-
)
|
|
47
39
|
self.set_module_lora(module_name, lora)
|
|
48
40
|
|
|
49
41
|
return lora
|
tests/test_envs.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
import tpu_inference.envs as envs
|
|
7
|
+
from tpu_inference.envs import enable_envs_cache, environment_variables
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_getattr_without_cache(monkeypatch: pytest.MonkeyPatch):
|
|
11
|
+
assert envs.JAX_PLATFORMS == ""
|
|
12
|
+
assert envs.PHASED_PROFILING_DIR == ""
|
|
13
|
+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
|
|
14
|
+
monkeypatch.setenv("PHASED_PROFILING_DIR", "/tmp/profiling")
|
|
15
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
16
|
+
assert envs.PHASED_PROFILING_DIR == "/tmp/profiling"
|
|
17
|
+
|
|
18
|
+
assert envs.TPU_NAME is None
|
|
19
|
+
assert envs.TPU_ACCELERATOR_TYPE is None
|
|
20
|
+
monkeypatch.setenv("TPU_NAME", "my-tpu")
|
|
21
|
+
monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v5litepod-16")
|
|
22
|
+
assert envs.TPU_NAME == "my-tpu"
|
|
23
|
+
assert envs.TPU_ACCELERATOR_TYPE == "v5litepod-16"
|
|
24
|
+
|
|
25
|
+
# __getattr__ is not decorated with functools.cache
|
|
26
|
+
assert not hasattr(envs.__getattr__, "cache_info")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
|
|
30
|
+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
|
|
31
|
+
monkeypatch.setenv("TPU_NAME", "my-tpu")
|
|
32
|
+
|
|
33
|
+
# __getattr__ is not decorated with functools.cache
|
|
34
|
+
assert not hasattr(envs.__getattr__, "cache_info")
|
|
35
|
+
|
|
36
|
+
enable_envs_cache()
|
|
37
|
+
|
|
38
|
+
# __getattr__ is decorated with functools.cache
|
|
39
|
+
assert hasattr(envs.__getattr__, "cache_info")
|
|
40
|
+
start_hits = envs.__getattr__.cache_info().hits
|
|
41
|
+
|
|
42
|
+
# 2 more hits due to JAX_PLATFORMS and TPU_NAME accesses
|
|
43
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
44
|
+
assert envs.TPU_NAME == "my-tpu"
|
|
45
|
+
assert envs.__getattr__.cache_info().hits == start_hits + 2
|
|
46
|
+
|
|
47
|
+
# All environment variables are cached
|
|
48
|
+
for environment_variable in environment_variables:
|
|
49
|
+
envs.__getattr__(environment_variable)
|
|
50
|
+
assert envs.__getattr__.cache_info(
|
|
51
|
+
).hits == start_hits + 2 + len(environment_variables)
|
|
52
|
+
|
|
53
|
+
# Reset envs.__getattr__ back to non-cached version to
|
|
54
|
+
# avoid affecting other tests
|
|
55
|
+
envs.__getattr__ = envs.__getattr__.__wrapped__
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
59
|
+
# Test SKIP_JAX_PRECOMPILE (default False)
|
|
60
|
+
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
61
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
|
|
62
|
+
assert envs.SKIP_JAX_PRECOMPILE is True
|
|
63
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
64
|
+
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
65
|
+
|
|
66
|
+
# Test NEW_MODEL_DESIGN (default False)
|
|
67
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
68
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
|
|
69
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
70
|
+
|
|
71
|
+
# Test USE_MOE_EP_KERNEL (default False)
|
|
72
|
+
assert envs.USE_MOE_EP_KERNEL is False
|
|
73
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
|
|
74
|
+
assert envs.USE_MOE_EP_KERNEL is True
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
78
|
+
assert envs.PYTHON_TRACER_LEVEL == 1
|
|
79
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
|
|
80
|
+
assert envs.PYTHON_TRACER_LEVEL == 3
|
|
81
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
|
|
82
|
+
assert envs.PYTHON_TRACER_LEVEL == 0
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
|
|
86
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
|
|
87
|
+
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
|
|
88
|
+
|
|
89
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "FLAX_NNX")
|
|
90
|
+
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
|
|
94
|
+
monkeypatch.delenv("JAX_PLATFORMS", raising=False)
|
|
95
|
+
monkeypatch.delenv("PREFILL_SLICES", raising=False)
|
|
96
|
+
monkeypatch.delenv("DECODE_SLICES", raising=False)
|
|
97
|
+
|
|
98
|
+
assert envs.JAX_PLATFORMS == ""
|
|
99
|
+
assert envs.PREFILL_SLICES == ""
|
|
100
|
+
assert envs.DECODE_SLICES == ""
|
|
101
|
+
assert envs.PHASED_PROFILING_DIR == ""
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_none_default_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
105
|
+
monkeypatch.delenv("TPU_ACCELERATOR_TYPE", raising=False)
|
|
106
|
+
monkeypatch.delenv("TPU_NAME", raising=False)
|
|
107
|
+
monkeypatch.delenv("TPU_WORKER_ID", raising=False)
|
|
108
|
+
|
|
109
|
+
assert envs.TPU_ACCELERATOR_TYPE is None
|
|
110
|
+
assert envs.TPU_NAME is None
|
|
111
|
+
assert envs.TPU_WORKER_ID is None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
115
|
+
assert envs.RAY_USAGE_STATS_ENABLED == "0"
|
|
116
|
+
monkeypatch.setenv("RAY_USAGE_STATS_ENABLED", "1")
|
|
117
|
+
assert envs.RAY_USAGE_STATS_ENABLED == "1"
|
|
118
|
+
|
|
119
|
+
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
|
|
120
|
+
monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
|
|
121
|
+
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def test_invalid_attribute_raises_error():
|
|
125
|
+
with pytest.raises(AttributeError,
|
|
126
|
+
match="has no attribute 'NONEXISTENT_VAR'"):
|
|
127
|
+
_ = envs.NONEXISTENT_VAR
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_dir_returns_all_env_vars():
|
|
131
|
+
env_vars = envs.__dir__()
|
|
132
|
+
assert isinstance(env_vars, list)
|
|
133
|
+
assert len(env_vars) == len(environment_variables)
|
|
134
|
+
assert "JAX_PLATFORMS" in env_vars
|
|
135
|
+
assert "TPU_NAME" in env_vars
|
|
136
|
+
assert "SKIP_JAX_PRECOMPILE" in env_vars
|
|
137
|
+
assert "MODEL_IMPL_TYPE" in env_vars
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
141
|
+
monkeypatch.setenv("TPU_WORKER_ID", "0")
|
|
142
|
+
assert envs.TPU_WORKER_ID == "0"
|
|
143
|
+
|
|
144
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "grpc")
|
|
145
|
+
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
|
|
146
|
+
|
|
147
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
|
|
148
|
+
assert envs.TPU_MULTIHOST_BACKEND == "xla"
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
152
|
+
monkeypatch.setenv("PREFILL_SLICES", "0,1,2,3")
|
|
153
|
+
assert envs.PREFILL_SLICES == "0,1,2,3"
|
|
154
|
+
|
|
155
|
+
monkeypatch.setenv("DECODE_SLICES", "4,5,6,7")
|
|
156
|
+
assert envs.DECODE_SLICES == "4,5,6,7"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
|
|
160
|
+
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
|
|
161
|
+
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def test_cache_preserves_values_across_env_changes(
|
|
165
|
+
monkeypatch: pytest.MonkeyPatch):
|
|
166
|
+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
|
|
167
|
+
|
|
168
|
+
enable_envs_cache()
|
|
169
|
+
|
|
170
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
171
|
+
|
|
172
|
+
# Change environment variable
|
|
173
|
+
monkeypatch.setenv("JAX_PLATFORMS", "cpu")
|
|
174
|
+
|
|
175
|
+
# Cached value should still be "tpu"
|
|
176
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
177
|
+
|
|
178
|
+
# Reset envs.__getattr__ back to non-cached version
|
|
179
|
+
envs.__getattr__ = envs.__getattr__.__wrapped__
|
|
180
|
+
|
|
181
|
+
# Now it should reflect the new value
|
|
182
|
+
assert envs.JAX_PLATFORMS == "cpu"
|
tests/test_utils.py
CHANGED
|
@@ -75,25 +75,34 @@ def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
|
|
|
75
75
|
mock_device2 = MagicMock()
|
|
76
76
|
devices = [mock_device1, mock_device2]
|
|
77
77
|
|
|
78
|
-
# Create mock
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
78
|
+
# Create mock addressable shards with data property
|
|
79
|
+
mock_data1_dev1 = MagicMock()
|
|
80
|
+
mock_data1_dev1.device = mock_device1
|
|
81
|
+
mock_data1_dev1.nbytes = 2000 # 2000 bytes on device1
|
|
82
82
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
83
|
+
mock_data1_dev2 = MagicMock()
|
|
84
|
+
mock_data1_dev2.device = mock_device2
|
|
85
|
+
mock_data1_dev2.nbytes = 2000 # 2000 bytes on device2
|
|
86
86
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
87
|
+
mock_data2_dev1 = MagicMock()
|
|
88
|
+
mock_data2_dev1.device = mock_device1
|
|
89
|
+
mock_data2_dev1.nbytes = 1000 # 1000 bytes on device1
|
|
90
90
|
|
|
91
|
-
|
|
91
|
+
mock_shard1_dev1 = MagicMock()
|
|
92
|
+
mock_shard1_dev1.data = mock_data1_dev1
|
|
93
|
+
|
|
94
|
+
mock_shard1_dev2 = MagicMock()
|
|
95
|
+
mock_shard1_dev2.data = mock_data1_dev2
|
|
96
|
+
|
|
97
|
+
mock_shard2_dev1 = MagicMock()
|
|
98
|
+
mock_shard2_dev1.data = mock_data2_dev1
|
|
99
|
+
|
|
100
|
+
# Create mock arrays with addressable_shards
|
|
92
101
|
mock_array1 = MagicMock()
|
|
93
|
-
mock_array1.
|
|
102
|
+
mock_array1.addressable_shards = [mock_shard1_dev1, mock_shard1_dev2]
|
|
94
103
|
|
|
95
104
|
mock_array2 = MagicMock()
|
|
96
|
-
mock_array2.
|
|
105
|
+
mock_array2.addressable_shards = [mock_shard2_dev1]
|
|
97
106
|
|
|
98
107
|
mock_live_arrays.return_value = [mock_array1, mock_array2]
|
|
99
108
|
|
|
@@ -159,7 +168,7 @@ def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
|
|
|
159
168
|
"head_dim, expected_padded_head_dim",
|
|
160
169
|
[
|
|
161
170
|
(1, 128),
|
|
162
|
-
(64,
|
|
171
|
+
(64, 64),
|
|
163
172
|
(127, 128),
|
|
164
173
|
(128, 128),
|
|
165
174
|
(129, 256),
|
tpu_inference/__init__.py
CHANGED
|
@@ -1,21 +1,40 @@
|
|
|
1
|
-
import os
|
|
2
|
-
|
|
3
1
|
# The environment variables override should be imported before any other
|
|
4
2
|
# modules to ensure that the environment variables are set before any
|
|
5
3
|
# other modules are imported.
|
|
6
4
|
import tpu_inference.env_override # noqa: F401
|
|
5
|
+
from tpu_inference import envs
|
|
7
6
|
from tpu_inference import tpu_info as ti
|
|
8
7
|
from tpu_inference.logger import init_logger
|
|
9
8
|
|
|
10
9
|
logger = init_logger(__name__)
|
|
11
10
|
|
|
12
|
-
if "proxy" in
|
|
11
|
+
if "proxy" in envs.JAX_PLATFORMS:
|
|
13
12
|
logger.info("Running vLLM on TPU via Pathways proxy.")
|
|
14
13
|
# Must run pathwaysutils.initialize() before any JAX operations
|
|
15
14
|
try:
|
|
15
|
+
import traceback
|
|
16
|
+
|
|
16
17
|
import pathwaysutils
|
|
18
|
+
import vllm
|
|
19
|
+
from vllm.platforms import (resolve_current_platform_cls_qualname,
|
|
20
|
+
resolve_obj_by_qualname)
|
|
17
21
|
pathwaysutils.initialize()
|
|
18
22
|
logger.info("Module pathwaysutils is imported.")
|
|
23
|
+
|
|
24
|
+
# Pathways requires eager resolution of vllm.current_platform instead of
|
|
25
|
+
# lazy resolution in the normal code path. Since this part involves
|
|
26
|
+
# global topology discovery across multiple hosts, the platform
|
|
27
|
+
# resolution must happen before other components are loaded.
|
|
28
|
+
logger.info("Eagerly resolving vLLM current_platform for Pathways.")
|
|
29
|
+
platform_cls_qualname = resolve_current_platform_cls_qualname()
|
|
30
|
+
resolved_platform_instance = resolve_obj_by_qualname(
|
|
31
|
+
platform_cls_qualname)()
|
|
32
|
+
vllm.platforms._current_platform = resolved_platform_instance
|
|
33
|
+
vllm.platforms._init_trace = "".join(traceback.format_stack())
|
|
34
|
+
logger.info(
|
|
35
|
+
f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
|
|
36
|
+
)
|
|
37
|
+
|
|
19
38
|
except Exception as e:
|
|
20
39
|
logger.error(
|
|
21
40
|
f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
|
tpu_inference/core/core_tpu.py
CHANGED
|
@@ -29,6 +29,7 @@ from vllm.v1.request import Request, RequestStatus
|
|
|
29
29
|
|
|
30
30
|
from tpu_inference import utils as common_utils
|
|
31
31
|
from tpu_inference.core import disagg_executor, disagg_utils
|
|
32
|
+
from tpu_inference.runner.tpu_runner import AsyncTPUModelRunnerOutput
|
|
32
33
|
# ======================================================================================
|
|
33
34
|
# Imports for _DisaggOrchestrator (decoupled from vLLM)
|
|
34
35
|
# ======================================================================================
|
|
@@ -186,6 +187,8 @@ class _DisaggOrchestrator:
|
|
|
186
187
|
if model_output is None:
|
|
187
188
|
model_output = prefill_engine.model_executor.sample_tokens(
|
|
188
189
|
grammar_output)
|
|
190
|
+
if isinstance(model_output, AsyncTPUModelRunnerOutput):
|
|
191
|
+
model_output = model_output.get_output()
|
|
189
192
|
|
|
190
193
|
if scheduler_output.total_num_scheduled_tokens > 0:
|
|
191
194
|
logger.debug(f"Prefill result: {model_output}")
|
|
@@ -218,15 +221,16 @@ class _DisaggOrchestrator:
|
|
|
218
221
|
f"request-{req_id}: tokens={request.all_token_ids} after prefill"
|
|
219
222
|
)
|
|
220
223
|
# Remove request from the prefill engine.
|
|
224
|
+
if req_id in prefill_engine.scheduler.requests:
|
|
225
|
+
request = prefill_engine.scheduler.requests[req_id]
|
|
226
|
+
prefill_engine.scheduler.running.remove(request)
|
|
227
|
+
prefill_engine.scheduler.encoder_cache_manager.free(
|
|
228
|
+
request)
|
|
221
229
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
prefill_engine.scheduler.encoder_cache_manager.free(
|
|
225
|
-
request)
|
|
230
|
+
prefill_engine.scheduler.kv_cache_manager.free(
|
|
231
|
+
request)
|
|
226
232
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
prefill_engine.scheduler.requests.pop(req_id)
|
|
233
|
+
prefill_engine.scheduler.requests.pop(req_id)
|
|
230
234
|
|
|
231
235
|
for output in (engine_core_outputs.items()
|
|
232
236
|
if engine_core_outputs else ()):
|
|
@@ -335,8 +339,10 @@ class _DisaggOrchestrator:
|
|
|
335
339
|
new_block_ids = kv_cache_manager.get_block_ids(req_id)
|
|
336
340
|
logger.debug(
|
|
337
341
|
f"inserting {req_id} new_block_ids {new_block_ids}")
|
|
338
|
-
|
|
339
|
-
|
|
342
|
+
if len(new_block_ids[0]) != math.ceil(
|
|
343
|
+
prompt_tokens / self._config.cache_config.block_size):
|
|
344
|
+
logger.warning("Running out of blocks in decode engine! ")
|
|
345
|
+
break
|
|
340
346
|
|
|
341
347
|
decode_engine.model_executor.driver_worker.model_runner.insert_request_with_kv_cache(
|
|
342
348
|
vllm_request, kv_cache, new_block_ids)
|
|
@@ -366,6 +372,8 @@ class _DisaggOrchestrator:
|
|
|
366
372
|
if model_output is None:
|
|
367
373
|
model_output = decode_engine.model_executor.sample_tokens(
|
|
368
374
|
grammar_output)
|
|
375
|
+
if isinstance(model_output, AsyncTPUModelRunnerOutput):
|
|
376
|
+
model_output = model_output.get_output()
|
|
369
377
|
|
|
370
378
|
if scheduler_output.total_num_scheduled_tokens > 0:
|
|
371
379
|
logger.debug(f"Decode result: {model_output}")
|
|
@@ -1,17 +1,15 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from typing import Tuple
|
|
5
4
|
|
|
6
|
-
|
|
7
|
-
DECODE_SLICES = 'DECODE_SLICES'
|
|
5
|
+
from tpu_inference import envs
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
def is_disagg_enabled() -> bool:
|
|
11
9
|
# We triggrer our code path as long as prefill slices are set. This
|
|
12
10
|
# allows us to test interleave mode effectively with the code path
|
|
13
11
|
# for comparison purposes.
|
|
14
|
-
return PREFILL_SLICES
|
|
12
|
+
return bool(envs.PREFILL_SLICES)
|
|
15
13
|
|
|
16
14
|
|
|
17
15
|
def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
@@ -40,12 +38,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
|
40
38
|
|
|
41
39
|
|
|
42
40
|
def get_prefill_slices() -> Tuple[int, ...]:
|
|
43
|
-
if
|
|
41
|
+
if not envs.PREFILL_SLICES:
|
|
44
42
|
return ()
|
|
45
|
-
return _parse_slices(
|
|
43
|
+
return _parse_slices(envs.PREFILL_SLICES)
|
|
46
44
|
|
|
47
45
|
|
|
48
46
|
def get_decode_slices() -> Tuple[int, ...]:
|
|
49
|
-
if
|
|
47
|
+
if not envs.DECODE_SLICES:
|
|
50
48
|
return ()
|
|
51
|
-
return _parse_slices(
|
|
49
|
+
return _parse_slices(envs.DECODE_SLICES)
|
|
@@ -60,7 +60,6 @@ D workflow:
|
|
|
60
60
|
|
|
61
61
|
import copy
|
|
62
62
|
import functools
|
|
63
|
-
import os
|
|
64
63
|
import threading
|
|
65
64
|
import time
|
|
66
65
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
@@ -86,6 +85,7 @@ if TYPE_CHECKING:
|
|
|
86
85
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
87
86
|
from vllm.v1.request import Request
|
|
88
87
|
|
|
88
|
+
from tpu_inference import envs
|
|
89
89
|
from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
|
|
90
90
|
get_kv_ports,
|
|
91
91
|
get_kv_transfer_port, get_node_id,
|
|
@@ -441,8 +441,7 @@ class TPUConnectorWorker:
|
|
|
441
441
|
|
|
442
442
|
self.runner: TPUModelRunner = None
|
|
443
443
|
self.mesh: Mesh = None
|
|
444
|
-
self.multi_host =
|
|
445
|
-
"").lower() == "ray"
|
|
444
|
+
self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
|
|
446
445
|
# NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
|
|
447
446
|
# The worker rank is assigned with vLLM's sorting logic, which does not work
|
|
448
447
|
# for TPU host topology.
|
|
@@ -2,6 +2,7 @@ import os
|
|
|
2
2
|
|
|
3
3
|
from vllm.utils.network_utils import get_ip
|
|
4
4
|
|
|
5
|
+
from tpu_inference import envs
|
|
5
6
|
from tpu_inference.logger import init_logger
|
|
6
7
|
|
|
7
8
|
logger = init_logger(__name__)
|
|
@@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
|
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
def get_kv_ips() -> str:
|
|
20
|
-
if
|
|
21
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
21
22
|
num_nodes = len(_NODES_KV_IP_PORT)
|
|
22
23
|
ips = []
|
|
23
24
|
for node_id in range(num_nodes):
|
|
@@ -28,7 +29,7 @@ def get_kv_ips() -> str:
|
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
def get_kv_ports() -> str:
|
|
31
|
-
if
|
|
32
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
32
33
|
num_nodes = len(_NODES_KV_IP_PORT)
|
|
33
34
|
ports = []
|
|
34
35
|
for node_id in range(num_nodes):
|
tpu_inference/envs.py
CHANGED
|
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|
|
26
26
|
environment_variables: dict[str, Callable[[], Any]] = {
|
|
27
27
|
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
|
|
28
28
|
"JAX_PLATFORMS":
|
|
29
|
-
lambda: os.getenv("JAX_PLATFORMS", ""),
|
|
29
|
+
lambda: os.getenv("JAX_PLATFORMS", "").lower(),
|
|
30
30
|
# TPU accelerator type (e.g., "v5litepod-16", "v4-8")
|
|
31
31
|
"TPU_ACCELERATOR_TYPE":
|
|
32
32
|
lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
|
|
@@ -108,6 +108,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
108
108
|
ip_port = self.collective_rpc("get_node_kv_ip_port")
|
|
109
109
|
for item in ip_port:
|
|
110
110
|
set_node_kv_ip_port(item)
|
|
111
|
+
self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
|
|
112
|
+
self.vllm_config.ec_transfer_config is None
|
|
113
|
+
or not self.vllm_config.ec_transfer_config.is_ec_producer)
|
|
111
114
|
|
|
112
115
|
def _initialize_ray_cluster(self) -> None:
|
|
113
116
|
"""Initialize the distributed cluster with Ray.
|
|
@@ -131,10 +134,17 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
131
134
|
f"current platform {current_platform.device_name} does not "
|
|
132
135
|
"support ray.")
|
|
133
136
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
137
|
+
pp_size = self.parallel_config.pipeline_parallel_size
|
|
138
|
+
placement_group_specs: List[Dict[str, float]] = []
|
|
139
|
+
if pp_size == 1:
|
|
140
|
+
placement_group_specs = [{
|
|
141
|
+
device_str: node['Resources'][device_str]
|
|
142
|
+
} for node in ray.nodes()]
|
|
143
|
+
else:
|
|
144
|
+
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
|
|
145
|
+
placement_group_specs = [{
|
|
146
|
+
device_str: num_devices_per_pp_rank
|
|
147
|
+
} for _ in range(pp_size)]
|
|
138
148
|
|
|
139
149
|
# vLLM engine is also a worker to execute model with an accelerator,
|
|
140
150
|
# so it requires to have the device in a current node. Check if
|
|
@@ -329,6 +339,8 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
329
339
|
all_kwargs = []
|
|
330
340
|
for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
|
|
331
341
|
local_rank = node_workers[node_id].index(rank)
|
|
342
|
+
ip = sorted_worker_metadata[rank].ip
|
|
343
|
+
prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
|
|
332
344
|
kwargs = dict(
|
|
333
345
|
vllm_config=self.vllm_config,
|
|
334
346
|
local_rank=local_rank,
|
|
@@ -336,22 +348,26 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
336
348
|
distributed_init_method=distributed_init_method,
|
|
337
349
|
is_driver_worker=(not self.parallel_config)
|
|
338
350
|
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
|
351
|
+
ip=ip,
|
|
352
|
+
prev_worker_ip=prev_ip,
|
|
339
353
|
)
|
|
340
354
|
all_kwargs.append(kwargs)
|
|
341
355
|
self.collective_rpc("init_worker", args=(all_kwargs, ))
|
|
342
356
|
self.collective_rpc("init_device")
|
|
357
|
+
if self.parallel_config.pipeline_parallel_size > 1:
|
|
358
|
+
self.collective_rpc("initialize_pp_transfer_connect")
|
|
343
359
|
self.collective_rpc("load_model")
|
|
344
360
|
|
|
345
361
|
if self.use_ray_spmd_worker:
|
|
346
362
|
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
|
347
363
|
self.pp_tp_workers.append([])
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
#
|
|
353
|
-
|
|
354
|
-
|
|
364
|
+
num_tp_workers = int(
|
|
365
|
+
self.parallel_config.tensor_parallel_size //
|
|
366
|
+
num_tpu_per_worker)
|
|
367
|
+
for tp_rank in range(num_tp_workers):
|
|
368
|
+
# PP=2, TP=4, num_tpu_per_worker=2
|
|
369
|
+
# pp_tp_workers = [[0, 1], [2, 3]]
|
|
370
|
+
rank = (pp_rank * num_tp_workers) + tp_rank
|
|
355
371
|
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
|
356
372
|
assert pp_rank < len(self.pp_tp_workers)
|
|
357
373
|
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|