tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import asdict
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
from vllm import LLM, EngineArgs, SamplingParams
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def model_name():
|
|
14
|
+
"""Choose LLama3 8b as the test model as it supports PP on jax model impl."""
|
|
15
|
+
return "meta-llama/Llama-3.1-8B-Instruct"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def test_prompts():
|
|
20
|
+
"""Simple test prompts for data parallelism testing."""
|
|
21
|
+
return [
|
|
22
|
+
"Hello, my name is",
|
|
23
|
+
"The capital of France is",
|
|
24
|
+
"The colors of the rainbow are",
|
|
25
|
+
"The future of AI is",
|
|
26
|
+
"The president of the United States is",
|
|
27
|
+
"How many players are on a standard soccer team?",
|
|
28
|
+
"In Greek mythology, who is the god of the sea?",
|
|
29
|
+
"What is the capital of Australia?",
|
|
30
|
+
"What is the largest planet in our solar system?",
|
|
31
|
+
"Who developed the theory of general relativity?",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture
|
|
36
|
+
def sampling_params():
|
|
37
|
+
"""Standard sampling parameters for testing."""
|
|
38
|
+
return SamplingParams(
|
|
39
|
+
temperature=0.0,
|
|
40
|
+
max_tokens=32,
|
|
41
|
+
ignore_eos=True,
|
|
42
|
+
logprobs=1,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _run_inference_with_config(model_name: str,
|
|
47
|
+
test_prompts: list,
|
|
48
|
+
sampling_params: SamplingParams,
|
|
49
|
+
tensor_parallel_size: int = 1,
|
|
50
|
+
pipeline_parallel_size: int = 1,
|
|
51
|
+
additional_config: dict = {},
|
|
52
|
+
kv_cache_dtype: str = "auto",
|
|
53
|
+
enable_prefix_caching: bool = False) -> list:
|
|
54
|
+
"""Helper function to run inference with specified configuration."""
|
|
55
|
+
|
|
56
|
+
# Create LLM args using parser-based approach similar to offline_inference.py
|
|
57
|
+
engine_args = EngineArgs(
|
|
58
|
+
model=model_name,
|
|
59
|
+
max_model_len=128,
|
|
60
|
+
tensor_parallel_size=tensor_parallel_size,
|
|
61
|
+
pipeline_parallel_size=pipeline_parallel_size,
|
|
62
|
+
gpu_memory_utilization=0.95,
|
|
63
|
+
max_num_batched_tokens=128,
|
|
64
|
+
max_num_seqs=16,
|
|
65
|
+
enable_prefix_caching=enable_prefix_caching,
|
|
66
|
+
additional_config=additional_config,
|
|
67
|
+
kv_cache_dtype=kv_cache_dtype,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
engine_args_dict = asdict(engine_args)
|
|
71
|
+
llm = LLM(**engine_args_dict)
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
outputs = llm.generate(test_prompts, sampling_params)
|
|
75
|
+
return outputs
|
|
76
|
+
finally:
|
|
77
|
+
del llm
|
|
78
|
+
# Wait for TPUs to be released
|
|
79
|
+
time.sleep(5)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.mark.skip(reason="PP is not fully enabled.")
|
|
83
|
+
def test_pipeline_parallelism_jax_model(
|
|
84
|
+
model_name: str,
|
|
85
|
+
test_prompts: list,
|
|
86
|
+
sampling_params: SamplingParams,
|
|
87
|
+
):
|
|
88
|
+
"""
|
|
89
|
+
Test pipline parallelism works on Jax models
|
|
90
|
+
|
|
91
|
+
Equivalent to:
|
|
92
|
+
python examples/offline_inference.py --tensor_parallel_size=1 --pipeline_parallel_size=2
|
|
93
|
+
"""
|
|
94
|
+
# Test with pipeline parallelism enabled
|
|
95
|
+
outputs = _run_inference_with_config(
|
|
96
|
+
model_name=model_name,
|
|
97
|
+
test_prompts=test_prompts,
|
|
98
|
+
sampling_params=sampling_params,
|
|
99
|
+
tensor_parallel_size=1,
|
|
100
|
+
pipeline_parallel_size=2,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Verify we got outputs for all prompts
|
|
104
|
+
assert len(outputs) == len(test_prompts)
|
|
105
|
+
|
|
106
|
+
# Verify each output has generated text
|
|
107
|
+
for output in outputs:
|
|
108
|
+
assert len(output.outputs) > 0
|
|
109
|
+
assert len(output.outputs[0].text.strip()) > 0
|
|
110
|
+
|
|
111
|
+
print(
|
|
112
|
+
f"✓ Pipeline Parallelism Jax model test passed with {len(outputs)} outputs"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@pytest.mark.skip(reason="PP is not fully enabled.")
|
|
117
|
+
def test_pipeline_parallelism_vllm_model(
|
|
118
|
+
model_name: str,
|
|
119
|
+
test_prompts: list,
|
|
120
|
+
sampling_params: SamplingParams,
|
|
121
|
+
):
|
|
122
|
+
"""
|
|
123
|
+
Test pipline parallelism works on vLLM models, and it also works with
|
|
124
|
+
with tensor parallelism.
|
|
125
|
+
|
|
126
|
+
Equivalent to:
|
|
127
|
+
MODEL_IMPL_TYPE=vllm python examples/offline_inference.py --tensor_parallel_size=1 --pipeline_parallel_size=2
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
os.environ['MODEL_IMPL_TYPE'] = 'vllm'
|
|
131
|
+
# Test with data parallelism enabled
|
|
132
|
+
outputs = _run_inference_with_config(
|
|
133
|
+
model_name=model_name,
|
|
134
|
+
test_prompts=test_prompts,
|
|
135
|
+
sampling_params=sampling_params,
|
|
136
|
+
tensor_parallel_size=1,
|
|
137
|
+
pipeline_parallel_size=2,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Verify we got outputs for all prompts
|
|
141
|
+
assert len(outputs) == len(test_prompts)
|
|
142
|
+
|
|
143
|
+
# Verify each output has generated text
|
|
144
|
+
for output in outputs:
|
|
145
|
+
assert len(output.outputs) > 0
|
|
146
|
+
assert len(output.outputs[0].text.strip()) > 0
|
|
147
|
+
|
|
148
|
+
print(
|
|
149
|
+
f"✓ Pipeline Parallelism vLLM model test passed with {len(outputs)} outputs"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@pytest.mark.skip(reason="PP is not fully enabled.")
|
|
154
|
+
def test_pipeline_parallelism_jax_model_correctness(
|
|
155
|
+
model_name: str,
|
|
156
|
+
test_prompts: list,
|
|
157
|
+
sampling_params: SamplingParams,
|
|
158
|
+
):
|
|
159
|
+
"""
|
|
160
|
+
Test that pipeline parallelism produces consistent results compared to a baseline.
|
|
161
|
+
This test compares outputs from a single-device run with pipeline parallel runs
|
|
162
|
+
to ensure correctness, including log probabilities.
|
|
163
|
+
"""
|
|
164
|
+
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
|
|
165
|
+
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
|
|
166
|
+
|
|
167
|
+
# Use a smaller subset of prompts for correctness testing
|
|
168
|
+
small_prompts = test_prompts[:10]
|
|
169
|
+
|
|
170
|
+
# Run baseline (no PP)
|
|
171
|
+
baseline_outputs = _run_inference_with_config(
|
|
172
|
+
model_name=model_name,
|
|
173
|
+
test_prompts=small_prompts,
|
|
174
|
+
sampling_params=sampling_params,
|
|
175
|
+
tensor_parallel_size=1,
|
|
176
|
+
pipeline_parallel_size=1,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Run with model data parallelism and async scheduling
|
|
180
|
+
pp_outputs = _run_inference_with_config(
|
|
181
|
+
model_name=model_name,
|
|
182
|
+
test_prompts=small_prompts,
|
|
183
|
+
sampling_params=sampling_params,
|
|
184
|
+
tensor_parallel_size=1,
|
|
185
|
+
pipeline_parallel_size=2,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Compare outputs - in theory they should be identical for greedy sampling
|
|
189
|
+
# in reality there may be some differences, but overall the outputs should
|
|
190
|
+
# be very similar.
|
|
191
|
+
|
|
192
|
+
# an example:
|
|
193
|
+
# prompt: What is the capital of Australia?
|
|
194
|
+
# both answers should be acceptable.
|
|
195
|
+
# The capital of Australia is Canberra. It is located in the Australian Capital Territory (ACT) and is home to many
|
|
196
|
+
# Canberra is the capital of Australia. It is located in the Australian Capital Territory (ACT) and is home to
|
|
197
|
+
assert len(baseline_outputs) == len(pp_outputs)
|
|
198
|
+
|
|
199
|
+
text_matches = 0
|
|
200
|
+
text_mismatches = 0
|
|
201
|
+
logprob_mismatches = 0
|
|
202
|
+
max_logprob_diff = 0.0
|
|
203
|
+
|
|
204
|
+
for i, (baseline, pp_result) in enumerate(zip(baseline_outputs,
|
|
205
|
+
pp_outputs)):
|
|
206
|
+
baseline_text = baseline.outputs[0].text.strip()
|
|
207
|
+
pp_text = pp_result.outputs[0].text.strip()
|
|
208
|
+
|
|
209
|
+
# Check text output
|
|
210
|
+
if baseline_text == pp_text:
|
|
211
|
+
text_matches += 1
|
|
212
|
+
else:
|
|
213
|
+
text_mismatches += 1
|
|
214
|
+
print(f"Text mismatch found in prompt {i}:")
|
|
215
|
+
print(f" Baseline: {baseline_text}")
|
|
216
|
+
print(f" Pipeline Parallel: {pp_text}")
|
|
217
|
+
|
|
218
|
+
# Check log probabilities
|
|
219
|
+
baseline_logprobs = baseline.outputs[0].logprobs
|
|
220
|
+
pp_logprobs = pp_result.outputs[0].logprobs
|
|
221
|
+
if baseline_logprobs is not None and pp_logprobs is not None:
|
|
222
|
+
# Compare log probabilities for each token
|
|
223
|
+
assert len(baseline_logprobs) == len(pp_logprobs), \
|
|
224
|
+
f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(pp_logprobs)}"
|
|
225
|
+
for token_idx, (base_lp, pp_lp) in enumerate(
|
|
226
|
+
zip(baseline_logprobs, pp_logprobs)):
|
|
227
|
+
# Get the top logprob value for the selected token
|
|
228
|
+
if base_lp and pp_lp:
|
|
229
|
+
# Get the top token's logprob from each
|
|
230
|
+
base_top_token = list(base_lp.keys())[0]
|
|
231
|
+
pp_top_token = list(pp_lp.keys())[0]
|
|
232
|
+
|
|
233
|
+
base_logprob_val = base_lp[base_top_token].logprob
|
|
234
|
+
pp_logprob_val = pp_lp[pp_top_token].logprob
|
|
235
|
+
|
|
236
|
+
# Calculate absolute difference
|
|
237
|
+
diff = abs(base_logprob_val - pp_logprob_val)
|
|
238
|
+
max_logprob_diff = max(max_logprob_diff, diff)
|
|
239
|
+
|
|
240
|
+
# Allow small numerical differences (e.g., 1e-3)
|
|
241
|
+
if diff > 1e-3:
|
|
242
|
+
logprob_mismatches += 1
|
|
243
|
+
print(
|
|
244
|
+
f"Logprob mismatch in prompt {i}, token {token_idx}:"
|
|
245
|
+
)
|
|
246
|
+
print(
|
|
247
|
+
f" Baseline token: {base_top_token}, logprob: {base_logprob_val:.6f}"
|
|
248
|
+
)
|
|
249
|
+
print(
|
|
250
|
+
f" PP token: {pp_top_token}, logprob: {pp_logprob_val:.6f}"
|
|
251
|
+
)
|
|
252
|
+
print(f" Difference: {diff:.6f}")
|
|
253
|
+
|
|
254
|
+
print("✓ Correctness test results:")
|
|
255
|
+
print(f" Text: {text_matches} matches, {text_mismatches} mismatches")
|
|
256
|
+
print(f" Max logprob difference: {max_logprob_diff:.6e}")
|
|
257
|
+
print(f" Significant logprob mismatches (>1e-3): {logprob_mismatches}")
|
|
258
|
+
|
|
259
|
+
# Allow for some variance due to potential numerical differences
|
|
260
|
+
# but most outputs should match with greedy sampling
|
|
261
|
+
text_match_rate = text_matches / len(baseline_outputs)
|
|
262
|
+
assert text_match_rate >= 0.9, f"Text match rate {text_match_rate:.2%} is too low"
|
|
263
|
+
|
|
264
|
+
# Log probabilities should be very close (allow small numerical errors)
|
|
265
|
+
assert max_logprob_diff < 1, f"Max logprob difference {max_logprob_diff} is too large"
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# This file contains end-to-end tests for the RunAI Model Streamer loader.
|
|
16
|
+
#
|
|
17
|
+
# The RunAI Model Streamer is a high-performance model loader that serves as an
|
|
18
|
+
# alternative to the default Hugging Face loader. Instead of downloading a model
|
|
19
|
+
# to local disk, it streams the weights from object storage (like GCS) into
|
|
20
|
+
# GPU memory. This streaming process is significantly faster than the
|
|
21
|
+
# traditional disk-based loading method.
|
|
22
|
+
|
|
23
|
+
# The tests in this file verify that loading model weights using the
|
|
24
|
+
# streamer produces the same results as loading the same model using the
|
|
25
|
+
# standard Hugging Face loader. This ensures the correctness of the streamer
|
|
26
|
+
# integration.
|
|
27
|
+
|
|
28
|
+
# The tests are performed by:
|
|
29
|
+
# 1. Loading a model from Google Cloud Storage using the `runai_streamer` format.
|
|
30
|
+
# 2. Generating output with this model.
|
|
31
|
+
# 3. Loading the same model from Hugging Face using the default loader.
|
|
32
|
+
# 4. Generating output with this second model.
|
|
33
|
+
# 5. Asserting that the outputs from both models are identical.
|
|
34
|
+
|
|
35
|
+
from __future__ import annotations
|
|
36
|
+
|
|
37
|
+
import time
|
|
38
|
+
|
|
39
|
+
import pytest
|
|
40
|
+
from vllm import LLM, SamplingParams
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@pytest.fixture
|
|
44
|
+
def sampling_config():
|
|
45
|
+
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=True)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@pytest.fixture
|
|
49
|
+
# TODO(amacaskill): Replace with GKE owned GCS bucket.
|
|
50
|
+
def gcs_model_name():
|
|
51
|
+
return "gs://vertex-model-garden-public-us/llama3/llama3-8b-hf"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.fixture
|
|
55
|
+
def hf_model_name():
|
|
56
|
+
return "meta-llama/Meta-Llama-3-8B"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.fixture
|
|
60
|
+
def prompt():
|
|
61
|
+
return "Hello, my name is"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def test_correctness(
|
|
65
|
+
sampling_config: SamplingParams,
|
|
66
|
+
gcs_model_name: str,
|
|
67
|
+
hf_model_name: str,
|
|
68
|
+
prompt: str,
|
|
69
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
70
|
+
):
|
|
71
|
+
'''
|
|
72
|
+
Compare the outputs of a model loaded from GCS via runai_model_streamer
|
|
73
|
+
and a model loaded from Hugging Face. The outputs should be the same.
|
|
74
|
+
These tests attempt to use tensor_parallel_size=1. The model is 16GB,
|
|
75
|
+
# and v6e has 32GB of HBM, so it will fit.
|
|
76
|
+
'''
|
|
77
|
+
# Set ENV variables so that runai_model_streamer uses anonymous GCS access.
|
|
78
|
+
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake-project")
|
|
79
|
+
monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true")
|
|
80
|
+
monkeypatch.setenv("CLOUD_STORAGE_EMULATOR_ENDPOINT",
|
|
81
|
+
"https://storage.googleapis.com")
|
|
82
|
+
gcs_llm = LLM(model=gcs_model_name,
|
|
83
|
+
load_format="runai_streamer",
|
|
84
|
+
max_model_len=128,
|
|
85
|
+
max_num_seqs=16,
|
|
86
|
+
max_num_batched_tokens=256)
|
|
87
|
+
gcs_outputs = gcs_llm.generate([prompt], sampling_config)
|
|
88
|
+
gcs_output_text = gcs_outputs[0].outputs[0].text
|
|
89
|
+
del gcs_llm
|
|
90
|
+
time.sleep(10) # Wait for TPUs to be released
|
|
91
|
+
|
|
92
|
+
# Test with Hugging Face model
|
|
93
|
+
hf_llm = LLM(model=hf_model_name,
|
|
94
|
+
max_model_len=128,
|
|
95
|
+
max_num_seqs=16,
|
|
96
|
+
max_num_batched_tokens=256)
|
|
97
|
+
hf_outputs = hf_llm.generate([prompt], sampling_config)
|
|
98
|
+
hf_output_text = hf_outputs[0].outputs[0].text
|
|
99
|
+
del hf_llm
|
|
100
|
+
time.sleep(10) # Wait for TPUs to be released
|
|
101
|
+
|
|
102
|
+
assert gcs_output_text == hf_output_text, (
|
|
103
|
+
f"Outputs do not match! "
|
|
104
|
+
f"GCS output: {gcs_output_text}, HF output: {hf_output_text}")
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# This file contains end-to-end tests for sampling parameters.
|
|
16
|
+
#
|
|
17
|
+
# Sampling parameters control how the model selects tokens during generation.
|
|
18
|
+
# These tests verify that temperature, top_p, top_k, and logprobs work correctly.
|
|
19
|
+
#
|
|
20
|
+
# The tests in this file verify that:
|
|
21
|
+
# 1. Temperature=0 produces deterministic (greedy) outputs
|
|
22
|
+
# 2. Higher temperature produces more varied outputs
|
|
23
|
+
# 3. top_p (nucleus sampling) correctly constrains token selection
|
|
24
|
+
# 4. top_k correctly limits the number of candidate tokens
|
|
25
|
+
# 5. logprobs returns probability information for generated tokens
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
import pytest
|
|
30
|
+
from vllm import LLM, SamplingParams
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.fixture(scope="module")
|
|
34
|
+
def llm():
|
|
35
|
+
"""Create a shared LLM instance for all tests in this module."""
|
|
36
|
+
return LLM(
|
|
37
|
+
model='meta-llama/Llama-3.2-1B-Instruct',
|
|
38
|
+
max_model_len=1024,
|
|
39
|
+
max_num_seqs=4,
|
|
40
|
+
enable_prefix_caching=False,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TestTemperature:
|
|
45
|
+
"""Tests for temperature sampling parameter."""
|
|
46
|
+
|
|
47
|
+
def test_temperature_zero_is_deterministic(self, llm: LLM):
|
|
48
|
+
"""Temperature=0 should produce identical outputs across multiple runs."""
|
|
49
|
+
prompt = "What is 2 + 2? Answer with just the number:"
|
|
50
|
+
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
|
51
|
+
|
|
52
|
+
outputs1 = llm.generate([prompt], sampling_params)
|
|
53
|
+
outputs2 = llm.generate([prompt], sampling_params)
|
|
54
|
+
|
|
55
|
+
assert outputs1[0].outputs[0].text == outputs2[0].outputs[0].text
|
|
56
|
+
|
|
57
|
+
def test_high_temperature_produces_variation(self, llm: LLM):
|
|
58
|
+
"""High temperature should produce varied outputs across multiple runs."""
|
|
59
|
+
prompt = "Write a random word:"
|
|
60
|
+
sampling_params = SamplingParams(temperature=2,
|
|
61
|
+
max_tokens=10,
|
|
62
|
+
top_k=4096)
|
|
63
|
+
|
|
64
|
+
# Run multiple times and collect unique outputs
|
|
65
|
+
unique_outputs = set()
|
|
66
|
+
num_runs = 10
|
|
67
|
+
for _ in range(num_runs):
|
|
68
|
+
outputs = llm.generate([prompt], sampling_params)
|
|
69
|
+
unique_outputs.add(outputs[0].outputs[0].text)
|
|
70
|
+
|
|
71
|
+
# With high temperature, we expect some variation
|
|
72
|
+
assert len(unique_outputs) > 1, (
|
|
73
|
+
"High temperature should produce varied outputs")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class TestTopP:
|
|
77
|
+
"""Tests for top_p (nucleus sampling) parameter."""
|
|
78
|
+
|
|
79
|
+
def test_top_p_restricts_sampling(self, llm: LLM):
|
|
80
|
+
"""top_p=1.0 vs lower values should affect output diversity."""
|
|
81
|
+
prompt = "Name a color:"
|
|
82
|
+
|
|
83
|
+
# With top_p=1.0 (consider all tokens)
|
|
84
|
+
sampling_params_full = SamplingParams(temperature=0.8,
|
|
85
|
+
top_p=1.0,
|
|
86
|
+
max_tokens=5)
|
|
87
|
+
|
|
88
|
+
# With top_p=0.1 (very restrictive, only top tokens)
|
|
89
|
+
sampling_params_restricted = SamplingParams(temperature=0.8,
|
|
90
|
+
top_p=0.1,
|
|
91
|
+
max_tokens=5)
|
|
92
|
+
|
|
93
|
+
# Collect outputs with full nucleus
|
|
94
|
+
full_outputs = set()
|
|
95
|
+
for _ in range(10):
|
|
96
|
+
outputs = llm.generate([prompt], sampling_params_full)
|
|
97
|
+
full_outputs.add(outputs[0].outputs[0].text)
|
|
98
|
+
|
|
99
|
+
# Collect outputs with restricted nucleus
|
|
100
|
+
restricted_outputs = set()
|
|
101
|
+
for _ in range(10):
|
|
102
|
+
outputs = llm.generate([prompt], sampling_params_restricted)
|
|
103
|
+
restricted_outputs.add(outputs[0].outputs[0].text)
|
|
104
|
+
|
|
105
|
+
# Restricted top_p should generally produce less variety
|
|
106
|
+
# (though this isn't guaranteed, it's a reasonable expectation)
|
|
107
|
+
assert len(
|
|
108
|
+
restricted_outputs) >= 1, "Should produce at least one output"
|
|
109
|
+
assert len(full_outputs) >= 1, "Should produce at least one output"
|
|
110
|
+
|
|
111
|
+
def test_top_p_with_temperature_zero(self, llm: LLM):
|
|
112
|
+
"""top_p should have no effect when temperature=0 (greedy)."""
|
|
113
|
+
prompt = "The capital of France is"
|
|
114
|
+
|
|
115
|
+
sampling_params_1 = SamplingParams(temperature=0,
|
|
116
|
+
top_p=0.1,
|
|
117
|
+
max_tokens=10)
|
|
118
|
+
sampling_params_2 = SamplingParams(temperature=0,
|
|
119
|
+
top_p=0.9,
|
|
120
|
+
max_tokens=10)
|
|
121
|
+
|
|
122
|
+
outputs1 = llm.generate([prompt], sampling_params_1)
|
|
123
|
+
outputs2 = llm.generate([prompt], sampling_params_2)
|
|
124
|
+
|
|
125
|
+
# Both should produce identical outputs since temperature=0
|
|
126
|
+
assert outputs1[0].outputs[0].text == outputs2[0].outputs[0].text
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class TestTopK:
|
|
130
|
+
"""Tests for top_k sampling parameter."""
|
|
131
|
+
|
|
132
|
+
def test_top_k_restricts_sampling(self, llm: LLM):
|
|
133
|
+
"""top_k should limit the candidate tokens for sampling."""
|
|
134
|
+
prompt = "Pick a number between 1 and 10:"
|
|
135
|
+
|
|
136
|
+
# top_k=1 is equivalent to greedy (always pick the most likely)
|
|
137
|
+
sampling_params_k1 = SamplingParams(temperature=1.0,
|
|
138
|
+
top_k=1,
|
|
139
|
+
max_tokens=5)
|
|
140
|
+
|
|
141
|
+
# top_k=-1 considers all tokens
|
|
142
|
+
sampling_params_all = SamplingParams(temperature=1.0,
|
|
143
|
+
top_k=-1,
|
|
144
|
+
max_tokens=5)
|
|
145
|
+
|
|
146
|
+
# With top_k=1, outputs should be deterministic
|
|
147
|
+
outputs_k1_run1 = llm.generate([prompt], sampling_params_k1)
|
|
148
|
+
outputs_k1_run2 = llm.generate([prompt], sampling_params_k1)
|
|
149
|
+
assert outputs_k1_run1[0].outputs[0].text == outputs_k1_run2[
|
|
150
|
+
0].outputs[0].text
|
|
151
|
+
|
|
152
|
+
# With top_k=-1 and temperature=1.0, we may see variation
|
|
153
|
+
all_outputs = set()
|
|
154
|
+
for _ in range(10):
|
|
155
|
+
outputs = llm.generate([prompt], sampling_params_all)
|
|
156
|
+
all_outputs.add(outputs[0].outputs[0].text)
|
|
157
|
+
|
|
158
|
+
# Should produce at least one valid output
|
|
159
|
+
assert len(all_outputs) >= 1
|
|
160
|
+
|
|
161
|
+
def test_top_k_with_temperature_zero(self, llm: LLM):
|
|
162
|
+
"""top_k should have no effect when temperature=0 (greedy)."""
|
|
163
|
+
prompt = "The largest planet is"
|
|
164
|
+
|
|
165
|
+
sampling_params_k5 = SamplingParams(temperature=0,
|
|
166
|
+
top_k=5,
|
|
167
|
+
max_tokens=10)
|
|
168
|
+
sampling_params_k50 = SamplingParams(temperature=0,
|
|
169
|
+
top_k=50,
|
|
170
|
+
max_tokens=10)
|
|
171
|
+
|
|
172
|
+
outputs1 = llm.generate([prompt], sampling_params_k5)
|
|
173
|
+
outputs2 = llm.generate([prompt], sampling_params_k50)
|
|
174
|
+
|
|
175
|
+
# Both should produce identical outputs since temperature=0
|
|
176
|
+
assert outputs1[0].outputs[0].text == outputs2[0].outputs[0].text
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class TestLogprobs:
|
|
180
|
+
"""Tests for logprobs parameter."""
|
|
181
|
+
|
|
182
|
+
def test_logprobs_returns_probabilities(self, llm: LLM):
|
|
183
|
+
"""logprobs parameter should return log probabilities for tokens."""
|
|
184
|
+
prompt = "Hello"
|
|
185
|
+
sampling_params = SamplingParams(temperature=0,
|
|
186
|
+
max_tokens=5,
|
|
187
|
+
logprobs=5)
|
|
188
|
+
|
|
189
|
+
outputs = llm.generate([prompt], sampling_params)
|
|
190
|
+
output = outputs[0].outputs[0]
|
|
191
|
+
|
|
192
|
+
# Check that logprobs are returned
|
|
193
|
+
assert output.logprobs is not None, "logprobs should be returned"
|
|
194
|
+
assert len(output.logprobs) > 0, "logprobs should contain entries"
|
|
195
|
+
|
|
196
|
+
# Each token should have logprob information
|
|
197
|
+
for token_logprobs in output.logprobs:
|
|
198
|
+
assert token_logprobs is not None
|
|
199
|
+
# Should have up to 5 top logprobs as requested
|
|
200
|
+
assert len(token_logprobs) <= 5
|
|
201
|
+
|
|
202
|
+
def test_logprobs_none_returns_no_probabilities(self, llm: LLM):
|
|
203
|
+
"""When logprobs=None, no log probabilities should be returned."""
|
|
204
|
+
prompt = "Hello"
|
|
205
|
+
sampling_params = SamplingParams(temperature=0,
|
|
206
|
+
max_tokens=5,
|
|
207
|
+
logprobs=None)
|
|
208
|
+
|
|
209
|
+
outputs = llm.generate([prompt], sampling_params)
|
|
210
|
+
output = outputs[0].outputs[0]
|
|
211
|
+
|
|
212
|
+
# logprobs should be None when not requested
|
|
213
|
+
assert output.logprobs is None, "logprobs should be None when not requested"
|
|
214
|
+
|
|
215
|
+
def test_logprobs_values_are_valid(self, llm: LLM):
|
|
216
|
+
"""Log probabilities should be valid (negative or zero)."""
|
|
217
|
+
prompt = "The sky is"
|
|
218
|
+
sampling_params = SamplingParams(temperature=0,
|
|
219
|
+
max_tokens=3,
|
|
220
|
+
logprobs=3)
|
|
221
|
+
|
|
222
|
+
outputs = llm.generate([prompt], sampling_params)
|
|
223
|
+
output = outputs[0].outputs[0]
|
|
224
|
+
|
|
225
|
+
assert output.logprobs is not None
|
|
226
|
+
for token_logprobs in output.logprobs:
|
|
227
|
+
for token_id, logprob_obj in token_logprobs.items():
|
|
228
|
+
# Log probabilities should be <= 0
|
|
229
|
+
assert logprob_obj.logprob <= 0, (
|
|
230
|
+
f"Log probability should be <= 0, got {logprob_obj.logprob}"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class TestCombinedParameters:
|
|
235
|
+
"""Tests for combinations of sampling parameters."""
|
|
236
|
+
|
|
237
|
+
def test_top_p_and_top_k_combined(self, llm: LLM):
|
|
238
|
+
"""top_p and top_k can be used together."""
|
|
239
|
+
prompt = "List a fruit:"
|
|
240
|
+
sampling_params = SamplingParams(
|
|
241
|
+
temperature=0.7,
|
|
242
|
+
top_p=0.9,
|
|
243
|
+
top_k=50,
|
|
244
|
+
max_tokens=10,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
outputs = llm.generate([prompt], sampling_params)
|
|
248
|
+
assert len(outputs[0].outputs[0].text) > 0
|
|
249
|
+
|
|
250
|
+
def test_all_params_with_logprobs(self, llm: LLM):
|
|
251
|
+
"""All sampling parameters should work together with logprobs."""
|
|
252
|
+
prompt = "Complete this sentence: The weather today is"
|
|
253
|
+
sampling_params = SamplingParams(
|
|
254
|
+
temperature=0.5,
|
|
255
|
+
top_p=0.95,
|
|
256
|
+
top_k=40,
|
|
257
|
+
max_tokens=10,
|
|
258
|
+
logprobs=3,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
outputs = llm.generate([prompt], sampling_params)
|
|
262
|
+
output = outputs[0].outputs[0]
|
|
263
|
+
|
|
264
|
+
# Should have generated text
|
|
265
|
+
assert len(output.text) > 0
|
|
266
|
+
|
|
267
|
+
# Should have logprobs
|
|
268
|
+
assert output.logprobs is not None
|
|
269
|
+
assert len(output.logprobs) > 0
|