tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import random
|
|
18
|
+
import string
|
|
19
|
+
import time
|
|
20
|
+
|
|
21
|
+
import pytest
|
|
22
|
+
from vllm import LLM, SamplingParams
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def sampling_config():
|
|
27
|
+
return SamplingParams(temperature=0,
|
|
28
|
+
max_tokens=120,
|
|
29
|
+
ignore_eos=True,
|
|
30
|
+
repetition_penalty=1,
|
|
31
|
+
frequency_penalty=0,
|
|
32
|
+
presence_penalty=0,
|
|
33
|
+
min_p=0,
|
|
34
|
+
logprobs=None)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@pytest.fixture
|
|
38
|
+
def model_name():
|
|
39
|
+
return "Qwen/Qwen2.5-1.5B-Instruct"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_test_prompts():
|
|
43
|
+
"""
|
|
44
|
+
Generates a list of prompts with a specific word count,
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
num_prompts: The number of prompts to generate.
|
|
48
|
+
input_len_words: The total number of words for each prompt.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
A list of strings with number of prompts = num_prompts and
|
|
52
|
+
The total number of words for each prompt = input_len_words.
|
|
53
|
+
"""
|
|
54
|
+
num_prompts = 500
|
|
55
|
+
input_len_words = 120
|
|
56
|
+
prompts = []
|
|
57
|
+
|
|
58
|
+
# For example w = 's'
|
|
59
|
+
# The generated prompt will be Keep repeating: s s s ...
|
|
60
|
+
num_repetitions = input_len_words
|
|
61
|
+
prefix = "Keep repeating: "
|
|
62
|
+
|
|
63
|
+
for _ in range(num_prompts):
|
|
64
|
+
# 1. Pick a random lowercase letter
|
|
65
|
+
w = random.choice(list(string.ascii_lowercase))
|
|
66
|
+
|
|
67
|
+
# 2. Create the string of repeated words
|
|
68
|
+
# This will have (num_repetitions) words
|
|
69
|
+
repeating_part = " ".join([w] * num_repetitions)
|
|
70
|
+
|
|
71
|
+
# 3. Combine with the prefix (if any)
|
|
72
|
+
print(f"{prefix}{repeating_part}")
|
|
73
|
+
prompts.append(f"{prefix}{repeating_part}")
|
|
74
|
+
|
|
75
|
+
return prompts
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _test_performance_helper(monkeypatch: pytest.MonkeyPatch,
|
|
79
|
+
sampling_config: SamplingParams, model_name: str,
|
|
80
|
+
min_speedup: float):
|
|
81
|
+
'''
|
|
82
|
+
Helper function to test async scheduler decoding performance.
|
|
83
|
+
Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
|
|
84
|
+
'''
|
|
85
|
+
|
|
86
|
+
with monkeypatch.context():
|
|
87
|
+
# Use a smaller set of prompts for performance testing
|
|
88
|
+
test_prompts = get_test_prompts() # num_prompts=100, input_len=120
|
|
89
|
+
|
|
90
|
+
# Test reference LLM timing
|
|
91
|
+
ref_llm = LLM(model=model_name,
|
|
92
|
+
max_model_len=800,
|
|
93
|
+
max_num_seqs=24,
|
|
94
|
+
max_num_batched_tokens=512,
|
|
95
|
+
enable_prefix_caching=False,
|
|
96
|
+
async_scheduling=0)
|
|
97
|
+
|
|
98
|
+
start_time = time.time()
|
|
99
|
+
_ = ref_llm.generate(test_prompts, sampling_config)
|
|
100
|
+
ref_time = time.time() - start_time
|
|
101
|
+
|
|
102
|
+
del ref_llm
|
|
103
|
+
# Waiting for TPUs to be released
|
|
104
|
+
time.sleep(10)
|
|
105
|
+
|
|
106
|
+
# # Test async LLM timing with max_num_seqs=256
|
|
107
|
+
async_llm = LLM(model=model_name,
|
|
108
|
+
max_model_len=800,
|
|
109
|
+
max_num_seqs=24,
|
|
110
|
+
max_num_batched_tokens=512,
|
|
111
|
+
enable_prefix_caching=False,
|
|
112
|
+
async_scheduling=1)
|
|
113
|
+
|
|
114
|
+
start_time = time.time()
|
|
115
|
+
_ = async_llm.generate(test_prompts, sampling_config)
|
|
116
|
+
async_time = time.time() - start_time
|
|
117
|
+
|
|
118
|
+
del async_llm
|
|
119
|
+
# # Waiting for TPUs to be released
|
|
120
|
+
time.sleep(10)
|
|
121
|
+
|
|
122
|
+
speedup = ref_time / async_time
|
|
123
|
+
print(f"Reference LLM time: {ref_time:.2f}s")
|
|
124
|
+
print(f"Async LLM time: {async_time:.2f}s")
|
|
125
|
+
print(f"Speedup: {speedup:.2f}x")
|
|
126
|
+
|
|
127
|
+
assert speedup >= min_speedup, f"Expected at least {min_speedup}x speedup for async scheduler, got {speedup:.2f}x"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_performance(
|
|
131
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
132
|
+
sampling_config: SamplingParams,
|
|
133
|
+
model_name: str,
|
|
134
|
+
):
|
|
135
|
+
'''
|
|
136
|
+
Test that async scheduler decoding provides significant performance improvement.
|
|
137
|
+
Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
|
|
138
|
+
Expects async_llm to be at least 1.3x faster than ref_llm.
|
|
139
|
+
'''
|
|
140
|
+
min_speed_up = 1.3
|
|
141
|
+
_test_performance_helper(monkeypatch, sampling_config, model_name,
|
|
142
|
+
min_speed_up)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _test_correctness_helper(
|
|
146
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
147
|
+
sampling_config: SamplingParams,
|
|
148
|
+
model_name: str,
|
|
149
|
+
):
|
|
150
|
+
'''
|
|
151
|
+
Helper function to test async scheduler correctness.
|
|
152
|
+
Compare the outputs of a original LLM and a async LLM
|
|
153
|
+
should be the same when using async scheduler decoding.
|
|
154
|
+
|
|
155
|
+
Known Edge Case (KV Cache Swapping):
|
|
156
|
+
Under this case, though the temperature is set to 0,
|
|
157
|
+
the output is still slightly different everytime.
|
|
158
|
+
This is an expected behaviour as the normal scheduler also
|
|
159
|
+
behaves the same and hence, it is difficult to design a test
|
|
160
|
+
for such scenario.
|
|
161
|
+
'''
|
|
162
|
+
with monkeypatch.context():
|
|
163
|
+
test_prompts = get_test_prompts()
|
|
164
|
+
|
|
165
|
+
ref_llm = LLM(model=model_name,
|
|
166
|
+
max_model_len=1024,
|
|
167
|
+
max_num_seqs=100,
|
|
168
|
+
async_scheduling=0)
|
|
169
|
+
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
|
|
170
|
+
|
|
171
|
+
del ref_llm
|
|
172
|
+
|
|
173
|
+
# Waiting for TPUs to be released.
|
|
174
|
+
time.sleep(10)
|
|
175
|
+
|
|
176
|
+
async_llm = LLM(model=model_name,
|
|
177
|
+
max_model_len=1024,
|
|
178
|
+
max_num_seqs=100,
|
|
179
|
+
async_scheduling=1)
|
|
180
|
+
async_outputs = async_llm.generate(test_prompts, sampling_config)
|
|
181
|
+
|
|
182
|
+
matches = 0
|
|
183
|
+
misses = 0
|
|
184
|
+
for ref_output, async_output in zip(ref_outputs, async_outputs):
|
|
185
|
+
if ref_output.outputs[0].text == async_output.outputs[0].text:
|
|
186
|
+
print(f"ref_output: {ref_output.outputs[0].text}")
|
|
187
|
+
print(f"async_output: {async_output.outputs[0].text}")
|
|
188
|
+
matches += 1
|
|
189
|
+
else:
|
|
190
|
+
misses += 1
|
|
191
|
+
print(f"ref_output: {ref_output.outputs[0].text}")
|
|
192
|
+
print(f"async_output: {async_output.outputs[0].text}")
|
|
193
|
+
|
|
194
|
+
assert misses == 0
|
|
195
|
+
del async_outputs
|
|
196
|
+
|
|
197
|
+
# Waiting for TPUs to be released.
|
|
198
|
+
time.sleep(10)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def test_async_correctness(
|
|
202
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
203
|
+
sampling_config: SamplingParams,
|
|
204
|
+
model_name: str,
|
|
205
|
+
):
|
|
206
|
+
'''
|
|
207
|
+
Compare the outputs of a original LLM and a async LLM
|
|
208
|
+
should be the same when using async scheduler.
|
|
209
|
+
'''
|
|
210
|
+
|
|
211
|
+
_test_correctness_helper(monkeypatch, sampling_config, model_name)
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
from vllm import LLM, SamplingParams
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.fixture(autouse=True)
|
|
12
|
+
def setup_new_model_design():
|
|
13
|
+
os.environ['NEW_MODEL_DESIGN'] = '1'
|
|
14
|
+
os.environ['SKIP_JAX_PRECOMPILE'] = '0'
|
|
15
|
+
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '1'
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def test_prompts(num_prompts: int = 256) -> list:
|
|
20
|
+
base_text = (
|
|
21
|
+
"The rapid advancement of artificial intelligence has transformed numerous industries "
|
|
22
|
+
"and continues to reshape our understanding of technology's potential. Machine learning "
|
|
23
|
+
"algorithms have become increasingly sophisticated, enabling computers to perform tasks "
|
|
24
|
+
"that were once thought to require human intelligence. From natural language processing "
|
|
25
|
+
"to computer vision, AI systems are now capable of understanding context, recognizing "
|
|
26
|
+
"patterns, and making decisions with remarkable accuracy. " *
|
|
27
|
+
20 # Repeat to reach ~1k tokens
|
|
28
|
+
)
|
|
29
|
+
return [
|
|
30
|
+
f"Prompt {i}: {base_text} What are your thoughts on this topic?"
|
|
31
|
+
for i in range(num_prompts)
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture
|
|
36
|
+
def sampling_params():
|
|
37
|
+
return SamplingParams(
|
|
38
|
+
temperature=0.0,
|
|
39
|
+
max_tokens=32,
|
|
40
|
+
ignore_eos=True,
|
|
41
|
+
logprobs=1,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _run_inference_with_config(model_name: str,
|
|
46
|
+
test_prompts: list,
|
|
47
|
+
sampling_params: SamplingParams,
|
|
48
|
+
tensor_parallel_size: int = 1,
|
|
49
|
+
data_parallel_size: int = 1,
|
|
50
|
+
additional_config: dict = {},
|
|
51
|
+
kv_cache_dtype: str = "auto",
|
|
52
|
+
enable_prefix_caching: bool = False,
|
|
53
|
+
async_scheduling: bool = False,
|
|
54
|
+
max_model_len: int = 32,
|
|
55
|
+
max_num_batched_tokens: int = 128,
|
|
56
|
+
max_num_seqs: int = 16,
|
|
57
|
+
gpu_memory_utilization: float = 0.90,
|
|
58
|
+
trace_dir: str = None) -> list:
|
|
59
|
+
|
|
60
|
+
llm = LLM(
|
|
61
|
+
model=model_name,
|
|
62
|
+
max_model_len=max_model_len,
|
|
63
|
+
tensor_parallel_size=tensor_parallel_size,
|
|
64
|
+
data_parallel_size=data_parallel_size,
|
|
65
|
+
gpu_memory_utilization=gpu_memory_utilization,
|
|
66
|
+
max_num_batched_tokens=max_num_batched_tokens,
|
|
67
|
+
max_num_seqs=max_num_seqs,
|
|
68
|
+
enable_prefix_caching=enable_prefix_caching,
|
|
69
|
+
additional_config=additional_config,
|
|
70
|
+
kv_cache_dtype=kv_cache_dtype,
|
|
71
|
+
async_scheduling=async_scheduling,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
start_time = time.time()
|
|
75
|
+
outputs = llm.generate(test_prompts, sampling_params)
|
|
76
|
+
elapsed_time = time.time() - start_time
|
|
77
|
+
|
|
78
|
+
del llm
|
|
79
|
+
time.sleep(10)
|
|
80
|
+
return outputs, elapsed_time
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _check_performance(test_name: str, baseline_time: float, dp_time: float,
|
|
84
|
+
num_prompts: int, tol: float):
|
|
85
|
+
|
|
86
|
+
speedup = baseline_time / dp_time if dp_time > 0 else 0
|
|
87
|
+
|
|
88
|
+
print(f"✓ {test_name} performance test results:")
|
|
89
|
+
print(f" Number of prompts: {num_prompts}")
|
|
90
|
+
print(f" Baseline time: {baseline_time:.2f}s")
|
|
91
|
+
print(f" Data parallel time: {dp_time:.2f}s")
|
|
92
|
+
print(f" Speedup: {speedup:.2f}x")
|
|
93
|
+
print(f" Baseline throughput: {num_prompts/baseline_time:.2f} prompts/s")
|
|
94
|
+
print(f" Data parallel throughput: {num_prompts/dp_time:.2f} prompts/s")
|
|
95
|
+
|
|
96
|
+
assert speedup >= tol, f"Data parallelism did not provide expected speedup ({tol:.2f}x): {speedup:.2f}x"
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _check_correctness(test_name, baseline_outputs, dp_outputs):
|
|
100
|
+
|
|
101
|
+
assert len(baseline_outputs) == len(dp_outputs)
|
|
102
|
+
|
|
103
|
+
text_matches = 0
|
|
104
|
+
logprob_matches = 0
|
|
105
|
+
total_compared_logprobs = 0
|
|
106
|
+
max_logprob_diff = 0.0
|
|
107
|
+
|
|
108
|
+
for i, (baseline, dp_result) in enumerate(zip(baseline_outputs,
|
|
109
|
+
dp_outputs)):
|
|
110
|
+
baseline_text = baseline.outputs[0].text.strip()
|
|
111
|
+
dp_text = dp_result.outputs[0].text.strip()
|
|
112
|
+
|
|
113
|
+
baseline_words = baseline_text.split()
|
|
114
|
+
dp_words = dp_text.split()
|
|
115
|
+
overlap_set = set(baseline_words) & set(dp_words)
|
|
116
|
+
match_percent = len(overlap_set) / len(set(baseline_words))
|
|
117
|
+
if match_percent >= 0.7:
|
|
118
|
+
text_matches += 1
|
|
119
|
+
|
|
120
|
+
# Check text output
|
|
121
|
+
if baseline_text != dp_text:
|
|
122
|
+
print(f"Text mismatch found in prompt {i}:")
|
|
123
|
+
print(f" Baseline: {baseline_text}")
|
|
124
|
+
print(f" Data Parallel: {dp_text}")
|
|
125
|
+
print(f" Match percent: {match_percent:.2%}")
|
|
126
|
+
|
|
127
|
+
# Check log probabilities
|
|
128
|
+
baseline_logprobs = baseline.outputs[0].logprobs
|
|
129
|
+
dp_logprobs = dp_result.outputs[0].logprobs
|
|
130
|
+
|
|
131
|
+
if baseline_logprobs is not None and dp_logprobs is not None:
|
|
132
|
+
# Compare log probabilities for each token
|
|
133
|
+
assert len(baseline_logprobs) == len(dp_logprobs), \
|
|
134
|
+
f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(dp_logprobs)}"
|
|
135
|
+
|
|
136
|
+
for token_idx, (base_lp, dp_lp) in enumerate(
|
|
137
|
+
zip(baseline_logprobs, dp_logprobs)):
|
|
138
|
+
# Get the top logprob value for the selected token
|
|
139
|
+
if base_lp and dp_lp:
|
|
140
|
+
# Get the top token's logprob from each
|
|
141
|
+
base_top_token = list(base_lp.keys())[0]
|
|
142
|
+
dp_top_token = list(dp_lp.keys())[0]
|
|
143
|
+
|
|
144
|
+
# Only compare logprobs if tokens match
|
|
145
|
+
if base_top_token == dp_top_token:
|
|
146
|
+
base_logprob_val = base_lp[base_top_token].logprob
|
|
147
|
+
dp_logprob_val = dp_lp[dp_top_token].logprob
|
|
148
|
+
|
|
149
|
+
# Calculate absolute difference
|
|
150
|
+
diff = abs(base_logprob_val - dp_logprob_val)
|
|
151
|
+
max_logprob_diff = max(max_logprob_diff, diff)
|
|
152
|
+
|
|
153
|
+
total_compared_logprobs += 1
|
|
154
|
+
# Count as match if difference is small
|
|
155
|
+
if diff < 0.1:
|
|
156
|
+
logprob_matches += 1
|
|
157
|
+
else:
|
|
158
|
+
print(
|
|
159
|
+
f" Logprob mismatch in prompt {i}, token {token_idx}: "
|
|
160
|
+
f"Baseline logprob={base_logprob_val}, "
|
|
161
|
+
f"Data Parallel logprob={dp_logprob_val}, "
|
|
162
|
+
f"Diff={diff:.6e}")
|
|
163
|
+
|
|
164
|
+
print(f"✓ {test_name} correctness test results:")
|
|
165
|
+
print(f" Text: {text_matches} matches (match percent >= 70%)")
|
|
166
|
+
print(
|
|
167
|
+
f" Logprobs: {logprob_matches}/{total_compared_logprobs} ({logprob_matches / total_compared_logprobs:.2%}) matches (diff < 0.1)"
|
|
168
|
+
)
|
|
169
|
+
print(f" Max logprob difference: {max_logprob_diff:.6e}")
|
|
170
|
+
|
|
171
|
+
# Allow for some variance due to potential numerical differences
|
|
172
|
+
# but most outputs should match with greedy sampling
|
|
173
|
+
text_match_rate = text_matches / len(baseline_outputs)
|
|
174
|
+
assert text_match_rate >= 0.9, f"Text match rate {text_match_rate:.2%} is too low"
|
|
175
|
+
|
|
176
|
+
# Log probabilities should match for most matching tokens
|
|
177
|
+
if total_compared_logprobs > 0:
|
|
178
|
+
logprob_match_rate = logprob_matches / total_compared_logprobs
|
|
179
|
+
assert logprob_match_rate >= 0.9, f"Logprob match rate {logprob_match_rate:.2%} is too low"
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def test_attention_data_parallelism(
|
|
183
|
+
test_prompts: list,
|
|
184
|
+
sampling_params: SamplingParams,
|
|
185
|
+
):
|
|
186
|
+
"""
|
|
187
|
+
Correctness and performance test for attention DP
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
os.environ['MODEL_IMPL_TYPE'] = "vllm"
|
|
191
|
+
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
|
192
|
+
|
|
193
|
+
# Configuration for long sequences
|
|
194
|
+
max_model_len = 2048
|
|
195
|
+
max_num_batched_tokens = 4096
|
|
196
|
+
max_num_seqs = 128
|
|
197
|
+
|
|
198
|
+
# Run with attn_dp=2 tp=2
|
|
199
|
+
dp_outputs, dp_time = _run_inference_with_config(
|
|
200
|
+
model_name=model_name,
|
|
201
|
+
test_prompts=test_prompts,
|
|
202
|
+
sampling_params=sampling_params,
|
|
203
|
+
tensor_parallel_size=4,
|
|
204
|
+
async_scheduling=False,
|
|
205
|
+
max_model_len=max_model_len,
|
|
206
|
+
max_num_batched_tokens=max_num_batched_tokens,
|
|
207
|
+
max_num_seqs=max_num_seqs,
|
|
208
|
+
additional_config={
|
|
209
|
+
"sharding": {
|
|
210
|
+
"sharding_strategy": {
|
|
211
|
+
"enable_dp_attention": 1
|
|
212
|
+
}
|
|
213
|
+
}
|
|
214
|
+
})
|
|
215
|
+
|
|
216
|
+
# Run baseline (tp=2)
|
|
217
|
+
baseline_outputs, baseline_time = _run_inference_with_config(
|
|
218
|
+
model_name=model_name,
|
|
219
|
+
test_prompts=test_prompts,
|
|
220
|
+
sampling_params=sampling_params,
|
|
221
|
+
tensor_parallel_size=2,
|
|
222
|
+
data_parallel_size=1,
|
|
223
|
+
async_scheduling=False,
|
|
224
|
+
max_model_len=max_model_len,
|
|
225
|
+
max_num_batched_tokens=max_num_batched_tokens,
|
|
226
|
+
max_num_seqs=max_num_seqs,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
_check_correctness("Attention data parallelism", baseline_outputs,
|
|
230
|
+
dp_outputs)
|
|
231
|
+
|
|
232
|
+
# Different hardware gives different performance. This test runs on v6e_8
|
|
233
|
+
_check_performance("Attention data parallelism",
|
|
234
|
+
baseline_time,
|
|
235
|
+
dp_time,
|
|
236
|
+
len(test_prompts),
|
|
237
|
+
tol=1.1)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def test_data_parallelism(
|
|
241
|
+
sampling_params: SamplingParams,
|
|
242
|
+
test_prompts: list,
|
|
243
|
+
):
|
|
244
|
+
"""
|
|
245
|
+
Correctness and performance test for model DP
|
|
246
|
+
"""
|
|
247
|
+
os.environ['MODEL_IMPL_TYPE'] = "flax_nnx"
|
|
248
|
+
|
|
249
|
+
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
|
250
|
+
|
|
251
|
+
# Configuration for long sequences
|
|
252
|
+
max_model_len = 2048
|
|
253
|
+
max_num_batched_tokens = 4096
|
|
254
|
+
max_num_seqs = 128
|
|
255
|
+
|
|
256
|
+
# Run with data parallelism (dp=2, tp=1)
|
|
257
|
+
dp_outputs, dp_time = _run_inference_with_config(
|
|
258
|
+
model_name=model_name,
|
|
259
|
+
test_prompts=test_prompts,
|
|
260
|
+
sampling_params=sampling_params,
|
|
261
|
+
tensor_parallel_size=1,
|
|
262
|
+
data_parallel_size=2,
|
|
263
|
+
async_scheduling=True,
|
|
264
|
+
max_model_len=max_model_len,
|
|
265
|
+
max_num_batched_tokens=max_num_batched_tokens,
|
|
266
|
+
max_num_seqs=max_num_seqs,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Run baseline (tp=1)
|
|
270
|
+
baseline_outputs, baseline_time = _run_inference_with_config(
|
|
271
|
+
model_name=model_name,
|
|
272
|
+
test_prompts=test_prompts,
|
|
273
|
+
sampling_params=sampling_params,
|
|
274
|
+
tensor_parallel_size=1,
|
|
275
|
+
data_parallel_size=1,
|
|
276
|
+
async_scheduling=True,
|
|
277
|
+
max_model_len=max_model_len,
|
|
278
|
+
max_num_batched_tokens=max_num_batched_tokens,
|
|
279
|
+
max_num_seqs=max_num_seqs,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
_check_correctness("Data parallelism", baseline_outputs, dp_outputs)
|
|
283
|
+
|
|
284
|
+
# Test is too small to see significant speedup, mainly for testing regression
|
|
285
|
+
_check_performance("Data parallelism",
|
|
286
|
+
baseline_time,
|
|
287
|
+
dp_time,
|
|
288
|
+
len(test_prompts),
|
|
289
|
+
tol=1.1)
|