tpu-inference 0.12.0.dev20251222__py3-none-any.whl → 0.12.0.dev20251224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/core/test_dp_scheduler.py +128 -71
- tests/e2e/test_data_parallel.py +176 -280
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_speculative_decoding.py +26 -6
- tests/layers/jax/test_qwix.py +1 -1
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +36 -21
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +36 -21
- tests/layers/vllm/test_mxfp4.py +25 -10
- tests/layers/vllm/test_unquantized.py +61 -31
- tests/layers/vllm/utils.py +19 -4
- tests/models/common/test_model_loader.py +2 -2
- tests/models/jax/test_qwen2_5_vl.py +10 -11
- tests/runner/test_multimodal_manager.py +3 -3
- tests/runner/test_tpu_runner.py +67 -8
- tests/runner/test_tpu_runner_dp.py +66 -0
- tpu_inference/core/sched/dp_scheduler.py +65 -40
- tpu_inference/kernels/mla/v1/kernel.py +7 -26
- tpu_inference/layers/common/sharding.py +8 -3
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +3 -3
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +3 -3
- tpu_inference/layers/jax/attention/llama4_attention.py +3 -4
- tpu_inference/layers/jax/sample/sampling.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +51 -47
- tpu_inference/layers/vllm/quantization/common.py +14 -13
- tpu_inference/layers/vllm/quantization/mxfp4.py +21 -7
- tpu_inference/layers/vllm/quantization/unquantized.py +19 -7
- tpu_inference/layers/vllm/sharding.py +7 -4
- tpu_inference/models/common/model_loader.py +11 -14
- tpu_inference/models/jax/llama3.py +13 -10
- tpu_inference/models/jax/llama_guard_4.py +1 -1
- tpu_inference/models/jax/qwen2.py +3 -2
- tpu_inference/models/jax/qwen2_5_vl.py +4 -4
- tpu_inference/models/jax/utils/multi_modal_utils.py +4 -4
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +3 -3
- tpu_inference/models/vllm/vllm_model_wrapper.py +5 -2
- tpu_inference/platforms/tpu_platform.py +7 -7
- tpu_inference/runner/compilation_manager.py +43 -33
- tpu_inference/runner/kv_cache_manager.py +1 -2
- tpu_inference/runner/multimodal_manager.py +1 -1
- tpu_inference/runner/tpu_runner.py +12 -9
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/tpu_worker.py +5 -2
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/METADATA +1 -1
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/RECORD +47 -46
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/top_level.txt +0 -0
tests/e2e/test_data_parallel.py
CHANGED
|
@@ -3,38 +3,37 @@
|
|
|
3
3
|
|
|
4
4
|
import os
|
|
5
5
|
import time
|
|
6
|
-
from dataclasses import asdict
|
|
7
6
|
|
|
8
7
|
import pytest
|
|
9
|
-
from vllm import LLM,
|
|
8
|
+
from vllm import LLM, SamplingParams
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
@pytest.fixture(autouse=True)
|
|
13
12
|
def setup_new_model_design():
|
|
14
|
-
"""Automatically set NEW_MODEL_DESIGN=1 for all tests."""
|
|
15
13
|
os.environ['NEW_MODEL_DESIGN'] = '1'
|
|
14
|
+
os.environ['SKIP_JAX_PRECOMPILE'] = '0'
|
|
15
|
+
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '1'
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
@pytest.fixture
|
|
19
|
-
def test_prompts():
|
|
20
|
-
|
|
19
|
+
def test_prompts(num_prompts: int = 256) -> list:
|
|
20
|
+
base_text = (
|
|
21
|
+
"The rapid advancement of artificial intelligence has transformed numerous industries "
|
|
22
|
+
"and continues to reshape our understanding of technology's potential. Machine learning "
|
|
23
|
+
"algorithms have become increasingly sophisticated, enabling computers to perform tasks "
|
|
24
|
+
"that were once thought to require human intelligence. From natural language processing "
|
|
25
|
+
"to computer vision, AI systems are now capable of understanding context, recognizing "
|
|
26
|
+
"patterns, and making decisions with remarkable accuracy. " *
|
|
27
|
+
20 # Repeat to reach ~1k tokens
|
|
28
|
+
)
|
|
21
29
|
return [
|
|
22
|
-
"
|
|
23
|
-
|
|
24
|
-
"The colors of the rainbow are",
|
|
25
|
-
"The future of AI is",
|
|
26
|
-
"The president of the United States is",
|
|
27
|
-
"How many players are on a standard soccer team?",
|
|
28
|
-
"In Greek mythology, who is the god of the sea?",
|
|
29
|
-
"What is the capital of Australia?",
|
|
30
|
-
"What is the largest planet in our solar system?",
|
|
31
|
-
"Who developed the theory of general relativity?",
|
|
30
|
+
f"Prompt {i}: {base_text} What are your thoughts on this topic?"
|
|
31
|
+
for i in range(num_prompts)
|
|
32
32
|
]
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
@pytest.fixture
|
|
36
36
|
def sampling_params():
|
|
37
|
-
"""Standard sampling parameters for testing."""
|
|
38
37
|
return SamplingParams(
|
|
39
38
|
temperature=0.0,
|
|
40
39
|
max_tokens=32,
|
|
@@ -52,24 +51,18 @@ def _run_inference_with_config(model_name: str,
|
|
|
52
51
|
kv_cache_dtype: str = "auto",
|
|
53
52
|
enable_prefix_caching: bool = False,
|
|
54
53
|
async_scheduling: bool = False,
|
|
55
|
-
measure_time: bool = False,
|
|
56
54
|
max_model_len: int = 32,
|
|
57
55
|
max_num_batched_tokens: int = 128,
|
|
58
|
-
max_num_seqs: int = 16
|
|
59
|
-
|
|
56
|
+
max_num_seqs: int = 16,
|
|
57
|
+
gpu_memory_utilization: float = 0.90,
|
|
58
|
+
trace_dir: str = None) -> list:
|
|
60
59
|
|
|
61
|
-
|
|
62
|
-
If measure_time=True: (outputs, elapsed_time) tuple
|
|
63
|
-
If measure_time=False: outputs list
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
# Create LLM args using parser-based approach similar to offline_inference.py
|
|
67
|
-
engine_args = EngineArgs(
|
|
60
|
+
llm = LLM(
|
|
68
61
|
model=model_name,
|
|
69
62
|
max_model_len=max_model_len,
|
|
70
63
|
tensor_parallel_size=tensor_parallel_size,
|
|
71
64
|
data_parallel_size=data_parallel_size,
|
|
72
|
-
gpu_memory_utilization=
|
|
65
|
+
gpu_memory_utilization=gpu_memory_utilization,
|
|
73
66
|
max_num_batched_tokens=max_num_batched_tokens,
|
|
74
67
|
max_num_seqs=max_num_seqs,
|
|
75
68
|
enable_prefix_caching=enable_prefix_caching,
|
|
@@ -78,149 +71,112 @@ def _run_inference_with_config(model_name: str,
|
|
|
78
71
|
async_scheduling=async_scheduling,
|
|
79
72
|
)
|
|
80
73
|
|
|
81
|
-
|
|
82
|
-
|
|
74
|
+
start_time = time.time()
|
|
75
|
+
outputs = llm.generate(test_prompts, sampling_params)
|
|
76
|
+
elapsed_time = time.time() - start_time
|
|
83
77
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
elapsed_time = time.time() - start_time
|
|
88
|
-
if measure_time:
|
|
89
|
-
return outputs, elapsed_time
|
|
90
|
-
else:
|
|
91
|
-
return outputs
|
|
92
|
-
finally:
|
|
93
|
-
del llm
|
|
94
|
-
# Wait for TPUs to be released
|
|
95
|
-
time.sleep(5)
|
|
78
|
+
del llm
|
|
79
|
+
time.sleep(10)
|
|
80
|
+
return outputs, elapsed_time
|
|
96
81
|
|
|
97
82
|
|
|
98
|
-
def
|
|
99
|
-
|
|
100
|
-
Test that data parallelism provides performance improvements compared to baseline.
|
|
101
|
-
This test measures the execution time with 128 prompts of length ~1k tokens.
|
|
83
|
+
def _check_performance(test_name: str, baseline_time: float, dp_time: float,
|
|
84
|
+
num_prompts: int, tol: float):
|
|
102
85
|
|
|
103
|
-
|
|
104
|
-
"""
|
|
105
|
-
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '1'
|
|
106
|
-
os.environ['SKIP_JAX_PRECOMPILE'] = '0'
|
|
107
|
-
os.environ['MODEL_IMPL_TYPE'] = 'flax_nnx'
|
|
86
|
+
speedup = baseline_time / dp_time if dp_time > 0 else 0
|
|
108
87
|
|
|
109
|
-
|
|
88
|
+
print(f"✓ {test_name} performance test results:")
|
|
89
|
+
print(f" Number of prompts: {num_prompts}")
|
|
90
|
+
print(f" Baseline time: {baseline_time:.2f}s")
|
|
91
|
+
print(f" Data parallel time: {dp_time:.2f}s")
|
|
92
|
+
print(f" Speedup: {speedup:.2f}x")
|
|
93
|
+
print(f" Baseline throughput: {num_prompts/baseline_time:.2f} prompts/s")
|
|
94
|
+
print(f" Data parallel throughput: {num_prompts/dp_time:.2f} prompts/s")
|
|
110
95
|
|
|
111
|
-
|
|
112
|
-
# Creating a base prompt of about 1k tokens using repeated text
|
|
113
|
-
base_text = (
|
|
114
|
-
"The rapid advancement of artificial intelligence has transformed numerous industries "
|
|
115
|
-
"and continues to reshape our understanding of technology's potential. Machine learning "
|
|
116
|
-
"algorithms have become increasingly sophisticated, enabling computers to perform tasks "
|
|
117
|
-
"that were once thought to require human intelligence. From natural language processing "
|
|
118
|
-
"to computer vision, AI systems are now capable of understanding context, recognizing "
|
|
119
|
-
"patterns, and making decisions with remarkable accuracy. " *
|
|
120
|
-
20 # Repeat to reach ~1k tokens
|
|
121
|
-
)
|
|
96
|
+
assert speedup >= tol, f"Data parallelism did not provide expected speedup ({tol:.2f}x): {speedup:.2f}x"
|
|
122
97
|
|
|
123
|
-
# Create 128 prompts with slight variations
|
|
124
|
-
long_prompts = [
|
|
125
|
-
f"Prompt {i}: {base_text} What are your thoughts on this topic?"
|
|
126
|
-
for i in range(128)
|
|
127
|
-
]
|
|
128
98
|
|
|
129
|
-
|
|
130
|
-
f"Generated {len(long_prompts)} prompts, approximate length: {len(base_text.split())} tokens each"
|
|
131
|
-
)
|
|
99
|
+
def _check_correctness(test_name, baseline_outputs, dp_outputs):
|
|
132
100
|
|
|
133
|
-
|
|
134
|
-
max_model_len = 2048
|
|
135
|
-
max_num_batched_tokens = 4096
|
|
136
|
-
max_num_seqs = 64
|
|
101
|
+
assert len(baseline_outputs) == len(dp_outputs)
|
|
137
102
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
sampling_params=sampling_params,
|
|
143
|
-
tensor_parallel_size=1,
|
|
144
|
-
data_parallel_size=1,
|
|
145
|
-
async_scheduling=True,
|
|
146
|
-
measure_time=True,
|
|
147
|
-
max_model_len=max_model_len,
|
|
148
|
-
max_num_batched_tokens=max_num_batched_tokens,
|
|
149
|
-
max_num_seqs=max_num_seqs,
|
|
150
|
-
)
|
|
103
|
+
text_matches = 0
|
|
104
|
+
logprob_matches = 0
|
|
105
|
+
total_compared_logprobs = 0
|
|
106
|
+
max_logprob_diff = 0.0
|
|
151
107
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
sampling_params=sampling_params,
|
|
157
|
-
tensor_parallel_size=1,
|
|
158
|
-
data_parallel_size=2,
|
|
159
|
-
async_scheduling=True,
|
|
160
|
-
measure_time=True,
|
|
161
|
-
max_model_len=max_model_len,
|
|
162
|
-
max_num_batched_tokens=max_num_batched_tokens,
|
|
163
|
-
max_num_seqs=max_num_seqs,
|
|
164
|
-
)
|
|
108
|
+
for i, (baseline, dp_result) in enumerate(zip(baseline_outputs,
|
|
109
|
+
dp_outputs)):
|
|
110
|
+
baseline_text = baseline.outputs[0].text.strip()
|
|
111
|
+
dp_text = dp_result.outputs[0].text.strip()
|
|
165
112
|
|
|
166
|
-
|
|
167
|
-
|
|
113
|
+
baseline_words = baseline_text.split()
|
|
114
|
+
dp_words = dp_text.split()
|
|
115
|
+
overlap_set = set(baseline_words) & set(dp_words)
|
|
116
|
+
match_percent = len(overlap_set) / len(set(baseline_words))
|
|
117
|
+
if match_percent >= 0.7:
|
|
118
|
+
text_matches += 1
|
|
168
119
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
f" Baseline throughput: {len(long_prompts)/baseline_time:.2f} prompts/s"
|
|
176
|
-
)
|
|
177
|
-
print(
|
|
178
|
-
f" Data parallel throughput: {len(long_prompts)/dp_time:.2f} prompts/s"
|
|
179
|
-
)
|
|
120
|
+
# Check text output
|
|
121
|
+
if baseline_text != dp_text:
|
|
122
|
+
print(f"Text mismatch found in prompt {i}:")
|
|
123
|
+
print(f" Baseline: {baseline_text}")
|
|
124
|
+
print(f" Data Parallel: {dp_text}")
|
|
125
|
+
print(f" Match percent: {match_percent:.2%}")
|
|
180
126
|
|
|
127
|
+
# Check log probabilities
|
|
128
|
+
baseline_logprobs = baseline.outputs[0].logprobs
|
|
129
|
+
dp_logprobs = dp_result.outputs[0].logprobs
|
|
181
130
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
model_impl_type: str,
|
|
187
|
-
):
|
|
188
|
-
"""
|
|
189
|
-
Test model-wise data parallelism where data=2 in the mesh axis.
|
|
190
|
-
This test verifies that the model can run with data parallelism enabled,
|
|
191
|
-
duplicating the entire model across 2 data parallel workers.
|
|
131
|
+
if baseline_logprobs is not None and dp_logprobs is not None:
|
|
132
|
+
# Compare log probabilities for each token
|
|
133
|
+
assert len(baseline_logprobs) == len(dp_logprobs), \
|
|
134
|
+
f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(dp_logprobs)}"
|
|
192
135
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
|
|
201
|
-
|
|
202
|
-
# Test with data parallelism enabled
|
|
203
|
-
outputs = _run_inference_with_config(
|
|
204
|
-
model_name=test_model,
|
|
205
|
-
test_prompts=test_prompts,
|
|
206
|
-
sampling_params=sampling_params,
|
|
207
|
-
tensor_parallel_size=1,
|
|
208
|
-
data_parallel_size=2,
|
|
209
|
-
async_scheduling=False,
|
|
210
|
-
)
|
|
136
|
+
for token_idx, (base_lp, dp_lp) in enumerate(
|
|
137
|
+
zip(baseline_logprobs, dp_logprobs)):
|
|
138
|
+
# Get the top logprob value for the selected token
|
|
139
|
+
if base_lp and dp_lp:
|
|
140
|
+
# Get the top token's logprob from each
|
|
141
|
+
base_top_token = list(base_lp.keys())[0]
|
|
142
|
+
dp_top_token = list(dp_lp.keys())[0]
|
|
211
143
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
144
|
+
# Only compare logprobs if tokens match
|
|
145
|
+
if base_top_token == dp_top_token:
|
|
146
|
+
base_logprob_val = base_lp[base_top_token].logprob
|
|
147
|
+
dp_logprob_val = dp_lp[dp_top_token].logprob
|
|
148
|
+
|
|
149
|
+
# Calculate absolute difference
|
|
150
|
+
diff = abs(base_logprob_val - dp_logprob_val)
|
|
151
|
+
max_logprob_diff = max(max_logprob_diff, diff)
|
|
152
|
+
|
|
153
|
+
total_compared_logprobs += 1
|
|
154
|
+
# Count as match if difference is small
|
|
155
|
+
if diff < 0.1:
|
|
156
|
+
logprob_matches += 1
|
|
157
|
+
else:
|
|
158
|
+
print(
|
|
159
|
+
f" Logprob mismatch in prompt {i}, token {token_idx}: "
|
|
160
|
+
f"Baseline logprob={base_logprob_val}, "
|
|
161
|
+
f"Data Parallel logprob={dp_logprob_val}, "
|
|
162
|
+
f"Diff={diff:.6e}")
|
|
163
|
+
|
|
164
|
+
print(f"✓ {test_name} correctness test results:")
|
|
165
|
+
print(f" Text: {text_matches} matches (match percent >= 70%)")
|
|
166
|
+
print(
|
|
167
|
+
f" Logprobs: {logprob_matches}/{total_compared_logprobs} ({logprob_matches / total_compared_logprobs:.2%}) matches (diff < 0.1)"
|
|
168
|
+
)
|
|
169
|
+
print(f" Max logprob difference: {max_logprob_diff:.6e}")
|
|
216
170
|
|
|
217
|
-
#
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
output.outputs[0].text.strip()) > 0, "Generated text is empty"
|
|
171
|
+
# Allow for some variance due to potential numerical differences
|
|
172
|
+
# but most outputs should match with greedy sampling
|
|
173
|
+
text_match_rate = text_matches / len(baseline_outputs)
|
|
174
|
+
assert text_match_rate >= 0.9, f"Text match rate {text_match_rate:.2%} is too low"
|
|
222
175
|
|
|
223
|
-
|
|
176
|
+
# Log probabilities should match for most matching tokens
|
|
177
|
+
if total_compared_logprobs > 0:
|
|
178
|
+
logprob_match_rate = logprob_matches / total_compared_logprobs
|
|
179
|
+
assert logprob_match_rate >= 0.9, f"Logprob match rate {logprob_match_rate:.2%} is too low"
|
|
224
180
|
|
|
225
181
|
|
|
226
182
|
def test_attention_data_parallelism(
|
|
@@ -228,166 +184,106 @@ def test_attention_data_parallelism(
|
|
|
228
184
|
sampling_params: SamplingParams,
|
|
229
185
|
):
|
|
230
186
|
"""
|
|
231
|
-
|
|
232
|
-
attn_dp=2 in the mesh axis. This is useful when num_kv_heads < TP to avoid
|
|
233
|
-
wasting KV cache memory.
|
|
234
|
-
|
|
235
|
-
Equivalent to:
|
|
236
|
-
python examples/offline_inference.py --tensor_parallel_size=4 --kv-cache-dtype=fp8 \
|
|
237
|
-
--additional_config='{"sharding":{"sharding_strategy": {"enable_dp_attention":1}}}'
|
|
187
|
+
Correctness and performance test for attention DP
|
|
238
188
|
"""
|
|
239
|
-
# Use Qwen3 0.6B for this test with reduced tensor parallelism
|
|
240
|
-
test_model = "Qwen/Qwen3-0.6B"
|
|
241
189
|
|
|
242
|
-
os.environ['MODEL_IMPL_TYPE'] = "
|
|
243
|
-
|
|
244
|
-
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
|
|
190
|
+
os.environ['MODEL_IMPL_TYPE'] = "vllm"
|
|
191
|
+
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
|
245
192
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
}
|
|
251
|
-
}
|
|
252
|
-
}
|
|
193
|
+
# Configuration for long sequences
|
|
194
|
+
max_model_len = 2048
|
|
195
|
+
max_num_batched_tokens = 4096
|
|
196
|
+
max_num_seqs = 128
|
|
253
197
|
|
|
254
|
-
#
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
model_name=test_model,
|
|
198
|
+
# Run with attn_dp=2 tp=2
|
|
199
|
+
dp_outputs, dp_time = _run_inference_with_config(
|
|
200
|
+
model_name=model_name,
|
|
258
201
|
test_prompts=test_prompts,
|
|
259
202
|
sampling_params=sampling_params,
|
|
260
203
|
tensor_parallel_size=4,
|
|
204
|
+
async_scheduling=False,
|
|
205
|
+
max_model_len=max_model_len,
|
|
206
|
+
max_num_batched_tokens=max_num_batched_tokens,
|
|
207
|
+
max_num_seqs=max_num_seqs,
|
|
208
|
+
additional_config={
|
|
209
|
+
"sharding": {
|
|
210
|
+
"sharding_strategy": {
|
|
211
|
+
"enable_dp_attention": 1
|
|
212
|
+
}
|
|
213
|
+
}
|
|
214
|
+
})
|
|
215
|
+
|
|
216
|
+
# Run baseline (tp=2)
|
|
217
|
+
baseline_outputs, baseline_time = _run_inference_with_config(
|
|
218
|
+
model_name=model_name,
|
|
219
|
+
test_prompts=test_prompts,
|
|
220
|
+
sampling_params=sampling_params,
|
|
221
|
+
tensor_parallel_size=2,
|
|
261
222
|
data_parallel_size=1,
|
|
262
|
-
|
|
263
|
-
|
|
223
|
+
async_scheduling=False,
|
|
224
|
+
max_model_len=max_model_len,
|
|
225
|
+
max_num_batched_tokens=max_num_batched_tokens,
|
|
226
|
+
max_num_seqs=max_num_seqs,
|
|
264
227
|
)
|
|
265
228
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
test_prompts
|
|
269
|
-
), f"Expected {len(test_prompts)} outputs, got {len(outputs)}"
|
|
229
|
+
_check_correctness("Attention data parallelism", baseline_outputs,
|
|
230
|
+
dp_outputs)
|
|
270
231
|
|
|
271
|
-
#
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
232
|
+
# Different hardware gives different performance. This test runs on v6e_8
|
|
233
|
+
_check_performance("Attention data parallelism",
|
|
234
|
+
baseline_time,
|
|
235
|
+
dp_time,
|
|
236
|
+
len(test_prompts),
|
|
237
|
+
tol=1.1)
|
|
276
238
|
|
|
277
|
-
print(
|
|
278
|
-
f"✓ Attention data parallelism test passed with {len(outputs)} outputs"
|
|
279
|
-
)
|
|
280
239
|
|
|
281
|
-
|
|
282
|
-
def test_data_parallelism_correctness(
|
|
283
|
-
test_prompts: list,
|
|
240
|
+
def test_data_parallelism(
|
|
284
241
|
sampling_params: SamplingParams,
|
|
242
|
+
test_prompts: list,
|
|
285
243
|
):
|
|
286
244
|
"""
|
|
287
|
-
|
|
288
|
-
This test compares outputs from a single-device run with data parallel runs
|
|
289
|
-
to ensure correctness, including log probabilities.
|
|
245
|
+
Correctness and performance test for model DP
|
|
290
246
|
"""
|
|
291
|
-
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
|
|
292
|
-
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
|
|
293
247
|
os.environ['MODEL_IMPL_TYPE'] = "flax_nnx"
|
|
294
248
|
|
|
295
249
|
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
|
296
|
-
# Use a smaller subset of prompts for correctness testing
|
|
297
|
-
small_prompts = test_prompts[:10]
|
|
298
250
|
|
|
299
|
-
#
|
|
300
|
-
|
|
251
|
+
# Configuration for long sequences
|
|
252
|
+
max_model_len = 2048
|
|
253
|
+
max_num_batched_tokens = 4096
|
|
254
|
+
max_num_seqs = 128
|
|
255
|
+
|
|
256
|
+
# Run with data parallelism (dp=2, tp=1)
|
|
257
|
+
dp_outputs, dp_time = _run_inference_with_config(
|
|
301
258
|
model_name=model_name,
|
|
302
|
-
test_prompts=
|
|
259
|
+
test_prompts=test_prompts,
|
|
303
260
|
sampling_params=sampling_params,
|
|
304
261
|
tensor_parallel_size=1,
|
|
305
|
-
data_parallel_size=
|
|
262
|
+
data_parallel_size=2,
|
|
306
263
|
async_scheduling=True,
|
|
264
|
+
max_model_len=max_model_len,
|
|
265
|
+
max_num_batched_tokens=max_num_batched_tokens,
|
|
266
|
+
max_num_seqs=max_num_seqs,
|
|
307
267
|
)
|
|
308
268
|
|
|
309
|
-
# Run
|
|
310
|
-
|
|
269
|
+
# Run baseline (tp=1)
|
|
270
|
+
baseline_outputs, baseline_time = _run_inference_with_config(
|
|
311
271
|
model_name=model_name,
|
|
312
|
-
test_prompts=
|
|
272
|
+
test_prompts=test_prompts,
|
|
313
273
|
sampling_params=sampling_params,
|
|
314
274
|
tensor_parallel_size=1,
|
|
315
|
-
data_parallel_size=
|
|
275
|
+
data_parallel_size=1,
|
|
316
276
|
async_scheduling=True,
|
|
277
|
+
max_model_len=max_model_len,
|
|
278
|
+
max_num_batched_tokens=max_num_batched_tokens,
|
|
279
|
+
max_num_seqs=max_num_seqs,
|
|
317
280
|
)
|
|
318
281
|
|
|
319
|
-
|
|
320
|
-
assert len(baseline_outputs) == len(dp_outputs)
|
|
321
|
-
|
|
322
|
-
text_matches = 0
|
|
323
|
-
text_mismatches = 0
|
|
324
|
-
logprob_mismatches = 0
|
|
325
|
-
max_logprob_diff = 0.0
|
|
326
|
-
|
|
327
|
-
for i, (baseline, dp_result) in enumerate(zip(baseline_outputs,
|
|
328
|
-
dp_outputs)):
|
|
329
|
-
baseline_text = baseline.outputs[0].text.strip()
|
|
330
|
-
dp_text = dp_result.outputs[0].text.strip()
|
|
331
|
-
|
|
332
|
-
# Check text output
|
|
333
|
-
if baseline_text == dp_text:
|
|
334
|
-
text_matches += 1
|
|
335
|
-
else:
|
|
336
|
-
text_mismatches += 1
|
|
337
|
-
print(f"Text mismatch found in prompt {i}:")
|
|
338
|
-
print(f" Baseline: {baseline_text}")
|
|
339
|
-
print(f" Data Parallel: {dp_text}")
|
|
340
|
-
|
|
341
|
-
# Check log probabilities
|
|
342
|
-
baseline_logprobs = baseline.outputs[0].logprobs
|
|
343
|
-
dp_logprobs = dp_result.outputs[0].logprobs
|
|
344
|
-
|
|
345
|
-
if baseline_logprobs is not None and dp_logprobs is not None:
|
|
346
|
-
# Compare log probabilities for each token
|
|
347
|
-
assert len(baseline_logprobs) == len(dp_logprobs), \
|
|
348
|
-
f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(dp_logprobs)}"
|
|
349
|
-
|
|
350
|
-
for token_idx, (base_lp, dp_lp) in enumerate(
|
|
351
|
-
zip(baseline_logprobs, dp_logprobs)):
|
|
352
|
-
# Get the top logprob value for the selected token
|
|
353
|
-
if base_lp and dp_lp:
|
|
354
|
-
# Get the top token's logprob from each
|
|
355
|
-
base_top_token = list(base_lp.keys())[0]
|
|
356
|
-
dp_top_token = list(dp_lp.keys())[0]
|
|
357
|
-
|
|
358
|
-
base_logprob_val = base_lp[base_top_token].logprob
|
|
359
|
-
dp_logprob_val = dp_lp[dp_top_token].logprob
|
|
360
|
-
|
|
361
|
-
# Calculate absolute difference
|
|
362
|
-
diff = abs(base_logprob_val - dp_logprob_val)
|
|
363
|
-
max_logprob_diff = max(max_logprob_diff, diff)
|
|
364
|
-
|
|
365
|
-
# Allow small numerical differences
|
|
366
|
-
if diff > 0.15:
|
|
367
|
-
logprob_mismatches += 1
|
|
368
|
-
print(
|
|
369
|
-
f"Logprob mismatch in prompt {i}, token {token_idx}:"
|
|
370
|
-
)
|
|
371
|
-
print(
|
|
372
|
-
f" Baseline token: {base_top_token}, logprob: {base_logprob_val:.6f}"
|
|
373
|
-
)
|
|
374
|
-
print(
|
|
375
|
-
f" DP token: {dp_top_token}, logprob: {dp_logprob_val:.6f}"
|
|
376
|
-
)
|
|
377
|
-
print(f" Difference: {diff:.6f}")
|
|
378
|
-
|
|
379
|
-
print("✓ Correctness test results:")
|
|
380
|
-
print(f" Text: {text_matches} matches, {text_mismatches} mismatches")
|
|
381
|
-
print(f" Max logprob difference: {max_logprob_diff:.6e}")
|
|
382
|
-
print(f" Significant logprob mismatches (>0.15): {logprob_mismatches}")
|
|
383
|
-
|
|
384
|
-
# Allow for some variance due to potential numerical differences
|
|
385
|
-
# but most outputs should match with greedy sampling
|
|
386
|
-
text_match_rate = text_matches / len(baseline_outputs)
|
|
387
|
-
assert text_match_rate >= 0.9, f"Text match rate {text_match_rate:.2%} is too low"
|
|
388
|
-
|
|
389
|
-
# Log probabilities should be very close (allow small numerical errors)
|
|
390
|
-
assert max_logprob_diff < 0.15, f"Max logprob difference {max_logprob_diff} is too large"
|
|
282
|
+
_check_correctness("Data parallelism", baseline_outputs, dp_outputs)
|
|
391
283
|
|
|
392
|
-
#
|
|
393
|
-
|
|
284
|
+
# Test is too small to see significant speedup, mainly for testing regression
|
|
285
|
+
_check_performance("Data parallelism",
|
|
286
|
+
baseline_time,
|
|
287
|
+
dp_time,
|
|
288
|
+
len(test_prompts),
|
|
289
|
+
tol=1.1)
|