tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -1
- tests/lora/test_lora_perf.py +53 -0
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/models/common/model_loader.py +78 -22
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama_eagle3.py +4 -5
- tpu_inference/models/jax/qwen2_5_vl.py +161 -47
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +203 -155
- tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
- tpu_inference/platforms/tpu_platform.py +29 -48
- tpu_inference/runner/compilation_manager.py +112 -46
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +40 -31
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +94 -51
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -22
- tpu_inference/utils.py +41 -14
- tpu_inference/worker/tpu_worker.py +43 -45
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
tests/lora/test_layers.py
CHANGED
|
@@ -18,7 +18,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
|
|
18
18
|
ReplicatedLinearWithLoRA,
|
|
19
19
|
RowParallelLinearWithLoRA)
|
|
20
20
|
# yapf: enable
|
|
21
|
-
from vllm.lora.
|
|
21
|
+
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
|
22
22
|
from vllm.lora.punica_wrapper import get_punica_wrapper
|
|
23
23
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
24
24
|
MergedColumnParallelLinear,
|
|
@@ -205,6 +205,9 @@ def create_random_inputs(
|
|
|
205
205
|
@pytest.mark.parametrize("repeats", [1, 2, 3])
|
|
206
206
|
@pytest.mark.parametrize("stage", [True, False])
|
|
207
207
|
def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
|
|
208
|
+
# TODO(Qiliang Cui): Remove when issue is resolved.
|
|
209
|
+
if 'TPU7x' in jax.devices()[0].device_kind:
|
|
210
|
+
pytest.skip("Skipping test on TPU TPU7x.")
|
|
208
211
|
set_random_seed(6)
|
|
209
212
|
|
|
210
213
|
max_loras = 9
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import vllm
|
|
6
|
+
from vllm.lora.request import LoRARequest
|
|
7
|
+
|
|
8
|
+
TP = [2] if os.environ.get("USE_V6E8_QUEUE", False) else [1]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.mark.parametrize("tp", TP)
|
|
12
|
+
def test_lora_performance(tp):
|
|
13
|
+
prompt = "What is 1+1? \n"
|
|
14
|
+
llm_without_lora = vllm.LLM(
|
|
15
|
+
model="Qwen/Qwen2.5-3B-Instruct",
|
|
16
|
+
max_model_len=256,
|
|
17
|
+
max_num_batched_tokens=64,
|
|
18
|
+
max_num_seqs=8,
|
|
19
|
+
tensor_parallel_size=tp,
|
|
20
|
+
)
|
|
21
|
+
start_time = time.time()
|
|
22
|
+
llm_without_lora.generate(
|
|
23
|
+
prompt,
|
|
24
|
+
sampling_params=vllm.SamplingParams(max_tokens=16, temperature=0),
|
|
25
|
+
)[0].outputs[0].text
|
|
26
|
+
base_time = time.time() - start_time
|
|
27
|
+
|
|
28
|
+
del llm_without_lora
|
|
29
|
+
# Waiting for TPUs to be released
|
|
30
|
+
time.sleep(10)
|
|
31
|
+
|
|
32
|
+
llm_with_lora = vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
|
33
|
+
max_model_len=256,
|
|
34
|
+
max_num_batched_tokens=64,
|
|
35
|
+
max_num_seqs=8,
|
|
36
|
+
tensor_parallel_size=tp,
|
|
37
|
+
enable_lora=True,
|
|
38
|
+
max_loras=1,
|
|
39
|
+
max_lora_rank=8)
|
|
40
|
+
lora_request = LoRARequest(
|
|
41
|
+
"lora_adapter_2", 2,
|
|
42
|
+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
|
|
43
|
+
start_time = time.time()
|
|
44
|
+
llm_with_lora.generate(prompt,
|
|
45
|
+
sampling_params=vllm.SamplingParams(max_tokens=16,
|
|
46
|
+
temperature=0),
|
|
47
|
+
lora_request=lora_request)[0].outputs[0].text
|
|
48
|
+
lora_time = time.time() - start_time
|
|
49
|
+
print(f"Base time: {base_time}, LoRA time: {lora_time}")
|
|
50
|
+
assert (base_time /
|
|
51
|
+
lora_time) < 8, f"Base time: {base_time}, LoRA time: {lora_time}"
|
|
52
|
+
|
|
53
|
+
del llm_with_lora
|
tests/test_envs.py
CHANGED
|
@@ -56,6 +56,13 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
|
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
59
|
+
# Ensure clean environment for boolean vars by setting to default "0"
|
|
60
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
61
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
62
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
|
|
63
|
+
monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "0")
|
|
64
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
|
|
65
|
+
|
|
59
66
|
# Test SKIP_JAX_PRECOMPILE (default False)
|
|
60
67
|
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
61
68
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
|
|
@@ -63,6 +70,13 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
63
70
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
64
71
|
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
65
72
|
|
|
73
|
+
# Test VLLM_XLA_CHECK_RECOMPILATION (default False)
|
|
74
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
75
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1")
|
|
76
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
|
|
77
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
78
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
79
|
+
|
|
66
80
|
# Test NEW_MODEL_DESIGN (default False)
|
|
67
81
|
assert envs.NEW_MODEL_DESIGN is False
|
|
68
82
|
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
|
|
@@ -73,22 +87,110 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
73
87
|
monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
|
|
74
88
|
assert envs.USE_MOE_EP_KERNEL is True
|
|
75
89
|
|
|
90
|
+
# Test ENABLE_QUANTIZED_MATMUL_KERNEL (default False)
|
|
91
|
+
assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is False
|
|
92
|
+
monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "1")
|
|
93
|
+
assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is True
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_boolean_env_vars_string_values(monkeypatch: pytest.MonkeyPatch):
|
|
97
|
+
"""Test that boolean env vars accept string values like 'True' and 'False'"""
|
|
98
|
+
|
|
99
|
+
# Test NEW_MODEL_DESIGN with string "True"
|
|
100
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "True")
|
|
101
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
102
|
+
|
|
103
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "true")
|
|
104
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
105
|
+
|
|
106
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "False")
|
|
107
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
108
|
+
|
|
109
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "false")
|
|
110
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
111
|
+
|
|
112
|
+
# Test SKIP_JAX_PRECOMPILE with string values
|
|
113
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "True")
|
|
114
|
+
assert envs.SKIP_JAX_PRECOMPILE is True
|
|
115
|
+
|
|
116
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "false")
|
|
117
|
+
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
118
|
+
|
|
119
|
+
# Test VLLM_XLA_CHECK_RECOMPILATION with string values
|
|
120
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "TRUE")
|
|
121
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
|
|
122
|
+
|
|
123
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "FALSE")
|
|
124
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
125
|
+
|
|
126
|
+
# Test USE_MOE_EP_KERNEL with string values
|
|
127
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "true")
|
|
128
|
+
assert envs.USE_MOE_EP_KERNEL is True
|
|
129
|
+
|
|
130
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "False")
|
|
131
|
+
assert envs.USE_MOE_EP_KERNEL is False
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_boolean_env_vars_invalid_values(monkeypatch: pytest.MonkeyPatch):
|
|
135
|
+
"""Test that boolean env vars raise errors for invalid values"""
|
|
136
|
+
|
|
137
|
+
# Test invalid value for NEW_MODEL_DESIGN
|
|
138
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "yes")
|
|
139
|
+
with pytest.raises(
|
|
140
|
+
ValueError,
|
|
141
|
+
match="Invalid boolean value 'yes' for NEW_MODEL_DESIGN"):
|
|
142
|
+
_ = envs.NEW_MODEL_DESIGN
|
|
143
|
+
|
|
144
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "2")
|
|
145
|
+
with pytest.raises(ValueError,
|
|
146
|
+
match="Invalid boolean value '2' for NEW_MODEL_DESIGN"):
|
|
147
|
+
_ = envs.NEW_MODEL_DESIGN
|
|
148
|
+
|
|
149
|
+
# Test invalid value for SKIP_JAX_PRECOMPILE
|
|
150
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "invalid")
|
|
151
|
+
with pytest.raises(
|
|
152
|
+
ValueError,
|
|
153
|
+
match="Invalid boolean value 'invalid' for SKIP_JAX_PRECOMPILE"):
|
|
154
|
+
_ = envs.SKIP_JAX_PRECOMPILE
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_boolean_env_vars_empty_string(monkeypatch: pytest.MonkeyPatch):
|
|
158
|
+
"""Test that empty string returns default value"""
|
|
159
|
+
|
|
160
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "")
|
|
161
|
+
assert envs.NEW_MODEL_DESIGN is False # Should return default
|
|
162
|
+
|
|
163
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "")
|
|
164
|
+
assert envs.SKIP_JAX_PRECOMPILE is False # Should return default
|
|
165
|
+
|
|
76
166
|
|
|
77
167
|
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
168
|
+
# Ensure clean environment for integer vars by setting to defaults
|
|
169
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "1")
|
|
170
|
+
monkeypatch.setenv("NUM_SLICES", "1")
|
|
171
|
+
|
|
78
172
|
assert envs.PYTHON_TRACER_LEVEL == 1
|
|
79
173
|
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
|
|
80
174
|
assert envs.PYTHON_TRACER_LEVEL == 3
|
|
81
175
|
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
|
|
82
176
|
assert envs.PYTHON_TRACER_LEVEL == 0
|
|
83
177
|
|
|
178
|
+
# Test NUM_SLICES (default 1)
|
|
179
|
+
assert envs.NUM_SLICES == 1
|
|
180
|
+
monkeypatch.setenv("NUM_SLICES", "2")
|
|
181
|
+
assert envs.NUM_SLICES == 2
|
|
182
|
+
monkeypatch.setenv("NUM_SLICES", "4")
|
|
183
|
+
assert envs.NUM_SLICES == 4
|
|
84
184
|
|
|
85
|
-
def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
|
|
86
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
|
|
87
|
-
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
|
|
88
185
|
|
|
89
|
-
|
|
186
|
+
def test_model_impl_type_choices(monkeypatch: pytest.MonkeyPatch):
|
|
187
|
+
# Test case sensitive choices
|
|
188
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "flax_nnx")
|
|
90
189
|
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
91
190
|
|
|
191
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "vllm")
|
|
192
|
+
assert envs.MODEL_IMPL_TYPE == "vllm"
|
|
193
|
+
|
|
92
194
|
|
|
93
195
|
def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
|
|
94
196
|
monkeypatch.delenv("JAX_PLATFORMS", raising=False)
|
|
@@ -117,8 +219,6 @@ def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
117
219
|
assert envs.RAY_USAGE_STATS_ENABLED == "1"
|
|
118
220
|
|
|
119
221
|
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
|
|
120
|
-
monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
|
|
121
|
-
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
|
|
122
222
|
|
|
123
223
|
|
|
124
224
|
def test_invalid_attribute_raises_error():
|
|
@@ -134,6 +234,7 @@ def test_dir_returns_all_env_vars():
|
|
|
134
234
|
assert "JAX_PLATFORMS" in env_vars
|
|
135
235
|
assert "TPU_NAME" in env_vars
|
|
136
236
|
assert "SKIP_JAX_PRECOMPILE" in env_vars
|
|
237
|
+
assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars
|
|
137
238
|
assert "MODEL_IMPL_TYPE" in env_vars
|
|
138
239
|
|
|
139
240
|
|
|
@@ -141,11 +242,8 @@ def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
141
242
|
monkeypatch.setenv("TPU_WORKER_ID", "0")
|
|
142
243
|
assert envs.TPU_WORKER_ID == "0"
|
|
143
244
|
|
|
144
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "
|
|
145
|
-
assert envs.TPU_MULTIHOST_BACKEND == "
|
|
146
|
-
|
|
147
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
|
|
148
|
-
assert envs.TPU_MULTIHOST_BACKEND == "xla"
|
|
245
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "ray")
|
|
246
|
+
assert envs.TPU_MULTIHOST_BACKEND == "ray"
|
|
149
247
|
|
|
150
248
|
|
|
151
249
|
def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
@@ -158,7 +256,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
158
256
|
|
|
159
257
|
def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
|
|
160
258
|
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
|
|
161
|
-
assert envs.MODEL_IMPL_TYPE == "
|
|
259
|
+
assert envs.MODEL_IMPL_TYPE == "auto"
|
|
162
260
|
|
|
163
261
|
|
|
164
262
|
def test_cache_preserves_values_across_env_changes(
|
tests/test_quantization.py
CHANGED
|
@@ -112,6 +112,8 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
|
|
|
112
112
|
self.mesh = Mesh(jax.devices(), ('model', ))
|
|
113
113
|
self.rng = jax.random.PRNGKey(0)
|
|
114
114
|
self.model = SimpleModel(rngs=nnx.Rngs(0))
|
|
115
|
+
self.model.vllm_config = MagicMock()
|
|
116
|
+
self.model.vllm_config.model_config.use_mla = False
|
|
115
117
|
|
|
116
118
|
self.qwix_config = [
|
|
117
119
|
{
|
|
@@ -131,6 +133,7 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
|
|
|
131
133
|
"""Test that qwix.quantize_model is called with the correct arguments."""
|
|
132
134
|
quantized_model_mock = MagicMock(spec=nnx.Module)
|
|
133
135
|
mock_quantize_model.return_value = quantized_model_mock
|
|
136
|
+
self.model.vllm_config.sharding_config.total_dp_size = 1
|
|
134
137
|
|
|
135
138
|
with patch(
|
|
136
139
|
"tpu_inference.models.jax.utils.quantization.quantization_utils.init_logger",
|
tests/test_utils.py
CHANGED
|
@@ -231,6 +231,5 @@ def test_get_jax_dtype_from_str_dtype():
|
|
|
231
231
|
assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
|
|
232
232
|
assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
|
|
233
233
|
assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
|
|
234
|
-
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.
|
|
234
|
+
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3fn
|
|
235
235
|
assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
|
|
236
|
-
assert get_jax_dtype_from_str_dtype("auto") is None
|
|
@@ -457,7 +457,6 @@ class TPUConnectorWorker:
|
|
|
457
457
|
self.side_channel_port = get_side_channel_port()
|
|
458
458
|
|
|
459
459
|
self.kv_transfer_server = None
|
|
460
|
-
self._maybe_start_p2p_server()
|
|
461
460
|
self.zmq_cxt = zmq.Context()
|
|
462
461
|
if self.is_producer:
|
|
463
462
|
ready_event = threading.Event()
|
|
@@ -499,6 +498,7 @@ class TPUConnectorWorker:
|
|
|
499
498
|
self.shape = list(kv_layer.shape)
|
|
500
499
|
self.dtype = kv_layer.dtype
|
|
501
500
|
self.sharding = kv_layer.sharding
|
|
501
|
+
self._maybe_start_p2p_server()
|
|
502
502
|
|
|
503
503
|
def _maybe_start_p2p_server(self):
|
|
504
504
|
if self.kv_transfer_server is not None:
|
tpu_inference/envs.py
CHANGED
|
@@ -15,13 +15,88 @@ if TYPE_CHECKING:
|
|
|
15
15
|
PREFILL_SLICES: str = ""
|
|
16
16
|
DECODE_SLICES: str = ""
|
|
17
17
|
SKIP_JAX_PRECOMPILE: bool = False
|
|
18
|
-
|
|
18
|
+
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
|
19
|
+
MODEL_IMPL_TYPE: str = "auto"
|
|
19
20
|
NEW_MODEL_DESIGN: bool = False
|
|
20
21
|
PHASED_PROFILING_DIR: str = ""
|
|
21
22
|
PYTHON_TRACER_LEVEL: int = 1
|
|
22
23
|
USE_MOE_EP_KERNEL: bool = False
|
|
24
|
+
NUM_SLICES: int = 1
|
|
23
25
|
RAY_USAGE_STATS_ENABLED: str = "0"
|
|
24
26
|
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
|
|
27
|
+
ENABLE_QUANTIZED_MATMUL_KERNEL: bool = False
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def env_with_choices(
|
|
31
|
+
env_name: str,
|
|
32
|
+
default: str | None,
|
|
33
|
+
choices: list[str] | Callable[[], list[str]],
|
|
34
|
+
case_sensitive: bool = True,
|
|
35
|
+
) -> Callable[[], str | None]:
|
|
36
|
+
"""
|
|
37
|
+
Create a lambda that validates environment variable against allowed choices
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
env_name: Name of the environment variable
|
|
41
|
+
default: Default value if not set (can be None)
|
|
42
|
+
choices: List of valid string options or callable that returns list
|
|
43
|
+
case_sensitive: Whether validation should be case sensitive
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Lambda function for environment_variables dict
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def _get_validated_env() -> str | None:
|
|
50
|
+
value = os.getenv(env_name)
|
|
51
|
+
if value is None:
|
|
52
|
+
return default
|
|
53
|
+
|
|
54
|
+
# Resolve choices if it's a callable (for lazy loading)
|
|
55
|
+
actual_choices = choices() if callable(choices) else choices
|
|
56
|
+
|
|
57
|
+
if not case_sensitive:
|
|
58
|
+
check_value = value.lower()
|
|
59
|
+
check_choices = [choice.lower() for choice in actual_choices]
|
|
60
|
+
else:
|
|
61
|
+
check_value = value
|
|
62
|
+
check_choices = actual_choices
|
|
63
|
+
|
|
64
|
+
if check_value not in check_choices:
|
|
65
|
+
raise ValueError(f"Invalid value '{value}' for {env_name}. "
|
|
66
|
+
f"Valid options: {actual_choices}.")
|
|
67
|
+
|
|
68
|
+
return value
|
|
69
|
+
|
|
70
|
+
return _get_validated_env
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
|
|
74
|
+
"""
|
|
75
|
+
Accepts both numeric strings ("0", "1") and boolean strings
|
|
76
|
+
("true", "false", "True", "False").
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
env_name: Name of the environment variable
|
|
80
|
+
default: Default boolean value if not set
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def _get_bool_env() -> bool:
|
|
84
|
+
value = os.getenv(env_name)
|
|
85
|
+
if value is None or value == "":
|
|
86
|
+
return default
|
|
87
|
+
|
|
88
|
+
value_lower = value.lower()
|
|
89
|
+
if value_lower in ("true", "1"):
|
|
90
|
+
return True
|
|
91
|
+
elif value_lower in ("false", "0"):
|
|
92
|
+
return False
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Invalid boolean value '{value}' for {env_name}. "
|
|
96
|
+
f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
|
|
97
|
+
|
|
98
|
+
return _get_bool_env
|
|
99
|
+
|
|
25
100
|
|
|
26
101
|
environment_variables: dict[str, Callable[[], Any]] = {
|
|
27
102
|
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
|
|
@@ -38,7 +113,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
38
113
|
lambda: os.getenv("TPU_WORKER_ID", None),
|
|
39
114
|
# Backend for multi-host communication on TPU
|
|
40
115
|
"TPU_MULTIHOST_BACKEND":
|
|
41
|
-
|
|
116
|
+
env_with_choices("TPU_MULTIHOST_BACKEND", "", ["ray"]),
|
|
42
117
|
# Slice configuration for disaggregated prefill workers
|
|
43
118
|
"PREFILL_SLICES":
|
|
44
119
|
lambda: os.getenv("PREFILL_SLICES", ""),
|
|
@@ -47,28 +122,37 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
47
122
|
lambda: os.getenv("DECODE_SLICES", ""),
|
|
48
123
|
# Skip JAX precompilation step during initialization
|
|
49
124
|
"SKIP_JAX_PRECOMPILE":
|
|
50
|
-
|
|
125
|
+
env_bool("SKIP_JAX_PRECOMPILE", default=False),
|
|
126
|
+
# Check for XLA recompilation during execution
|
|
127
|
+
"VLLM_XLA_CHECK_RECOMPILATION":
|
|
128
|
+
env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
|
|
51
129
|
# Model implementation type (e.g., "flax_nnx")
|
|
52
130
|
"MODEL_IMPL_TYPE":
|
|
53
|
-
|
|
131
|
+
env_with_choices("MODEL_IMPL_TYPE", "auto",
|
|
132
|
+
["auto", "vllm", "flax_nnx", "jetpack"]),
|
|
54
133
|
# Enable new experimental model design
|
|
55
134
|
"NEW_MODEL_DESIGN":
|
|
56
|
-
|
|
135
|
+
env_bool("NEW_MODEL_DESIGN", default=False),
|
|
57
136
|
# Directory to store phased profiling output
|
|
58
137
|
"PHASED_PROFILING_DIR":
|
|
59
138
|
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
|
|
60
139
|
# Python tracer level for profiling
|
|
61
140
|
"PYTHON_TRACER_LEVEL":
|
|
62
|
-
lambda: int(os.getenv("PYTHON_TRACER_LEVEL"
|
|
141
|
+
lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
|
|
63
142
|
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
|
|
64
143
|
"USE_MOE_EP_KERNEL":
|
|
65
|
-
|
|
144
|
+
env_bool("USE_MOE_EP_KERNEL", default=False),
|
|
145
|
+
# Number of TPU slices for multi-slice mesh
|
|
146
|
+
"NUM_SLICES":
|
|
147
|
+
lambda: int(os.getenv("NUM_SLICES") or "1"),
|
|
66
148
|
# Enable/disable Ray usage statistics collection
|
|
67
149
|
"RAY_USAGE_STATS_ENABLED":
|
|
68
150
|
lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
|
|
69
151
|
# Ray compiled DAG channel type for TPU
|
|
70
152
|
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
|
|
71
|
-
|
|
153
|
+
env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
|
|
154
|
+
"ENABLE_QUANTIZED_MATMUL_KERNEL":
|
|
155
|
+
lambda: bool(int(os.getenv("ENABLE_QUANTIZED_MATMUL_KERNEL") or "0")),
|
|
72
156
|
}
|
|
73
157
|
|
|
74
158
|
|
|
@@ -136,10 +136,14 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
136
136
|
|
|
137
137
|
pp_size = self.parallel_config.pipeline_parallel_size
|
|
138
138
|
placement_group_specs: List[Dict[str, float]] = []
|
|
139
|
+
|
|
140
|
+
ray_nodes = ray.nodes()
|
|
141
|
+
logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
|
|
142
|
+
|
|
139
143
|
if pp_size == 1:
|
|
140
144
|
placement_group_specs = [{
|
|
141
145
|
device_str: node['Resources'][device_str]
|
|
142
|
-
} for node in
|
|
146
|
+
} for node in ray_nodes]
|
|
143
147
|
else:
|
|
144
148
|
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
|
|
145
149
|
placement_group_specs = [{
|
|
@@ -540,12 +540,16 @@ def get_vmem_estimate_bytes(
|
|
|
540
540
|
"""Returns the total vmem bytes used by the kernel."""
|
|
541
541
|
m_per_device = m // tp_size
|
|
542
542
|
n_per_device = n // tp_size
|
|
543
|
-
y_vmem_bytes = n_per_device * k * dtypes.bit_width(y_dtype)
|
|
543
|
+
y_vmem_bytes = (n_per_device * k * (dtypes.bit_width(y_dtype) if hasattr(
|
|
544
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(y_dtype)) // 8)
|
|
544
545
|
total_bytes = (
|
|
545
|
-
2 * m_per_device * k *
|
|
546
|
-
|
|
546
|
+
2 * m_per_device * k *
|
|
547
|
+
(dtypes.bit_width(x_dtype) if hasattr(dtypes, "bit_width") else
|
|
548
|
+
dtypes.itemsize_bits(x_dtype)) // 8 # x_vmem_scratch_ref
|
|
547
549
|
+ y_vmem_bytes # y_vmem_scratch_ref
|
|
548
|
-
+ 2 * m * bn *
|
|
550
|
+
+ 2 * m * bn *
|
|
551
|
+
(dtypes.bit_width(out_dtype) if hasattr(dtypes, "bit_width") else
|
|
552
|
+
dtypes.itemsize_bits(out_dtype)) // 8 # o_vmem_scratch_ref
|
|
549
553
|
+ acc_bytes # acc_vmem_scratch_ref, jnp.float32
|
|
550
554
|
)
|
|
551
555
|
return total_bytes
|
|
@@ -639,8 +643,10 @@ def all_gather_matmul(
|
|
|
639
643
|
# NOTE(chengjiyao): acc buffer is not used in the grid_k == 1 case.
|
|
640
644
|
if grid_k == 1:
|
|
641
645
|
acc_shape = (8, 128)
|
|
642
|
-
acc_bytes =
|
|
643
|
-
|
|
646
|
+
acc_bytes = (
|
|
647
|
+
acc_shape[0] *
|
|
648
|
+
acc_shape[1] * (dtypes.bit_width(jnp.float32) if hasattr(
|
|
649
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(jnp.float32)) // 8)
|
|
644
650
|
y_vmem_shape = (n_per_device, k) if rhs_transpose else (k, n_per_device)
|
|
645
651
|
estimated_vmem_bytes = get_vmem_estimate_bytes(
|
|
646
652
|
m,
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
"""All-gather matmul kernel's tuned block sizes."""
|
|
3
3
|
|
|
4
|
+
import re
|
|
5
|
+
|
|
4
6
|
import jax
|
|
5
7
|
|
|
6
8
|
# key:
|
|
@@ -32,8 +34,11 @@ def get_tpu_version() -> int:
|
|
|
32
34
|
return -1
|
|
33
35
|
if kind.endswith(' lite'):
|
|
34
36
|
kind = kind[:-len(' lite')]
|
|
35
|
-
|
|
36
|
-
|
|
37
|
+
|
|
38
|
+
# v6: "TPU v6"
|
|
39
|
+
# v7: "TPU7x"
|
|
40
|
+
assert kind[:3] == 'TPU', kind
|
|
41
|
+
return int(re.search(r'\d+', kind).group())
|
|
37
42
|
|
|
38
43
|
|
|
39
44
|
def get_key(
|