tpu-inference 0.12.0.dev20251222__py3-none-any.whl → 0.12.0.dev20251224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/core/test_dp_scheduler.py +128 -71
- tests/e2e/test_data_parallel.py +176 -280
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_speculative_decoding.py +26 -6
- tests/layers/jax/test_qwix.py +1 -1
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +36 -21
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +36 -21
- tests/layers/vllm/test_mxfp4.py +25 -10
- tests/layers/vllm/test_unquantized.py +61 -31
- tests/layers/vllm/utils.py +19 -4
- tests/models/common/test_model_loader.py +2 -2
- tests/models/jax/test_qwen2_5_vl.py +10 -11
- tests/runner/test_multimodal_manager.py +3 -3
- tests/runner/test_tpu_runner.py +67 -8
- tests/runner/test_tpu_runner_dp.py +66 -0
- tpu_inference/core/sched/dp_scheduler.py +65 -40
- tpu_inference/kernels/mla/v1/kernel.py +7 -26
- tpu_inference/layers/common/sharding.py +8 -3
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +3 -3
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +3 -3
- tpu_inference/layers/jax/attention/llama4_attention.py +3 -4
- tpu_inference/layers/jax/sample/sampling.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +51 -47
- tpu_inference/layers/vllm/quantization/common.py +14 -13
- tpu_inference/layers/vllm/quantization/mxfp4.py +21 -7
- tpu_inference/layers/vllm/quantization/unquantized.py +19 -7
- tpu_inference/layers/vllm/sharding.py +7 -4
- tpu_inference/models/common/model_loader.py +11 -14
- tpu_inference/models/jax/llama3.py +13 -10
- tpu_inference/models/jax/llama_guard_4.py +1 -1
- tpu_inference/models/jax/qwen2.py +3 -2
- tpu_inference/models/jax/qwen2_5_vl.py +4 -4
- tpu_inference/models/jax/utils/multi_modal_utils.py +4 -4
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +3 -3
- tpu_inference/models/vllm/vllm_model_wrapper.py +5 -2
- tpu_inference/platforms/tpu_platform.py +7 -7
- tpu_inference/runner/compilation_manager.py +43 -33
- tpu_inference/runner/kv_cache_manager.py +1 -2
- tpu_inference/runner/multimodal_manager.py +1 -1
- tpu_inference/runner/tpu_runner.py +12 -9
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/tpu_worker.py +5 -2
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/METADATA +1 -1
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/RECORD +47 -46
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from dataclasses import asdict
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
from vllm import LLM, EngineArgs, SamplingParams
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.fixture
|
|
12
|
+
def model_name():
|
|
13
|
+
"""Choose gemma-27b as the test model as it has both full attention and
|
|
14
|
+
sliding window attention."""
|
|
15
|
+
return "google/gemma-3-27b-it"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def test_prompts():
|
|
20
|
+
"""Simple test prompts for hybrid kv cache testing."""
|
|
21
|
+
return [
|
|
22
|
+
"Hello, my name is",
|
|
23
|
+
"The capital of France is",
|
|
24
|
+
"The colors of the rainbow are",
|
|
25
|
+
"The future of AI is",
|
|
26
|
+
"The president of the United States is",
|
|
27
|
+
"How many players are on a standard soccer team?",
|
|
28
|
+
"In Greek mythology, who is the god of the sea?",
|
|
29
|
+
"What is the capital of Australia?",
|
|
30
|
+
"What is the largest planet in our solar system?",
|
|
31
|
+
"Who developed the theory of general relativity?",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture
|
|
36
|
+
def sampling_params():
|
|
37
|
+
"""Standard sampling parameters for testing."""
|
|
38
|
+
return SamplingParams(
|
|
39
|
+
temperature=0.0,
|
|
40
|
+
max_tokens=32,
|
|
41
|
+
ignore_eos=True,
|
|
42
|
+
logprobs=1,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _run_inference_with_config(
|
|
47
|
+
model_name: str,
|
|
48
|
+
test_prompts: list,
|
|
49
|
+
sampling_params: SamplingParams,
|
|
50
|
+
tensor_parallel_size: int = 4,
|
|
51
|
+
kv_cache_dtype: str = "auto",
|
|
52
|
+
enable_prefix_caching: bool = False,
|
|
53
|
+
disable_hybrid_kv_cache_manager: bool = False) -> list:
|
|
54
|
+
"""Helper function to run inference with specified configuration."""
|
|
55
|
+
|
|
56
|
+
# Create LLM args using parser-based approach similar to offline_inference.py
|
|
57
|
+
engine_args = EngineArgs(
|
|
58
|
+
model=model_name,
|
|
59
|
+
max_model_len=64,
|
|
60
|
+
tensor_parallel_size=tensor_parallel_size,
|
|
61
|
+
gpu_memory_utilization=0.95,
|
|
62
|
+
max_num_batched_tokens=256,
|
|
63
|
+
max_num_seqs=16,
|
|
64
|
+
enable_prefix_caching=enable_prefix_caching,
|
|
65
|
+
kv_cache_dtype=kv_cache_dtype,
|
|
66
|
+
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
engine_args_dict = asdict(engine_args)
|
|
70
|
+
llm = LLM(**engine_args_dict)
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
outputs = llm.generate(test_prompts, sampling_params)
|
|
74
|
+
return outputs
|
|
75
|
+
finally:
|
|
76
|
+
del llm
|
|
77
|
+
# Wait for TPUs to be released
|
|
78
|
+
time.sleep(10)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def test_hybrid_kv_cache(
|
|
82
|
+
model_name: str,
|
|
83
|
+
test_prompts: list,
|
|
84
|
+
sampling_params: SamplingParams,
|
|
85
|
+
):
|
|
86
|
+
"""
|
|
87
|
+
Test hybrid kv cache works on gemma vLLM models.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
os.environ['MODEL_IMPL_TYPE'] = 'vllm'
|
|
91
|
+
# Test with hybrid kv cache alloctaion enabled.
|
|
92
|
+
outputs = _run_inference_with_config(
|
|
93
|
+
model_name=model_name,
|
|
94
|
+
test_prompts=test_prompts,
|
|
95
|
+
sampling_params=sampling_params,
|
|
96
|
+
disable_hybrid_kv_cache_manager=False,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Verify we got outputs for all prompts
|
|
100
|
+
assert len(outputs) == len(test_prompts)
|
|
101
|
+
|
|
102
|
+
# Verify each output has generated text
|
|
103
|
+
for output in outputs:
|
|
104
|
+
assert len(output.outputs) > 0
|
|
105
|
+
assert len(output.outputs[0].text.strip()) > 0
|
|
106
|
+
|
|
107
|
+
print(f"✓ Hybrid KV cache test passed with {len(outputs)} outputs")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def test_hybrid_kv_cache_correctness(
|
|
111
|
+
model_name: str,
|
|
112
|
+
test_prompts: list,
|
|
113
|
+
sampling_params: SamplingParams,
|
|
114
|
+
):
|
|
115
|
+
"""
|
|
116
|
+
Test that hybrid kv cache allocation produces consistent results compared
|
|
117
|
+
to standard kv cache allocation.
|
|
118
|
+
"""
|
|
119
|
+
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
|
|
120
|
+
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
|
|
121
|
+
|
|
122
|
+
small_prompts = test_prompts
|
|
123
|
+
|
|
124
|
+
# Run baseline (no hybrid kv cache)
|
|
125
|
+
baseline_outputs = _run_inference_with_config(
|
|
126
|
+
model_name=model_name,
|
|
127
|
+
test_prompts=small_prompts,
|
|
128
|
+
sampling_params=sampling_params,
|
|
129
|
+
disable_hybrid_kv_cache_manager=True,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Run with hybrid kv cache enabled.
|
|
133
|
+
hybrid_kvcache_outputs = _run_inference_with_config(
|
|
134
|
+
model_name=model_name,
|
|
135
|
+
test_prompts=small_prompts,
|
|
136
|
+
sampling_params=sampling_params,
|
|
137
|
+
disable_hybrid_kv_cache_manager=False,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Compare outputs - in theory they should be identical for greedy sampling
|
|
141
|
+
# in reality there may be some differences, but overall the outputs should
|
|
142
|
+
# be very similar.
|
|
143
|
+
|
|
144
|
+
# an example:
|
|
145
|
+
# prompt: What is the capital of Australia?
|
|
146
|
+
# both answers should be acceptable.
|
|
147
|
+
# The capital of Australia is Canberra. It is located in the Australian Capital Territory (ACT) and is home to many
|
|
148
|
+
# Canberra is the capital of Australia. It is located in the Australian Capital Territory (ACT) and is home to
|
|
149
|
+
assert len(baseline_outputs) == len(hybrid_kvcache_outputs)
|
|
150
|
+
|
|
151
|
+
text_matches = 0
|
|
152
|
+
text_mismatches = 0
|
|
153
|
+
logprob_mismatches = 0
|
|
154
|
+
max_logprob_diff = 0.0
|
|
155
|
+
|
|
156
|
+
for i, (baseline, hybrid_kvcache_result) in enumerate(
|
|
157
|
+
zip(baseline_outputs, hybrid_kvcache_outputs)):
|
|
158
|
+
baseline_text = baseline.outputs[0].text.strip()
|
|
159
|
+
hybrid_kvcache_text = hybrid_kvcache_result.outputs[0].text.strip()
|
|
160
|
+
|
|
161
|
+
# Check text output
|
|
162
|
+
if baseline_text == hybrid_kvcache_text:
|
|
163
|
+
text_matches += 1
|
|
164
|
+
else:
|
|
165
|
+
text_mismatches += 1
|
|
166
|
+
print(f"Text mismatch found in prompt {i}:")
|
|
167
|
+
print(f" Baseline: {baseline_text}")
|
|
168
|
+
print(f" Hybrid KV Cache: {hybrid_kvcache_text}")
|
|
169
|
+
|
|
170
|
+
# Check log probabilities
|
|
171
|
+
baseline_logprobs = baseline.outputs[0].logprobs
|
|
172
|
+
hybrid_kvcache_logprobs = hybrid_kvcache_result.outputs[0].logprobs
|
|
173
|
+
if baseline_logprobs is not None and hybrid_kvcache_logprobs is not None:
|
|
174
|
+
# Compare log probabilities for each token
|
|
175
|
+
assert len(baseline_logprobs) == len(hybrid_kvcache_logprobs), \
|
|
176
|
+
f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(hybrid_kvcache_logprobs)}"
|
|
177
|
+
for token_idx, (base_lp, hybrid_kvcache_lp) in enumerate(
|
|
178
|
+
zip(baseline_logprobs, hybrid_kvcache_logprobs)):
|
|
179
|
+
# Get the top logprob value for the selected token
|
|
180
|
+
if base_lp and hybrid_kvcache_lp:
|
|
181
|
+
# Get the top token's logprob from each
|
|
182
|
+
base_top_token = list(base_lp.keys())[0]
|
|
183
|
+
hybrid_kvcache_top_token = list(
|
|
184
|
+
hybrid_kvcache_lp.keys())[0]
|
|
185
|
+
|
|
186
|
+
base_logprob_val = base_lp[base_top_token].logprob
|
|
187
|
+
hybrid_kvcache_logprob_val = hybrid_kvcache_lp[
|
|
188
|
+
hybrid_kvcache_top_token].logprob
|
|
189
|
+
|
|
190
|
+
# Calculate absolute difference
|
|
191
|
+
diff = abs(base_logprob_val - hybrid_kvcache_logprob_val)
|
|
192
|
+
max_logprob_diff = max(max_logprob_diff, diff)
|
|
193
|
+
|
|
194
|
+
# Allow small numerical differences (e.g., 1e-3)
|
|
195
|
+
if diff > 1e-3:
|
|
196
|
+
logprob_mismatches += 1
|
|
197
|
+
print(
|
|
198
|
+
f"Logprob mismatch in prompt {i}, token {token_idx}:"
|
|
199
|
+
)
|
|
200
|
+
print(
|
|
201
|
+
f" Baseline token: {base_top_token}, logprob: {base_logprob_val:.6f}"
|
|
202
|
+
)
|
|
203
|
+
print(
|
|
204
|
+
f" Hybrid KV Cache token: {hybrid_kvcache_top_token}, logprob: {hybrid_kvcache_logprob_val:.6f}"
|
|
205
|
+
)
|
|
206
|
+
print(f" Difference: {diff:.6f}")
|
|
207
|
+
|
|
208
|
+
print("✓ Correctness test results:")
|
|
209
|
+
print(f" Text: {text_matches} matches, {text_mismatches} mismatches")
|
|
210
|
+
print(f" Max logprob difference: {max_logprob_diff:.6e}")
|
|
211
|
+
print(f" Significant logprob mismatches (>1e-3): {logprob_mismatches}")
|
|
212
|
+
|
|
213
|
+
# Allow for some variance due to potential numerical differences
|
|
214
|
+
# but most outputs should match with greedy sampling
|
|
215
|
+
text_match_rate = text_matches / len(baseline_outputs)
|
|
216
|
+
assert text_match_rate >= 0.9, f"Text match rate {text_match_rate:.2%} is too low"
|
|
217
|
+
|
|
218
|
+
# Log probabilities should be very close (allow small numerical errors)
|
|
219
|
+
assert max_logprob_diff < 2, f"Max logprob difference {max_logprob_diff} is too large"
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
|
+
import os
|
|
17
18
|
import random
|
|
18
19
|
import string
|
|
19
20
|
import time
|
|
@@ -22,6 +23,19 @@ import pytest
|
|
|
22
23
|
from vllm import LLM, SamplingParams
|
|
23
24
|
|
|
24
25
|
|
|
26
|
+
# TODO (Qiliang Cui): remove this when XLA fixes the recursive jit call issue.
|
|
27
|
+
def _is_v7x():
|
|
28
|
+
# jax.devices() will hang so use IS_FOR_V7X to indicate the version.
|
|
29
|
+
return os.environ.get("IS_FOR_V7X", "false") == "true"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _get_tensor_parallel_size():
|
|
33
|
+
# Work around an XLA issue.
|
|
34
|
+
if _is_v7x():
|
|
35
|
+
return 2
|
|
36
|
+
return 1
|
|
37
|
+
|
|
38
|
+
|
|
25
39
|
def get_ngram_test_prompts():
|
|
26
40
|
num_prompts = 100
|
|
27
41
|
prompts = []
|
|
@@ -87,7 +101,10 @@ def _test_correctness_helper(
|
|
|
87
101
|
with monkeypatch.context():
|
|
88
102
|
test_prompts = get_test_prompts(speculative_config)
|
|
89
103
|
|
|
90
|
-
ref_llm = LLM(model=model_name,
|
|
104
|
+
ref_llm = LLM(model=model_name,
|
|
105
|
+
max_model_len=1024,
|
|
106
|
+
max_num_seqs=4,
|
|
107
|
+
tensor_parallel_size=_get_tensor_parallel_size())
|
|
91
108
|
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
|
|
92
109
|
|
|
93
110
|
del ref_llm
|
|
@@ -98,7 +115,8 @@ def _test_correctness_helper(
|
|
|
98
115
|
spec_llm = LLM(model=model_name,
|
|
99
116
|
speculative_config=speculative_config,
|
|
100
117
|
max_model_len=1024,
|
|
101
|
-
max_num_seqs=4
|
|
118
|
+
max_num_seqs=4,
|
|
119
|
+
tensor_parallel_size=_get_tensor_parallel_size())
|
|
102
120
|
spec_outputs = spec_llm.generate(test_prompts, sampling_config)
|
|
103
121
|
|
|
104
122
|
matches = 0
|
|
@@ -179,7 +197,8 @@ def _test_performance_helper(
|
|
|
179
197
|
ref_llm = LLM(model=model_name,
|
|
180
198
|
max_model_len=1024,
|
|
181
199
|
max_num_seqs=1,
|
|
182
|
-
enable_prefix_caching=False
|
|
200
|
+
enable_prefix_caching=False,
|
|
201
|
+
tensor_parallel_size=_get_tensor_parallel_size())
|
|
183
202
|
|
|
184
203
|
start_time = time.time()
|
|
185
204
|
_ = ref_llm.generate(test_prompts, sampling_config)
|
|
@@ -195,6 +214,7 @@ def _test_performance_helper(
|
|
|
195
214
|
speculative_config=speculative_config,
|
|
196
215
|
max_model_len=1024,
|
|
197
216
|
max_num_seqs=1,
|
|
217
|
+
tensor_parallel_size=_get_tensor_parallel_size(),
|
|
198
218
|
enable_prefix_caching=False)
|
|
199
219
|
|
|
200
220
|
start_time = time.time()
|
|
@@ -229,7 +249,7 @@ def test_ngram_performance_greedy(
|
|
|
229
249
|
"prompt_lookup_max": 2,
|
|
230
250
|
"prompt_lookup_min": 2,
|
|
231
251
|
"num_speculative_tokens": 4,
|
|
232
|
-
}, 3.0)
|
|
252
|
+
}, 1.2 if _is_v7x() else 3.0)
|
|
233
253
|
|
|
234
254
|
|
|
235
255
|
def test_ngram_performance_random(
|
|
@@ -251,7 +271,7 @@ def test_ngram_performance_random(
|
|
|
251
271
|
"prompt_lookup_max": 2,
|
|
252
272
|
"prompt_lookup_min": 2,
|
|
253
273
|
"num_speculative_tokens": 4,
|
|
254
|
-
}, 3.0)
|
|
274
|
+
}, 1.5 if _is_v7x() else 3.0)
|
|
255
275
|
|
|
256
276
|
|
|
257
277
|
def test_eagle3_correctness(
|
|
@@ -288,4 +308,4 @@ def test_eagle3_performance(
|
|
|
288
308
|
"model": "unkmaster/EAGLE3-LLaMA3.1-Instruct-8B",
|
|
289
309
|
"num_speculative_tokens": 2,
|
|
290
310
|
"draft_tensor_parallel_size": 1
|
|
291
|
-
}, 1.8)
|
|
311
|
+
}, 1.2 if _is_v7x() else 1.8)
|
tests/layers/jax/test_qwix.py
CHANGED
|
@@ -832,7 +832,7 @@ class TestGetDefaultQwixQuantizationConfig(unittest.TestCase):
|
|
|
832
832
|
# Patch the constants in the module where the function resides
|
|
833
833
|
self.patchers = [
|
|
834
834
|
patch(
|
|
835
|
-
"tpu_inference.models.jax.utils.qwix.qwix_utils.
|
|
835
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG",
|
|
836
836
|
self.mock_deepseek_config),
|
|
837
837
|
patch(
|
|
838
838
|
"tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_LLAMA4_FP8_CONFIG",
|
|
@@ -251,12 +251,16 @@ def test_loading_model(model, mesh):
|
|
|
251
251
|
|
|
252
252
|
@pytest.mark.parametrize("model", MODELS)
|
|
253
253
|
@pytest.mark.parametrize("bias", [False, True])
|
|
254
|
-
@pytest.mark.parametrize("
|
|
255
|
-
test_utils.get_spmd_mesh(1),
|
|
256
|
-
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
257
|
-
])
|
|
254
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
258
255
|
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
259
|
-
|
|
256
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
257
|
+
def test_row_parallel_linear(model, bias, num_devices, enable_sp,
|
|
258
|
+
enable_attn_dp):
|
|
259
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
260
|
+
if enable_attn_dp and num_devices < 2:
|
|
261
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
262
|
+
|
|
263
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
260
264
|
dtype = torch.bfloat16
|
|
261
265
|
|
|
262
266
|
engine_args = EngineArgs(
|
|
@@ -287,12 +291,16 @@ def test_row_parallel_linear(model, bias, mesh, enable_sp):
|
|
|
287
291
|
|
|
288
292
|
@pytest.mark.parametrize("model", MODELS)
|
|
289
293
|
@pytest.mark.parametrize("bias", [False, True])
|
|
290
|
-
@pytest.mark.parametrize("
|
|
291
|
-
test_utils.get_spmd_mesh(1),
|
|
292
|
-
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
293
|
-
])
|
|
294
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
294
295
|
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
295
|
-
|
|
296
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
297
|
+
def test_column_parallel_linear(model, bias, num_devices, enable_sp,
|
|
298
|
+
enable_attn_dp):
|
|
299
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
300
|
+
if enable_attn_dp and num_devices < 2:
|
|
301
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
302
|
+
|
|
303
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
296
304
|
dtype = torch.bfloat16
|
|
297
305
|
|
|
298
306
|
engine_args = EngineArgs(
|
|
@@ -324,13 +332,17 @@ def test_column_parallel_linear(model, bias, mesh, enable_sp):
|
|
|
324
332
|
|
|
325
333
|
@pytest.mark.parametrize("model", MODELS)
|
|
326
334
|
@pytest.mark.parametrize("bias", [False, True])
|
|
327
|
-
@pytest.mark.parametrize("
|
|
328
|
-
test_utils.get_spmd_mesh(1),
|
|
329
|
-
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
330
|
-
])
|
|
335
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
331
336
|
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
332
337
|
@pytest.mark.parametrize("fuse_matmuls", [False, True])
|
|
333
|
-
|
|
338
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
339
|
+
def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
|
|
340
|
+
enable_attn_dp):
|
|
341
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
342
|
+
if enable_attn_dp and num_devices < 2:
|
|
343
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
344
|
+
|
|
345
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
334
346
|
dtype = torch.bfloat16
|
|
335
347
|
|
|
336
348
|
engine_args = EngineArgs(
|
|
@@ -365,14 +377,17 @@ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls):
|
|
|
365
377
|
|
|
366
378
|
@pytest.mark.parametrize("model", MODELS)
|
|
367
379
|
@pytest.mark.parametrize("bias", [False, True])
|
|
368
|
-
@pytest.mark.parametrize("
|
|
369
|
-
test_utils.get_spmd_mesh(1),
|
|
370
|
-
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
371
|
-
])
|
|
380
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
372
381
|
@pytest.mark.parametrize("fuse_matmuls", [False, True])
|
|
373
382
|
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
374
|
-
|
|
375
|
-
|
|
383
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
384
|
+
def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
|
|
385
|
+
enable_sp, enable_attn_dp):
|
|
386
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
387
|
+
if enable_attn_dp and num_devices < 2:
|
|
388
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
389
|
+
|
|
390
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
376
391
|
dtype = torch.bfloat16
|
|
377
392
|
|
|
378
393
|
engine_args = EngineArgs(
|
|
@@ -138,12 +138,16 @@ def test_loading_model(model, mesh):
|
|
|
138
138
|
|
|
139
139
|
@pytest.mark.parametrize("model", MODELS)
|
|
140
140
|
@pytest.mark.parametrize("bias", [False, True])
|
|
141
|
-
@pytest.mark.parametrize("
|
|
142
|
-
test_utils.get_spmd_mesh(1),
|
|
143
|
-
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
144
|
-
])
|
|
141
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
145
142
|
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
146
|
-
|
|
143
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
144
|
+
def test_row_parallel_linear(model, bias, num_devices, enable_sp,
|
|
145
|
+
enable_attn_dp):
|
|
146
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
147
|
+
if enable_attn_dp and num_devices < 2:
|
|
148
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
149
|
+
|
|
150
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
147
151
|
|
|
148
152
|
dtype = torch.bfloat16
|
|
149
153
|
|
|
@@ -209,12 +213,16 @@ def test_row_parallel_linear(model, bias, mesh, enable_sp):
|
|
|
209
213
|
|
|
210
214
|
@pytest.mark.parametrize("model", MODELS)
|
|
211
215
|
@pytest.mark.parametrize("bias", [False, True])
|
|
212
|
-
@pytest.mark.parametrize("
|
|
213
|
-
test_utils.get_spmd_mesh(1),
|
|
214
|
-
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
215
|
-
])
|
|
216
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
216
217
|
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
217
|
-
|
|
218
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
219
|
+
def test_column_parallel_linear(model, bias, num_devices, enable_sp,
|
|
220
|
+
enable_attn_dp):
|
|
221
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
222
|
+
if enable_attn_dp and num_devices < 2:
|
|
223
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
224
|
+
|
|
225
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
218
226
|
dtype = torch.bfloat16
|
|
219
227
|
|
|
220
228
|
engine_args = EngineArgs(
|
|
@@ -280,13 +288,17 @@ def test_column_parallel_linear(model, bias, mesh, enable_sp):
|
|
|
280
288
|
|
|
281
289
|
@pytest.mark.parametrize("model", MODELS)
|
|
282
290
|
@pytest.mark.parametrize("bias", [False, True])
|
|
283
|
-
@pytest.mark.parametrize("
|
|
284
|
-
test_utils.get_spmd_mesh(1),
|
|
285
|
-
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
286
|
-
])
|
|
291
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
287
292
|
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
288
293
|
@pytest.mark.parametrize("fuse_matmuls", [False, True])
|
|
289
|
-
|
|
294
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
295
|
+
def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
|
|
296
|
+
enable_attn_dp):
|
|
297
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
298
|
+
if enable_attn_dp and num_devices < 2:
|
|
299
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
300
|
+
|
|
301
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
290
302
|
dtype = torch.bfloat16
|
|
291
303
|
|
|
292
304
|
engine_args = EngineArgs(
|
|
@@ -354,14 +366,17 @@ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls):
|
|
|
354
366
|
|
|
355
367
|
@pytest.mark.parametrize("model", MODELS)
|
|
356
368
|
@pytest.mark.parametrize("bias", [False, True])
|
|
357
|
-
@pytest.mark.parametrize("
|
|
358
|
-
test_utils.get_spmd_mesh(1),
|
|
359
|
-
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
360
|
-
])
|
|
369
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
361
370
|
@pytest.mark.parametrize("fuse_matmuls", [False, True])
|
|
362
371
|
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
363
|
-
|
|
364
|
-
|
|
372
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
373
|
+
def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
|
|
374
|
+
enable_sp, enable_attn_dp):
|
|
375
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
376
|
+
if enable_attn_dp and num_devices < 2:
|
|
377
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
378
|
+
|
|
379
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
365
380
|
dtype = torch.bfloat16
|
|
366
381
|
|
|
367
382
|
engine_args = EngineArgs(
|
tests/layers/vllm/test_mxfp4.py
CHANGED
|
@@ -119,17 +119,22 @@ def test_quant_override(model, mesh):
|
|
|
119
119
|
assert quant_config.mesh == mesh
|
|
120
120
|
|
|
121
121
|
|
|
122
|
-
@pytest.mark.parametrize(
|
|
123
|
-
"mesh", [test_utils.get_spmd_mesh(1),
|
|
124
|
-
test_utils.get_spmd_mesh(2)])
|
|
122
|
+
@pytest.mark.parametrize("num_devices", [1, 2])
|
|
125
123
|
@pytest.mark.parametrize("num_tokens", [8])
|
|
126
124
|
@pytest.mark.parametrize("intermediate_size", [1024])
|
|
127
125
|
@pytest.mark.parametrize("hidden_size", [128])
|
|
128
126
|
@pytest.mark.parametrize("num_experts", [8])
|
|
129
127
|
@pytest.mark.parametrize("topk", [2])
|
|
130
128
|
@pytest.mark.parametrize("use_ep", [True, False])
|
|
131
|
-
|
|
132
|
-
|
|
129
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
130
|
+
def test_mxfp4_fused_moe(num_devices, num_tokens, intermediate_size,
|
|
131
|
+
hidden_size, num_experts, topk, use_ep,
|
|
132
|
+
enable_attn_dp):
|
|
133
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
134
|
+
if enable_attn_dp and num_devices < 2:
|
|
135
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
136
|
+
|
|
137
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
133
138
|
torch.manual_seed(42)
|
|
134
139
|
dtype = torch.bfloat16
|
|
135
140
|
|
|
@@ -201,16 +206,26 @@ def test_mxfp4_fused_moe(mesh, num_tokens, intermediate_size, hidden_size,
|
|
|
201
206
|
rtol=1e-1)
|
|
202
207
|
|
|
203
208
|
|
|
204
|
-
@pytest.mark.parametrize(
|
|
205
|
-
"mesh", [test_utils.get_spmd_mesh(1),
|
|
206
|
-
test_utils.get_spmd_mesh(2)])
|
|
209
|
+
@pytest.mark.parametrize("num_devices", [1, 2])
|
|
207
210
|
@pytest.mark.parametrize("num_tokens", [8])
|
|
208
211
|
@pytest.mark.parametrize("intermediate_size", [512])
|
|
209
212
|
@pytest.mark.parametrize("hidden_size", [1024])
|
|
210
213
|
@pytest.mark.parametrize("num_experts", [8])
|
|
211
214
|
@pytest.mark.parametrize("topk", [2])
|
|
212
|
-
|
|
213
|
-
|
|
215
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
216
|
+
def test_mxfp4_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
|
|
217
|
+
hidden_size, num_experts, topk,
|
|
218
|
+
enable_attn_dp):
|
|
219
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
220
|
+
if enable_attn_dp and num_devices < 2:
|
|
221
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
222
|
+
|
|
223
|
+
# Skip attn_dp tests for fused_moe_use_kernel since the kernel only supports 2D mesh
|
|
224
|
+
if enable_attn_dp:
|
|
225
|
+
pytest.skip(
|
|
226
|
+
"fused_moe kernel does not support attn_dp (requires 2D mesh)")
|
|
227
|
+
|
|
228
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
214
229
|
|
|
215
230
|
torch.manual_seed(42)
|
|
216
231
|
dtype = torch.bfloat16
|