tpu-inference 0.12.0.dev20251222__py3-none-any.whl → 0.12.0.dev20251224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. tests/core/test_dp_scheduler.py +128 -71
  2. tests/e2e/test_data_parallel.py +176 -280
  3. tests/e2e/test_hybrid_kvcache.py +219 -0
  4. tests/e2e/test_speculative_decoding.py +26 -6
  5. tests/layers/jax/test_qwix.py +1 -1
  6. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +36 -21
  7. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +36 -21
  8. tests/layers/vllm/test_mxfp4.py +25 -10
  9. tests/layers/vllm/test_unquantized.py +61 -31
  10. tests/layers/vllm/utils.py +19 -4
  11. tests/models/common/test_model_loader.py +2 -2
  12. tests/models/jax/test_qwen2_5_vl.py +10 -11
  13. tests/runner/test_multimodal_manager.py +3 -3
  14. tests/runner/test_tpu_runner.py +67 -8
  15. tests/runner/test_tpu_runner_dp.py +66 -0
  16. tpu_inference/core/sched/dp_scheduler.py +65 -40
  17. tpu_inference/kernels/mla/v1/kernel.py +7 -26
  18. tpu_inference/layers/common/sharding.py +8 -3
  19. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +3 -3
  20. tpu_inference/layers/jax/attention/gpt_oss_attention.py +3 -3
  21. tpu_inference/layers/jax/attention/llama4_attention.py +3 -4
  22. tpu_inference/layers/jax/sample/sampling.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +51 -47
  24. tpu_inference/layers/vllm/quantization/common.py +14 -13
  25. tpu_inference/layers/vllm/quantization/mxfp4.py +21 -7
  26. tpu_inference/layers/vllm/quantization/unquantized.py +19 -7
  27. tpu_inference/layers/vllm/sharding.py +7 -4
  28. tpu_inference/models/common/model_loader.py +11 -14
  29. tpu_inference/models/jax/llama3.py +13 -10
  30. tpu_inference/models/jax/llama_guard_4.py +1 -1
  31. tpu_inference/models/jax/qwen2.py +3 -2
  32. tpu_inference/models/jax/qwen2_5_vl.py +4 -4
  33. tpu_inference/models/jax/utils/multi_modal_utils.py +4 -4
  34. tpu_inference/models/jax/utils/qwix/qwix_utils.py +3 -3
  35. tpu_inference/models/vllm/vllm_model_wrapper.py +5 -2
  36. tpu_inference/platforms/tpu_platform.py +7 -7
  37. tpu_inference/runner/compilation_manager.py +43 -33
  38. tpu_inference/runner/kv_cache_manager.py +1 -2
  39. tpu_inference/runner/multimodal_manager.py +1 -1
  40. tpu_inference/runner/tpu_runner.py +12 -9
  41. tpu_inference/utils.py +31 -30
  42. tpu_inference/worker/tpu_worker.py +5 -2
  43. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/METADATA +1 -1
  44. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/RECORD +47 -46
  45. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/WHEEL +0 -0
  46. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/licenses/LICENSE +0 -0
  47. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/top_level.txt +0 -0
@@ -3,38 +3,37 @@
3
3
 
4
4
  import os
5
5
  import time
6
- from dataclasses import asdict
7
6
 
8
7
  import pytest
9
- from vllm import LLM, EngineArgs, SamplingParams
8
+ from vllm import LLM, SamplingParams
10
9
 
11
10
 
12
11
  @pytest.fixture(autouse=True)
13
12
  def setup_new_model_design():
14
- """Automatically set NEW_MODEL_DESIGN=1 for all tests."""
15
13
  os.environ['NEW_MODEL_DESIGN'] = '1'
14
+ os.environ['SKIP_JAX_PRECOMPILE'] = '0'
15
+ os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '1'
16
16
 
17
17
 
18
18
  @pytest.fixture
19
- def test_prompts():
20
- """Simple test prompts for data parallelism testing."""
19
+ def test_prompts(num_prompts: int = 256) -> list:
20
+ base_text = (
21
+ "The rapid advancement of artificial intelligence has transformed numerous industries "
22
+ "and continues to reshape our understanding of technology's potential. Machine learning "
23
+ "algorithms have become increasingly sophisticated, enabling computers to perform tasks "
24
+ "that were once thought to require human intelligence. From natural language processing "
25
+ "to computer vision, AI systems are now capable of understanding context, recognizing "
26
+ "patterns, and making decisions with remarkable accuracy. " *
27
+ 20 # Repeat to reach ~1k tokens
28
+ )
21
29
  return [
22
- "Hello, my name is",
23
- "The capital of France is",
24
- "The colors of the rainbow are",
25
- "The future of AI is",
26
- "The president of the United States is",
27
- "How many players are on a standard soccer team?",
28
- "In Greek mythology, who is the god of the sea?",
29
- "What is the capital of Australia?",
30
- "What is the largest planet in our solar system?",
31
- "Who developed the theory of general relativity?",
30
+ f"Prompt {i}: {base_text} What are your thoughts on this topic?"
31
+ for i in range(num_prompts)
32
32
  ]
33
33
 
34
34
 
35
35
  @pytest.fixture
36
36
  def sampling_params():
37
- """Standard sampling parameters for testing."""
38
37
  return SamplingParams(
39
38
  temperature=0.0,
40
39
  max_tokens=32,
@@ -52,24 +51,18 @@ def _run_inference_with_config(model_name: str,
52
51
  kv_cache_dtype: str = "auto",
53
52
  enable_prefix_caching: bool = False,
54
53
  async_scheduling: bool = False,
55
- measure_time: bool = False,
56
54
  max_model_len: int = 32,
57
55
  max_num_batched_tokens: int = 128,
58
- max_num_seqs: int = 16):
59
- """Helper function to run inference with specified configuration.
56
+ max_num_seqs: int = 16,
57
+ gpu_memory_utilization: float = 0.90,
58
+ trace_dir: str = None) -> list:
60
59
 
61
- Returns:
62
- If measure_time=True: (outputs, elapsed_time) tuple
63
- If measure_time=False: outputs list
64
- """
65
-
66
- # Create LLM args using parser-based approach similar to offline_inference.py
67
- engine_args = EngineArgs(
60
+ llm = LLM(
68
61
  model=model_name,
69
62
  max_model_len=max_model_len,
70
63
  tensor_parallel_size=tensor_parallel_size,
71
64
  data_parallel_size=data_parallel_size,
72
- gpu_memory_utilization=0.98,
65
+ gpu_memory_utilization=gpu_memory_utilization,
73
66
  max_num_batched_tokens=max_num_batched_tokens,
74
67
  max_num_seqs=max_num_seqs,
75
68
  enable_prefix_caching=enable_prefix_caching,
@@ -78,149 +71,112 @@ def _run_inference_with_config(model_name: str,
78
71
  async_scheduling=async_scheduling,
79
72
  )
80
73
 
81
- engine_args_dict = asdict(engine_args)
82
- llm = LLM(**engine_args_dict)
74
+ start_time = time.time()
75
+ outputs = llm.generate(test_prompts, sampling_params)
76
+ elapsed_time = time.time() - start_time
83
77
 
84
- try:
85
- start_time = time.time()
86
- outputs = llm.generate(test_prompts, sampling_params)
87
- elapsed_time = time.time() - start_time
88
- if measure_time:
89
- return outputs, elapsed_time
90
- else:
91
- return outputs
92
- finally:
93
- del llm
94
- # Wait for TPUs to be released
95
- time.sleep(5)
78
+ del llm
79
+ time.sleep(10)
80
+ return outputs, elapsed_time
96
81
 
97
82
 
98
- def test_data_parallelism_performance(sampling_params: SamplingParams, ):
99
- """
100
- Test that data parallelism provides performance improvements compared to baseline.
101
- This test measures the execution time with 128 prompts of length ~1k tokens.
83
+ def _check_performance(test_name: str, baseline_time: float, dp_time: float,
84
+ num_prompts: int, tol: float):
102
85
 
103
- Note: This is a performance benchmark test with large prompts.
104
- """
105
- os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '1'
106
- os.environ['SKIP_JAX_PRECOMPILE'] = '0'
107
- os.environ['MODEL_IMPL_TYPE'] = 'flax_nnx'
86
+ speedup = baseline_time / dp_time if dp_time > 0 else 0
108
87
 
109
- model_name = "Qwen/Qwen2.5-1.5B-Instruct"
88
+ print(f"✓ {test_name} performance test results:")
89
+ print(f" Number of prompts: {num_prompts}")
90
+ print(f" Baseline time: {baseline_time:.2f}s")
91
+ print(f" Data parallel time: {dp_time:.2f}s")
92
+ print(f" Speedup: {speedup:.2f}x")
93
+ print(f" Baseline throughput: {num_prompts/baseline_time:.2f} prompts/s")
94
+ print(f" Data parallel throughput: {num_prompts/dp_time:.2f} prompts/s")
110
95
 
111
- # Generate 128 prompts of approximately 1k tokens each
112
- # Creating a base prompt of about 1k tokens using repeated text
113
- base_text = (
114
- "The rapid advancement of artificial intelligence has transformed numerous industries "
115
- "and continues to reshape our understanding of technology's potential. Machine learning "
116
- "algorithms have become increasingly sophisticated, enabling computers to perform tasks "
117
- "that were once thought to require human intelligence. From natural language processing "
118
- "to computer vision, AI systems are now capable of understanding context, recognizing "
119
- "patterns, and making decisions with remarkable accuracy. " *
120
- 20 # Repeat to reach ~1k tokens
121
- )
96
+ assert speedup >= tol, f"Data parallelism did not provide expected speedup ({tol:.2f}x): {speedup:.2f}x"
122
97
 
123
- # Create 128 prompts with slight variations
124
- long_prompts = [
125
- f"Prompt {i}: {base_text} What are your thoughts on this topic?"
126
- for i in range(128)
127
- ]
128
98
 
129
- print(
130
- f"Generated {len(long_prompts)} prompts, approximate length: {len(base_text.split())} tokens each"
131
- )
99
+ def _check_correctness(test_name, baseline_outputs, dp_outputs):
132
100
 
133
- # Configuration for long sequences
134
- max_model_len = 2048
135
- max_num_batched_tokens = 4096
136
- max_num_seqs = 64
101
+ assert len(baseline_outputs) == len(dp_outputs)
137
102
 
138
- # Run baseline (no data parallelism) with timing
139
- baseline_outputs, baseline_time = _run_inference_with_config(
140
- model_name=model_name,
141
- test_prompts=long_prompts,
142
- sampling_params=sampling_params,
143
- tensor_parallel_size=1,
144
- data_parallel_size=1,
145
- async_scheduling=True,
146
- measure_time=True,
147
- max_model_len=max_model_len,
148
- max_num_batched_tokens=max_num_batched_tokens,
149
- max_num_seqs=max_num_seqs,
150
- )
103
+ text_matches = 0
104
+ logprob_matches = 0
105
+ total_compared_logprobs = 0
106
+ max_logprob_diff = 0.0
151
107
 
152
- # Run with model data parallelism and async scheduling with timing
153
- dp_outputs, dp_time = _run_inference_with_config(
154
- model_name=model_name,
155
- test_prompts=long_prompts,
156
- sampling_params=sampling_params,
157
- tensor_parallel_size=1,
158
- data_parallel_size=2,
159
- async_scheduling=True,
160
- measure_time=True,
161
- max_model_len=max_model_len,
162
- max_num_batched_tokens=max_num_batched_tokens,
163
- max_num_seqs=max_num_seqs,
164
- )
108
+ for i, (baseline, dp_result) in enumerate(zip(baseline_outputs,
109
+ dp_outputs)):
110
+ baseline_text = baseline.outputs[0].text.strip()
111
+ dp_text = dp_result.outputs[0].text.strip()
165
112
 
166
- # Calculate speedup
167
- speedup = baseline_time / dp_time if dp_time > 0 else 0
113
+ baseline_words = baseline_text.split()
114
+ dp_words = dp_text.split()
115
+ overlap_set = set(baseline_words) & set(dp_words)
116
+ match_percent = len(overlap_set) / len(set(baseline_words))
117
+ if match_percent >= 0.7:
118
+ text_matches += 1
168
119
 
169
- print("✓ Performance test results:")
170
- print(f" Number of prompts: {len(long_prompts)}")
171
- print(f" Baseline time: {baseline_time:.2f}s")
172
- print(f" Data parallel time: {dp_time:.2f}s")
173
- print(f" Speedup: {speedup:.2f}x")
174
- print(
175
- f" Baseline throughput: {len(long_prompts)/baseline_time:.2f} prompts/s"
176
- )
177
- print(
178
- f" Data parallel throughput: {len(long_prompts)/dp_time:.2f} prompts/s"
179
- )
120
+ # Check text output
121
+ if baseline_text != dp_text:
122
+ print(f"Text mismatch found in prompt {i}:")
123
+ print(f" Baseline: {baseline_text}")
124
+ print(f" Data Parallel: {dp_text}")
125
+ print(f" Match percent: {match_percent:.2%}")
180
126
 
127
+ # Check log probabilities
128
+ baseline_logprobs = baseline.outputs[0].logprobs
129
+ dp_logprobs = dp_result.outputs[0].logprobs
181
130
 
182
- @pytest.mark.parametrize("model_impl_type", ["vllm", "flax_nnx"])
183
- def test_model_data_parallelism(
184
- test_prompts: list,
185
- sampling_params: SamplingParams,
186
- model_impl_type: str,
187
- ):
188
- """
189
- Test model-wise data parallelism where data=2 in the mesh axis.
190
- This test verifies that the model can run with data parallelism enabled,
191
- duplicating the entire model across 2 data parallel workers.
131
+ if baseline_logprobs is not None and dp_logprobs is not None:
132
+ # Compare log probabilities for each token
133
+ assert len(baseline_logprobs) == len(dp_logprobs), \
134
+ f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(dp_logprobs)}"
192
135
 
193
- Equivalent to:
194
- python examples/offline_inference.py --tensor_parallel_size=4 --data_parallel_size=2
195
- """
196
- # Use Llama 1B for this test
197
- test_model = "meta-llama/Llama-3.2-1B-Instruct"
198
- os.environ['MODEL_IMPL_TYPE'] = model_impl_type
199
- os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
200
- os.environ['SKIP_JAX_PRECOMPILE'] = '1'
201
-
202
- # Test with data parallelism enabled
203
- outputs = _run_inference_with_config(
204
- model_name=test_model,
205
- test_prompts=test_prompts,
206
- sampling_params=sampling_params,
207
- tensor_parallel_size=1,
208
- data_parallel_size=2,
209
- async_scheduling=False,
210
- )
136
+ for token_idx, (base_lp, dp_lp) in enumerate(
137
+ zip(baseline_logprobs, dp_logprobs)):
138
+ # Get the top logprob value for the selected token
139
+ if base_lp and dp_lp:
140
+ # Get the top token's logprob from each
141
+ base_top_token = list(base_lp.keys())[0]
142
+ dp_top_token = list(dp_lp.keys())[0]
211
143
 
212
- # Verify we got outputs for all prompts
213
- assert len(outputs) == len(
214
- test_prompts
215
- ), f"Expected {len(test_prompts)} outputs, got {len(outputs)}"
144
+ # Only compare logprobs if tokens match
145
+ if base_top_token == dp_top_token:
146
+ base_logprob_val = base_lp[base_top_token].logprob
147
+ dp_logprob_val = dp_lp[dp_top_token].logprob
148
+
149
+ # Calculate absolute difference
150
+ diff = abs(base_logprob_val - dp_logprob_val)
151
+ max_logprob_diff = max(max_logprob_diff, diff)
152
+
153
+ total_compared_logprobs += 1
154
+ # Count as match if difference is small
155
+ if diff < 0.1:
156
+ logprob_matches += 1
157
+ else:
158
+ print(
159
+ f" Logprob mismatch in prompt {i}, token {token_idx}: "
160
+ f"Baseline logprob={base_logprob_val}, "
161
+ f"Data Parallel logprob={dp_logprob_val}, "
162
+ f"Diff={diff:.6e}")
163
+
164
+ print(f"✓ {test_name} correctness test results:")
165
+ print(f" Text: {text_matches} matches (match percent >= 70%)")
166
+ print(
167
+ f" Logprobs: {logprob_matches}/{total_compared_logprobs} ({logprob_matches / total_compared_logprobs:.2%}) matches (diff < 0.1)"
168
+ )
169
+ print(f" Max logprob difference: {max_logprob_diff:.6e}")
216
170
 
217
- # Verify each output has generated text
218
- for output in outputs:
219
- assert len(output.outputs) > 0, "Output has no generated text"
220
- assert len(
221
- output.outputs[0].text.strip()) > 0, "Generated text is empty"
171
+ # Allow for some variance due to potential numerical differences
172
+ # but most outputs should match with greedy sampling
173
+ text_match_rate = text_matches / len(baseline_outputs)
174
+ assert text_match_rate >= 0.9, f"Text match rate {text_match_rate:.2%} is too low"
222
175
 
223
- print(f"✓ Model data parallelism test passed with {len(outputs)} outputs")
176
+ # Log probabilities should match for most matching tokens
177
+ if total_compared_logprobs > 0:
178
+ logprob_match_rate = logprob_matches / total_compared_logprobs
179
+ assert logprob_match_rate >= 0.9, f"Logprob match rate {logprob_match_rate:.2%} is too low"
224
180
 
225
181
 
226
182
  def test_attention_data_parallelism(
@@ -228,166 +184,106 @@ def test_attention_data_parallelism(
228
184
  sampling_params: SamplingParams,
229
185
  ):
230
186
  """
231
- Test attention data parallelism where only the attention layer gets duplicated,
232
- attn_dp=2 in the mesh axis. This is useful when num_kv_heads < TP to avoid
233
- wasting KV cache memory.
234
-
235
- Equivalent to:
236
- python examples/offline_inference.py --tensor_parallel_size=4 --kv-cache-dtype=fp8 \
237
- --additional_config='{"sharding":{"sharding_strategy": {"enable_dp_attention":1}}}'
187
+ Correctness and performance test for attention DP
238
188
  """
239
- # Use Qwen3 0.6B for this test with reduced tensor parallelism
240
- test_model = "Qwen/Qwen3-0.6B"
241
189
 
242
- os.environ['MODEL_IMPL_TYPE'] = "flax_nnx"
243
- os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
244
- os.environ['SKIP_JAX_PRECOMPILE'] = '1'
190
+ os.environ['MODEL_IMPL_TYPE'] = "vllm"
191
+ model_name = "Qwen/Qwen2.5-1.5B-Instruct"
245
192
 
246
- additional_config = {
247
- "sharding": {
248
- "sharding_strategy": {
249
- "enable_dp_attention": 1
250
- }
251
- }
252
- }
193
+ # Configuration for long sequences
194
+ max_model_len = 2048
195
+ max_num_batched_tokens = 4096
196
+ max_num_seqs = 128
253
197
 
254
- # Test with attention data parallelism enabled
255
- # Reduced tensor_parallel_size from 8 to 4 to avoid memory exhaustion
256
- outputs = _run_inference_with_config(
257
- model_name=test_model,
198
+ # Run with attn_dp=2 tp=2
199
+ dp_outputs, dp_time = _run_inference_with_config(
200
+ model_name=model_name,
258
201
  test_prompts=test_prompts,
259
202
  sampling_params=sampling_params,
260
203
  tensor_parallel_size=4,
204
+ async_scheduling=False,
205
+ max_model_len=max_model_len,
206
+ max_num_batched_tokens=max_num_batched_tokens,
207
+ max_num_seqs=max_num_seqs,
208
+ additional_config={
209
+ "sharding": {
210
+ "sharding_strategy": {
211
+ "enable_dp_attention": 1
212
+ }
213
+ }
214
+ })
215
+
216
+ # Run baseline (tp=2)
217
+ baseline_outputs, baseline_time = _run_inference_with_config(
218
+ model_name=model_name,
219
+ test_prompts=test_prompts,
220
+ sampling_params=sampling_params,
221
+ tensor_parallel_size=2,
261
222
  data_parallel_size=1,
262
- additional_config=additional_config,
263
- kv_cache_dtype="fp8",
223
+ async_scheduling=False,
224
+ max_model_len=max_model_len,
225
+ max_num_batched_tokens=max_num_batched_tokens,
226
+ max_num_seqs=max_num_seqs,
264
227
  )
265
228
 
266
- # Verify we got outputs for all prompts
267
- assert len(outputs) == len(
268
- test_prompts
269
- ), f"Expected {len(test_prompts)} outputs, got {len(outputs)}"
229
+ _check_correctness("Attention data parallelism", baseline_outputs,
230
+ dp_outputs)
270
231
 
271
- # Verify each output has generated text
272
- for output in outputs:
273
- assert len(output.outputs) > 0, "Output has no generated text"
274
- assert len(
275
- output.outputs[0].text.strip()) > 0, "Generated text is empty"
232
+ # Different hardware gives different performance. This test runs on v6e_8
233
+ _check_performance("Attention data parallelism",
234
+ baseline_time,
235
+ dp_time,
236
+ len(test_prompts),
237
+ tol=1.1)
276
238
 
277
- print(
278
- f"✓ Attention data parallelism test passed with {len(outputs)} outputs"
279
- )
280
239
 
281
-
282
- def test_data_parallelism_correctness(
283
- test_prompts: list,
240
+ def test_data_parallelism(
284
241
  sampling_params: SamplingParams,
242
+ test_prompts: list,
285
243
  ):
286
244
  """
287
- Test that data parallelism produces consistent results compared to a baseline.
288
- This test compares outputs from a single-device run with data parallel runs
289
- to ensure correctness, including log probabilities.
245
+ Correctness and performance test for model DP
290
246
  """
291
- os.environ['SKIP_JAX_PRECOMPILE'] = '1'
292
- os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
293
247
  os.environ['MODEL_IMPL_TYPE'] = "flax_nnx"
294
248
 
295
249
  model_name = "Qwen/Qwen2.5-1.5B-Instruct"
296
- # Use a smaller subset of prompts for correctness testing
297
- small_prompts = test_prompts[:10]
298
250
 
299
- # Run baseline (no data parallelism)
300
- baseline_outputs = _run_inference_with_config(
251
+ # Configuration for long sequences
252
+ max_model_len = 2048
253
+ max_num_batched_tokens = 4096
254
+ max_num_seqs = 128
255
+
256
+ # Run with data parallelism (dp=2, tp=1)
257
+ dp_outputs, dp_time = _run_inference_with_config(
301
258
  model_name=model_name,
302
- test_prompts=small_prompts,
259
+ test_prompts=test_prompts,
303
260
  sampling_params=sampling_params,
304
261
  tensor_parallel_size=1,
305
- data_parallel_size=1,
262
+ data_parallel_size=2,
306
263
  async_scheduling=True,
264
+ max_model_len=max_model_len,
265
+ max_num_batched_tokens=max_num_batched_tokens,
266
+ max_num_seqs=max_num_seqs,
307
267
  )
308
268
 
309
- # Run with model data parallelism and async scheduling
310
- dp_outputs = _run_inference_with_config(
269
+ # Run baseline (tp=1)
270
+ baseline_outputs, baseline_time = _run_inference_with_config(
311
271
  model_name=model_name,
312
- test_prompts=small_prompts,
272
+ test_prompts=test_prompts,
313
273
  sampling_params=sampling_params,
314
274
  tensor_parallel_size=1,
315
- data_parallel_size=2,
275
+ data_parallel_size=1,
316
276
  async_scheduling=True,
277
+ max_model_len=max_model_len,
278
+ max_num_batched_tokens=max_num_batched_tokens,
279
+ max_num_seqs=max_num_seqs,
317
280
  )
318
281
 
319
- # Compare outputs - they should be identical for greedy sampling
320
- assert len(baseline_outputs) == len(dp_outputs)
321
-
322
- text_matches = 0
323
- text_mismatches = 0
324
- logprob_mismatches = 0
325
- max_logprob_diff = 0.0
326
-
327
- for i, (baseline, dp_result) in enumerate(zip(baseline_outputs,
328
- dp_outputs)):
329
- baseline_text = baseline.outputs[0].text.strip()
330
- dp_text = dp_result.outputs[0].text.strip()
331
-
332
- # Check text output
333
- if baseline_text == dp_text:
334
- text_matches += 1
335
- else:
336
- text_mismatches += 1
337
- print(f"Text mismatch found in prompt {i}:")
338
- print(f" Baseline: {baseline_text}")
339
- print(f" Data Parallel: {dp_text}")
340
-
341
- # Check log probabilities
342
- baseline_logprobs = baseline.outputs[0].logprobs
343
- dp_logprobs = dp_result.outputs[0].logprobs
344
-
345
- if baseline_logprobs is not None and dp_logprobs is not None:
346
- # Compare log probabilities for each token
347
- assert len(baseline_logprobs) == len(dp_logprobs), \
348
- f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(dp_logprobs)}"
349
-
350
- for token_idx, (base_lp, dp_lp) in enumerate(
351
- zip(baseline_logprobs, dp_logprobs)):
352
- # Get the top logprob value for the selected token
353
- if base_lp and dp_lp:
354
- # Get the top token's logprob from each
355
- base_top_token = list(base_lp.keys())[0]
356
- dp_top_token = list(dp_lp.keys())[0]
357
-
358
- base_logprob_val = base_lp[base_top_token].logprob
359
- dp_logprob_val = dp_lp[dp_top_token].logprob
360
-
361
- # Calculate absolute difference
362
- diff = abs(base_logprob_val - dp_logprob_val)
363
- max_logprob_diff = max(max_logprob_diff, diff)
364
-
365
- # Allow small numerical differences
366
- if diff > 0.15:
367
- logprob_mismatches += 1
368
- print(
369
- f"Logprob mismatch in prompt {i}, token {token_idx}:"
370
- )
371
- print(
372
- f" Baseline token: {base_top_token}, logprob: {base_logprob_val:.6f}"
373
- )
374
- print(
375
- f" DP token: {dp_top_token}, logprob: {dp_logprob_val:.6f}"
376
- )
377
- print(f" Difference: {diff:.6f}")
378
-
379
- print("✓ Correctness test results:")
380
- print(f" Text: {text_matches} matches, {text_mismatches} mismatches")
381
- print(f" Max logprob difference: {max_logprob_diff:.6e}")
382
- print(f" Significant logprob mismatches (>0.15): {logprob_mismatches}")
383
-
384
- # Allow for some variance due to potential numerical differences
385
- # but most outputs should match with greedy sampling
386
- text_match_rate = text_matches / len(baseline_outputs)
387
- assert text_match_rate >= 0.9, f"Text match rate {text_match_rate:.2%} is too low"
388
-
389
- # Log probabilities should be very close (allow small numerical errors)
390
- assert max_logprob_diff < 0.15, f"Max logprob difference {max_logprob_diff} is too large"
282
+ _check_correctness("Data parallelism", baseline_outputs, dp_outputs)
391
283
 
392
- # Log probabilities should be very close (allow small numerical errors)
393
- assert max_logprob_diff < 0.15, f"Max logprob difference {max_logprob_diff} is too large"
284
+ # Test is too small to see significant speedup, mainly for testing regression
285
+ _check_performance("Data parallelism",
286
+ baseline_time,
287
+ dp_time,
288
+ len(test_prompts),
289
+ tol=1.1)