tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +46 -17
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +44 -17
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
- tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# This file contains end-to-end tests for the RunAI Model Streamer loader.
|
|
16
|
+
#
|
|
17
|
+
# The RunAI Model Streamer is a high-performance model loader that serves as an
|
|
18
|
+
# alternative to the default Hugging Face loader. Instead of downloading a model
|
|
19
|
+
# to local disk, it streams the weights from object storage (like GCS) into
|
|
20
|
+
# GPU memory. This streaming process is significantly faster than the
|
|
21
|
+
# traditional disk-based loading method.
|
|
22
|
+
|
|
23
|
+
# The tests in this file verify that loading model weights using the
|
|
24
|
+
# streamer produces the same results as loading the same model using the
|
|
25
|
+
# standard Hugging Face loader. This ensures the correctness of the streamer
|
|
26
|
+
# integration.
|
|
27
|
+
|
|
28
|
+
# The tests are performed by:
|
|
29
|
+
# 1. Loading a model from Google Cloud Storage using the `runai_streamer` format.
|
|
30
|
+
# 2. Generating output with this model.
|
|
31
|
+
# 3. Loading the same model from Hugging Face using the default loader.
|
|
32
|
+
# 4. Generating output with this second model.
|
|
33
|
+
# 5. Asserting that the outputs from both models are identical.
|
|
34
|
+
|
|
35
|
+
from __future__ import annotations
|
|
36
|
+
|
|
37
|
+
import time
|
|
38
|
+
|
|
39
|
+
import pytest
|
|
40
|
+
from vllm import LLM, SamplingParams
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@pytest.fixture
|
|
44
|
+
def sampling_config():
|
|
45
|
+
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=True)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@pytest.fixture
|
|
49
|
+
# TODO(amacaskill): Replace with GKE owned GCS bucket.
|
|
50
|
+
def gcs_model_name():
|
|
51
|
+
return "gs://vertex-model-garden-public-us/llama3/llama3-8b-hf"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.fixture
|
|
55
|
+
def hf_model_name():
|
|
56
|
+
return "meta-llama/Meta-Llama-3-8B"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.fixture
|
|
60
|
+
def prompt():
|
|
61
|
+
return "Hello, my name is"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def test_correctness(
|
|
65
|
+
sampling_config: SamplingParams,
|
|
66
|
+
gcs_model_name: str,
|
|
67
|
+
hf_model_name: str,
|
|
68
|
+
prompt: str,
|
|
69
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
70
|
+
):
|
|
71
|
+
'''
|
|
72
|
+
Compare the outputs of a model loaded from GCS via runai_model_streamer
|
|
73
|
+
and a model loaded from Hugging Face. The outputs should be the same.
|
|
74
|
+
These tests attempt to use tensor_parallel_size=1. The model is 16GB,
|
|
75
|
+
# and v6e has 32GB of HBM, so it will fit.
|
|
76
|
+
'''
|
|
77
|
+
# Set ENV variables so that runai_model_streamer uses anonymous GCS access.
|
|
78
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake-project")
|
|
79
|
+
monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true")
|
|
80
|
+
monkeypatch.setenv("CLOUD_STORAGE_EMULATOR_ENDPOINT",
|
|
81
|
+
"https://storage.googleapis.com")
|
|
82
|
+
gcs_llm = LLM(model=gcs_model_name,
|
|
83
|
+
load_format="runai_streamer",
|
|
84
|
+
max_model_len=128,
|
|
85
|
+
max_num_seqs=16,
|
|
86
|
+
max_num_batched_tokens=256)
|
|
87
|
+
gcs_outputs = gcs_llm.generate([prompt], sampling_config)
|
|
88
|
+
gcs_output_text = gcs_outputs[0].outputs[0].text
|
|
89
|
+
del gcs_llm
|
|
90
|
+
time.sleep(10) # Wait for TPUs to be released
|
|
91
|
+
|
|
92
|
+
# Test with Hugging Face model
|
|
93
|
+
hf_llm = LLM(model=hf_model_name,
|
|
94
|
+
max_model_len=128,
|
|
95
|
+
max_num_seqs=16,
|
|
96
|
+
max_num_batched_tokens=256)
|
|
97
|
+
hf_outputs = hf_llm.generate([prompt], sampling_config)
|
|
98
|
+
hf_output_text = hf_outputs[0].outputs[0].text
|
|
99
|
+
del hf_llm
|
|
100
|
+
time.sleep(10) # Wait for TPUs to be released
|
|
101
|
+
|
|
102
|
+
assert gcs_output_text == hf_output_text, (
|
|
103
|
+
f"Outputs do not match! "
|
|
104
|
+
f"GCS output: {gcs_output_text}, HF output: {hf_output_text}")
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# This file contains end-to-end tests for sampling parameters.
|
|
16
|
+
#
|
|
17
|
+
# Sampling parameters control how the model selects tokens during generation.
|
|
18
|
+
# These tests verify that temperature, top_p, top_k, and logprobs work correctly.
|
|
19
|
+
#
|
|
20
|
+
# The tests in this file verify that:
|
|
21
|
+
# 1. Temperature=0 produces deterministic (greedy) outputs
|
|
22
|
+
# 2. Higher temperature produces more varied outputs
|
|
23
|
+
# 3. top_p (nucleus sampling) correctly constrains token selection
|
|
24
|
+
# 4. top_k correctly limits the number of candidate tokens
|
|
25
|
+
# 5. logprobs returns probability information for generated tokens
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
import pytest
|
|
30
|
+
from vllm import LLM, SamplingParams
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.fixture(scope="module")
|
|
34
|
+
def llm():
|
|
35
|
+
"""Create a shared LLM instance for all tests in this module."""
|
|
36
|
+
return LLM(
|
|
37
|
+
model='meta-llama/Llama-3.2-1B-Instruct',
|
|
38
|
+
max_model_len=1024,
|
|
39
|
+
max_num_seqs=4,
|
|
40
|
+
enable_prefix_caching=False,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TestTemperature:
|
|
45
|
+
"""Tests for temperature sampling parameter."""
|
|
46
|
+
|
|
47
|
+
def test_temperature_zero_is_deterministic(self, llm: LLM):
|
|
48
|
+
"""Temperature=0 should produce identical outputs across multiple runs."""
|
|
49
|
+
prompt = "What is 2 + 2? Answer with just the number:"
|
|
50
|
+
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
|
51
|
+
|
|
52
|
+
outputs1 = llm.generate([prompt], sampling_params)
|
|
53
|
+
outputs2 = llm.generate([prompt], sampling_params)
|
|
54
|
+
|
|
55
|
+
assert outputs1[0].outputs[0].text == outputs2[0].outputs[0].text
|
|
56
|
+
|
|
57
|
+
def test_high_temperature_produces_variation(self, llm: LLM):
|
|
58
|
+
"""High temperature should produce varied outputs across multiple runs."""
|
|
59
|
+
prompt = "Write a random word:"
|
|
60
|
+
sampling_params = SamplingParams(temperature=2,
|
|
61
|
+
max_tokens=10,
|
|
62
|
+
top_k=4096)
|
|
63
|
+
|
|
64
|
+
# Run multiple times and collect unique outputs
|
|
65
|
+
unique_outputs = set()
|
|
66
|
+
num_runs = 10
|
|
67
|
+
for _ in range(num_runs):
|
|
68
|
+
outputs = llm.generate([prompt], sampling_params)
|
|
69
|
+
unique_outputs.add(outputs[0].outputs[0].text)
|
|
70
|
+
|
|
71
|
+
# With high temperature, we expect some variation
|
|
72
|
+
assert len(unique_outputs) > 1, (
|
|
73
|
+
"High temperature should produce varied outputs")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class TestTopP:
|
|
77
|
+
"""Tests for top_p (nucleus sampling) parameter."""
|
|
78
|
+
|
|
79
|
+
def test_top_p_restricts_sampling(self, llm: LLM):
|
|
80
|
+
"""top_p=1.0 vs lower values should affect output diversity."""
|
|
81
|
+
prompt = "Name a color:"
|
|
82
|
+
|
|
83
|
+
# With top_p=1.0 (consider all tokens)
|
|
84
|
+
sampling_params_full = SamplingParams(temperature=0.8,
|
|
85
|
+
top_p=1.0,
|
|
86
|
+
max_tokens=5)
|
|
87
|
+
|
|
88
|
+
# With top_p=0.1 (very restrictive, only top tokens)
|
|
89
|
+
sampling_params_restricted = SamplingParams(temperature=0.8,
|
|
90
|
+
top_p=0.1,
|
|
91
|
+
max_tokens=5)
|
|
92
|
+
|
|
93
|
+
# Collect outputs with full nucleus
|
|
94
|
+
full_outputs = set()
|
|
95
|
+
for _ in range(10):
|
|
96
|
+
outputs = llm.generate([prompt], sampling_params_full)
|
|
97
|
+
full_outputs.add(outputs[0].outputs[0].text)
|
|
98
|
+
|
|
99
|
+
# Collect outputs with restricted nucleus
|
|
100
|
+
restricted_outputs = set()
|
|
101
|
+
for _ in range(10):
|
|
102
|
+
outputs = llm.generate([prompt], sampling_params_restricted)
|
|
103
|
+
restricted_outputs.add(outputs[0].outputs[0].text)
|
|
104
|
+
|
|
105
|
+
# Restricted top_p should generally produce less variety
|
|
106
|
+
# (though this isn't guaranteed, it's a reasonable expectation)
|
|
107
|
+
assert len(
|
|
108
|
+
restricted_outputs) >= 1, "Should produce at least one output"
|
|
109
|
+
assert len(full_outputs) >= 1, "Should produce at least one output"
|
|
110
|
+
|
|
111
|
+
def test_top_p_with_temperature_zero(self, llm: LLM):
|
|
112
|
+
"""top_p should have no effect when temperature=0 (greedy)."""
|
|
113
|
+
prompt = "The capital of France is"
|
|
114
|
+
|
|
115
|
+
sampling_params_1 = SamplingParams(temperature=0,
|
|
116
|
+
top_p=0.1,
|
|
117
|
+
max_tokens=10)
|
|
118
|
+
sampling_params_2 = SamplingParams(temperature=0,
|
|
119
|
+
top_p=0.9,
|
|
120
|
+
max_tokens=10)
|
|
121
|
+
|
|
122
|
+
outputs1 = llm.generate([prompt], sampling_params_1)
|
|
123
|
+
outputs2 = llm.generate([prompt], sampling_params_2)
|
|
124
|
+
|
|
125
|
+
# Both should produce identical outputs since temperature=0
|
|
126
|
+
assert outputs1[0].outputs[0].text == outputs2[0].outputs[0].text
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class TestTopK:
|
|
130
|
+
"""Tests for top_k sampling parameter."""
|
|
131
|
+
|
|
132
|
+
def test_top_k_restricts_sampling(self, llm: LLM):
|
|
133
|
+
"""top_k should limit the candidate tokens for sampling."""
|
|
134
|
+
prompt = "Pick a number between 1 and 10:"
|
|
135
|
+
|
|
136
|
+
# top_k=1 is equivalent to greedy (always pick the most likely)
|
|
137
|
+
sampling_params_k1 = SamplingParams(temperature=1.0,
|
|
138
|
+
top_k=1,
|
|
139
|
+
max_tokens=5)
|
|
140
|
+
|
|
141
|
+
# top_k=-1 considers all tokens
|
|
142
|
+
sampling_params_all = SamplingParams(temperature=1.0,
|
|
143
|
+
top_k=-1,
|
|
144
|
+
max_tokens=5)
|
|
145
|
+
|
|
146
|
+
# With top_k=1, outputs should be deterministic
|
|
147
|
+
outputs_k1_run1 = llm.generate([prompt], sampling_params_k1)
|
|
148
|
+
outputs_k1_run2 = llm.generate([prompt], sampling_params_k1)
|
|
149
|
+
assert outputs_k1_run1[0].outputs[0].text == outputs_k1_run2[
|
|
150
|
+
0].outputs[0].text
|
|
151
|
+
|
|
152
|
+
# With top_k=-1 and temperature=1.0, we may see variation
|
|
153
|
+
all_outputs = set()
|
|
154
|
+
for _ in range(10):
|
|
155
|
+
outputs = llm.generate([prompt], sampling_params_all)
|
|
156
|
+
all_outputs.add(outputs[0].outputs[0].text)
|
|
157
|
+
|
|
158
|
+
# Should produce at least one valid output
|
|
159
|
+
assert len(all_outputs) >= 1
|
|
160
|
+
|
|
161
|
+
def test_top_k_with_temperature_zero(self, llm: LLM):
|
|
162
|
+
"""top_k should have no effect when temperature=0 (greedy)."""
|
|
163
|
+
prompt = "The largest planet is"
|
|
164
|
+
|
|
165
|
+
sampling_params_k5 = SamplingParams(temperature=0,
|
|
166
|
+
top_k=5,
|
|
167
|
+
max_tokens=10)
|
|
168
|
+
sampling_params_k50 = SamplingParams(temperature=0,
|
|
169
|
+
top_k=50,
|
|
170
|
+
max_tokens=10)
|
|
171
|
+
|
|
172
|
+
outputs1 = llm.generate([prompt], sampling_params_k5)
|
|
173
|
+
outputs2 = llm.generate([prompt], sampling_params_k50)
|
|
174
|
+
|
|
175
|
+
# Both should produce identical outputs since temperature=0
|
|
176
|
+
assert outputs1[0].outputs[0].text == outputs2[0].outputs[0].text
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class TestLogprobs:
|
|
180
|
+
"""Tests for logprobs parameter."""
|
|
181
|
+
|
|
182
|
+
def test_logprobs_returns_probabilities(self, llm: LLM):
|
|
183
|
+
"""logprobs parameter should return log probabilities for tokens."""
|
|
184
|
+
prompt = "Hello"
|
|
185
|
+
sampling_params = SamplingParams(temperature=0,
|
|
186
|
+
max_tokens=5,
|
|
187
|
+
logprobs=5)
|
|
188
|
+
|
|
189
|
+
outputs = llm.generate([prompt], sampling_params)
|
|
190
|
+
output = outputs[0].outputs[0]
|
|
191
|
+
|
|
192
|
+
# Check that logprobs are returned
|
|
193
|
+
assert output.logprobs is not None, "logprobs should be returned"
|
|
194
|
+
assert len(output.logprobs) > 0, "logprobs should contain entries"
|
|
195
|
+
|
|
196
|
+
# Each token should have logprob information
|
|
197
|
+
for token_logprobs in output.logprobs:
|
|
198
|
+
assert token_logprobs is not None
|
|
199
|
+
# Should have up to 5 top logprobs as requested
|
|
200
|
+
assert len(token_logprobs) <= 5
|
|
201
|
+
|
|
202
|
+
def test_logprobs_none_returns_no_probabilities(self, llm: LLM):
|
|
203
|
+
"""When logprobs=None, no log probabilities should be returned."""
|
|
204
|
+
prompt = "Hello"
|
|
205
|
+
sampling_params = SamplingParams(temperature=0,
|
|
206
|
+
max_tokens=5,
|
|
207
|
+
logprobs=None)
|
|
208
|
+
|
|
209
|
+
outputs = llm.generate([prompt], sampling_params)
|
|
210
|
+
output = outputs[0].outputs[0]
|
|
211
|
+
|
|
212
|
+
# logprobs should be None when not requested
|
|
213
|
+
assert output.logprobs is None, "logprobs should be None when not requested"
|
|
214
|
+
|
|
215
|
+
def test_logprobs_values_are_valid(self, llm: LLM):
|
|
216
|
+
"""Log probabilities should be valid (negative or zero)."""
|
|
217
|
+
prompt = "The sky is"
|
|
218
|
+
sampling_params = SamplingParams(temperature=0,
|
|
219
|
+
max_tokens=3,
|
|
220
|
+
logprobs=3)
|
|
221
|
+
|
|
222
|
+
outputs = llm.generate([prompt], sampling_params)
|
|
223
|
+
output = outputs[0].outputs[0]
|
|
224
|
+
|
|
225
|
+
assert output.logprobs is not None
|
|
226
|
+
for token_logprobs in output.logprobs:
|
|
227
|
+
for token_id, logprob_obj in token_logprobs.items():
|
|
228
|
+
# Log probabilities should be <= 0
|
|
229
|
+
assert logprob_obj.logprob <= 0, (
|
|
230
|
+
f"Log probability should be <= 0, got {logprob_obj.logprob}"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class TestCombinedParameters:
|
|
235
|
+
"""Tests for combinations of sampling parameters."""
|
|
236
|
+
|
|
237
|
+
def test_top_p_and_top_k_combined(self, llm: LLM):
|
|
238
|
+
"""top_p and top_k can be used together."""
|
|
239
|
+
prompt = "List a fruit:"
|
|
240
|
+
sampling_params = SamplingParams(
|
|
241
|
+
temperature=0.7,
|
|
242
|
+
top_p=0.9,
|
|
243
|
+
top_k=50,
|
|
244
|
+
max_tokens=10,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
outputs = llm.generate([prompt], sampling_params)
|
|
248
|
+
assert len(outputs[0].outputs[0].text) > 0
|
|
249
|
+
|
|
250
|
+
def test_all_params_with_logprobs(self, llm: LLM):
|
|
251
|
+
"""All sampling parameters should work together with logprobs."""
|
|
252
|
+
prompt = "Complete this sentence: The weather today is"
|
|
253
|
+
sampling_params = SamplingParams(
|
|
254
|
+
temperature=0.5,
|
|
255
|
+
top_p=0.95,
|
|
256
|
+
top_k=40,
|
|
257
|
+
max_tokens=10,
|
|
258
|
+
logprobs=3,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
outputs = llm.generate([prompt], sampling_params)
|
|
262
|
+
output = outputs[0].outputs[0]
|
|
263
|
+
|
|
264
|
+
# Should have generated text
|
|
265
|
+
assert len(output.text) > 0
|
|
266
|
+
|
|
267
|
+
# Should have logprobs
|
|
268
|
+
assert output.logprobs is not None
|
|
269
|
+
assert len(output.logprobs) > 0
|
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import random
|
|
19
|
+
import string
|
|
20
|
+
import time
|
|
21
|
+
|
|
22
|
+
import pytest
|
|
23
|
+
from vllm import LLM, SamplingParams
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# TODO (Qiliang Cui): remove this when XLA fixes the recursive jit call issue.
|
|
27
|
+
def _is_v7x():
|
|
28
|
+
# jax.devices() will hang so use IS_FOR_V7X to indicate the version.
|
|
29
|
+
return os.environ.get("IS_FOR_V7X", "false") == "true"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _get_tensor_parallel_size():
|
|
33
|
+
# Work around an XLA issue.
|
|
34
|
+
if _is_v7x():
|
|
35
|
+
return 2
|
|
36
|
+
return 1
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_ngram_test_prompts():
|
|
40
|
+
num_prompts = 100
|
|
41
|
+
prompts = []
|
|
42
|
+
|
|
43
|
+
for _ in range(num_prompts):
|
|
44
|
+
w = random.choice(list(string.ascii_lowercase))
|
|
45
|
+
prompts.append(
|
|
46
|
+
f"Keep repeating: {w} {w} {w} {w} {w} {w} {w} {w} {w} {w}")
|
|
47
|
+
|
|
48
|
+
return prompts
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_eagle3_test_prompts():
|
|
52
|
+
num_prompts = 100
|
|
53
|
+
prompts = []
|
|
54
|
+
|
|
55
|
+
for _ in range(num_prompts):
|
|
56
|
+
prompts.append(
|
|
57
|
+
"Predict the continuation of this sequence: 1 2 3 4 5 6 7 8")
|
|
58
|
+
|
|
59
|
+
return prompts
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_test_prompts(speculative_config: dict):
|
|
63
|
+
if speculative_config['method'] == 'ngram':
|
|
64
|
+
return get_ngram_test_prompts()
|
|
65
|
+
elif speculative_config['method'] == 'eagle3':
|
|
66
|
+
return get_eagle3_test_prompts()
|
|
67
|
+
else:
|
|
68
|
+
raise NotImplementedError(
|
|
69
|
+
f"{speculative_config['method']} is not supported yet.")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@pytest.fixture
|
|
73
|
+
def sampling_config():
|
|
74
|
+
return SamplingParams(temperature=0,
|
|
75
|
+
max_tokens=32,
|
|
76
|
+
ignore_eos=True,
|
|
77
|
+
repetition_penalty=1,
|
|
78
|
+
frequency_penalty=0,
|
|
79
|
+
presence_penalty=0,
|
|
80
|
+
min_p=0,
|
|
81
|
+
logprobs=None)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@pytest.fixture
|
|
85
|
+
def model_name():
|
|
86
|
+
return "Qwen/Qwen2.5-0.5B-Instruct"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# TODO(pooyam): run vLLM engine with InProcClient (`VLLM_ENABLE_V1_MULTIPROCESSING = 0`) mode to avoid TPU contention among processes.
|
|
90
|
+
def _test_correctness_helper(
|
|
91
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
92
|
+
sampling_config: SamplingParams,
|
|
93
|
+
model_name: str,
|
|
94
|
+
speculative_config: dict,
|
|
95
|
+
):
|
|
96
|
+
'''
|
|
97
|
+
Helper function to test ngram correctness.
|
|
98
|
+
Compare the outputs of a original LLM and a speculative LLM
|
|
99
|
+
should be the same when using ngram speculative decoding.
|
|
100
|
+
'''
|
|
101
|
+
with monkeypatch.context():
|
|
102
|
+
test_prompts = get_test_prompts(speculative_config)
|
|
103
|
+
|
|
104
|
+
ref_llm = LLM(model=model_name,
|
|
105
|
+
max_model_len=1024,
|
|
106
|
+
max_num_seqs=4,
|
|
107
|
+
tensor_parallel_size=_get_tensor_parallel_size())
|
|
108
|
+
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
|
|
109
|
+
|
|
110
|
+
del ref_llm
|
|
111
|
+
|
|
112
|
+
# Waiting for TPUs to be released.
|
|
113
|
+
time.sleep(10)
|
|
114
|
+
|
|
115
|
+
spec_llm = LLM(model=model_name,
|
|
116
|
+
speculative_config=speculative_config,
|
|
117
|
+
max_model_len=1024,
|
|
118
|
+
max_num_seqs=4,
|
|
119
|
+
tensor_parallel_size=_get_tensor_parallel_size())
|
|
120
|
+
spec_outputs = spec_llm.generate(test_prompts, sampling_config)
|
|
121
|
+
|
|
122
|
+
matches = 0
|
|
123
|
+
misses = 0
|
|
124
|
+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
|
125
|
+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
|
126
|
+
matches += 1
|
|
127
|
+
else:
|
|
128
|
+
misses += 1
|
|
129
|
+
print(f"ref_output: {ref_output.outputs[0].text}")
|
|
130
|
+
print(f"spec_output: {spec_output.outputs[0].text}")
|
|
131
|
+
|
|
132
|
+
assert misses == 0
|
|
133
|
+
del spec_llm
|
|
134
|
+
|
|
135
|
+
# Waiting for TPUs to be released.
|
|
136
|
+
time.sleep(10)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def test_ngram_correctness_greedy(
|
|
140
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
141
|
+
sampling_config: SamplingParams,
|
|
142
|
+
model_name: str,
|
|
143
|
+
):
|
|
144
|
+
'''
|
|
145
|
+
Compare the outputs of a original LLM and a speculative LLM
|
|
146
|
+
should be the same when using ngram speculative decoding with greedy sampling.
|
|
147
|
+
'''
|
|
148
|
+
_test_correctness_helper(
|
|
149
|
+
monkeypatch, sampling_config, model_name, {
|
|
150
|
+
"method": "ngram",
|
|
151
|
+
"prompt_lookup_max": 5,
|
|
152
|
+
"prompt_lookup_min": 3,
|
|
153
|
+
"num_speculative_tokens": 3,
|
|
154
|
+
})
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_ngram_correctness_random(
|
|
158
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
159
|
+
sampling_config: SamplingParams,
|
|
160
|
+
model_name: str,
|
|
161
|
+
):
|
|
162
|
+
'''
|
|
163
|
+
Compare the outputs of a original LLM and a speculative LLM
|
|
164
|
+
should be the same when using ngram speculative decoding with random sampling.
|
|
165
|
+
'''
|
|
166
|
+
# Modify sampling config for random sampling
|
|
167
|
+
sampling_config.temperature = 0.01
|
|
168
|
+
sampling_config.top_p = 0.9
|
|
169
|
+
sampling_config.top_k = 5
|
|
170
|
+
|
|
171
|
+
_test_correctness_helper(
|
|
172
|
+
monkeypatch, sampling_config, model_name, {
|
|
173
|
+
"method": "ngram",
|
|
174
|
+
"prompt_lookup_max": 5,
|
|
175
|
+
"prompt_lookup_min": 3,
|
|
176
|
+
"num_speculative_tokens": 3,
|
|
177
|
+
})
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _test_performance_helper(
|
|
181
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
182
|
+
sampling_config: SamplingParams,
|
|
183
|
+
speculative_config: dict,
|
|
184
|
+
min_speedup: float,
|
|
185
|
+
):
|
|
186
|
+
'''
|
|
187
|
+
Helper function to test speculative decoding performance.
|
|
188
|
+
Compares timing between reference LLM and speculative LLM using Llama 3 8B.
|
|
189
|
+
'''
|
|
190
|
+
model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
|
191
|
+
|
|
192
|
+
with monkeypatch.context():
|
|
193
|
+
# Use a smaller set of prompts for performance testing
|
|
194
|
+
test_prompts = get_test_prompts(speculative_config)
|
|
195
|
+
|
|
196
|
+
# Test reference LLM timing
|
|
197
|
+
ref_llm = LLM(model=model_name,
|
|
198
|
+
max_model_len=1024,
|
|
199
|
+
max_num_seqs=1,
|
|
200
|
+
enable_prefix_caching=False,
|
|
201
|
+
tensor_parallel_size=_get_tensor_parallel_size())
|
|
202
|
+
|
|
203
|
+
start_time = time.time()
|
|
204
|
+
_ = ref_llm.generate(test_prompts, sampling_config)
|
|
205
|
+
ref_time = time.time() - start_time
|
|
206
|
+
|
|
207
|
+
del ref_llm
|
|
208
|
+
|
|
209
|
+
# Waiting for TPUs to be released
|
|
210
|
+
time.sleep(10)
|
|
211
|
+
|
|
212
|
+
# Test speculative LLM timing with max_num_seqs=1
|
|
213
|
+
spec_llm = LLM(model=model_name,
|
|
214
|
+
speculative_config=speculative_config,
|
|
215
|
+
max_model_len=1024,
|
|
216
|
+
max_num_seqs=1,
|
|
217
|
+
tensor_parallel_size=_get_tensor_parallel_size(),
|
|
218
|
+
enable_prefix_caching=False)
|
|
219
|
+
|
|
220
|
+
start_time = time.time()
|
|
221
|
+
_ = spec_llm.generate(test_prompts, sampling_config)
|
|
222
|
+
spec_time = time.time() - start_time
|
|
223
|
+
|
|
224
|
+
del spec_llm
|
|
225
|
+
# Waiting for TPUs to be released
|
|
226
|
+
time.sleep(10)
|
|
227
|
+
|
|
228
|
+
speedup = ref_time / spec_time
|
|
229
|
+
print(f"Reference LLM time: {ref_time:.2f}s")
|
|
230
|
+
print(f"Speculative LLM time: {spec_time:.2f}s")
|
|
231
|
+
print(f"Speedup: {speedup:.2f}x")
|
|
232
|
+
|
|
233
|
+
# TODO(pooyam): Make this tighter once we have better performance.
|
|
234
|
+
assert speedup >= min_speedup, f"Expected at least {min_speedup}x speedup for {speculative_config['method']}, got {speedup:.2f}x"
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def test_ngram_performance_greedy(
|
|
238
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
239
|
+
sampling_config: SamplingParams,
|
|
240
|
+
):
|
|
241
|
+
'''
|
|
242
|
+
Test that speculative decoding provides significant performance improvement.
|
|
243
|
+
Compares timing between reference LLM and speculative LLM using Llama 3 8B.
|
|
244
|
+
Expects spec_llm to be at least 3.x faster than ref_llm.
|
|
245
|
+
'''
|
|
246
|
+
_test_performance_helper(
|
|
247
|
+
monkeypatch, sampling_config, {
|
|
248
|
+
"method": "ngram",
|
|
249
|
+
"prompt_lookup_max": 2,
|
|
250
|
+
"prompt_lookup_min": 2,
|
|
251
|
+
"num_speculative_tokens": 4,
|
|
252
|
+
}, 1.2 if _is_v7x() else 3.0)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def test_ngram_performance_random(
|
|
256
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
257
|
+
sampling_config: SamplingParams,
|
|
258
|
+
):
|
|
259
|
+
'''
|
|
260
|
+
Test that speculative decoding provides significant performance improvement.
|
|
261
|
+
Compares timing between reference LLM and speculative LLM using Llama 3 8B.
|
|
262
|
+
Expects spec_llm to be at least 3.x faster than ref_llm.
|
|
263
|
+
'''
|
|
264
|
+
sampling_config.temperature = 0.01
|
|
265
|
+
sampling_config.top_p = 0.9
|
|
266
|
+
sampling_config.top_k = 5
|
|
267
|
+
|
|
268
|
+
_test_performance_helper(
|
|
269
|
+
monkeypatch, sampling_config, {
|
|
270
|
+
"method": "ngram",
|
|
271
|
+
"prompt_lookup_max": 2,
|
|
272
|
+
"prompt_lookup_min": 2,
|
|
273
|
+
"num_speculative_tokens": 4,
|
|
274
|
+
}, 1.5 if _is_v7x() else 3.0)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def test_eagle3_correctness(
|
|
278
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
279
|
+
sampling_config: SamplingParams,
|
|
280
|
+
):
|
|
281
|
+
'''
|
|
282
|
+
Compare the outputs of a original LLM and a speculative LLM
|
|
283
|
+
should be the same when using eagle-3 speculative decoding.
|
|
284
|
+
'''
|
|
285
|
+
model_name = 'meta-llama/Meta-Llama-3-8B-Instruct'
|
|
286
|
+
|
|
287
|
+
_test_correctness_helper(
|
|
288
|
+
monkeypatch, sampling_config, model_name, {
|
|
289
|
+
'model': "unkmaster/EAGLE3-LLaMA3.1-Instruct-8B",
|
|
290
|
+
"num_speculative_tokens": 3,
|
|
291
|
+
"method": "eagle3",
|
|
292
|
+
"draft_tensor_parallel_size": 1
|
|
293
|
+
})
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def test_eagle3_performance(
|
|
297
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
298
|
+
sampling_config: SamplingParams,
|
|
299
|
+
):
|
|
300
|
+
'''
|
|
301
|
+
Test that speculative decoding provides significant performance improvement.
|
|
302
|
+
Compares timing between reference LLM and speculative LLM using Llama 3 8B.
|
|
303
|
+
Expects spec_llm to be at least 1.8 faster than ref_llm.
|
|
304
|
+
'''
|
|
305
|
+
_test_performance_helper(
|
|
306
|
+
monkeypatch, sampling_config, {
|
|
307
|
+
"method": "eagle3",
|
|
308
|
+
"model": "unkmaster/EAGLE3-LLaMA3.1-Instruct-8B",
|
|
309
|
+
"num_speculative_tokens": 2,
|
|
310
|
+
"draft_tensor_parallel_size": 1
|
|
311
|
+
}, 1.2 if _is_v7x() else 1.8)
|