tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import asdict
|
|
7
|
+
from unittest.mock import patch
|
|
8
|
+
|
|
9
|
+
import pytest
|
|
10
|
+
import vllm.envs as vllm_envs
|
|
11
|
+
from vllm import LLM, EngineArgs, SamplingParams
|
|
12
|
+
|
|
13
|
+
from tpu_inference.core.core_tpu import DisaggEngineCore, DisaggEngineCoreProc
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture
|
|
17
|
+
def test_prompts():
|
|
18
|
+
"""Simple test prompts for disaggregated serving testing."""
|
|
19
|
+
return [
|
|
20
|
+
"Hello, my name is",
|
|
21
|
+
"The capital of France is",
|
|
22
|
+
"The colors of the rainbow are",
|
|
23
|
+
"The future of AI is",
|
|
24
|
+
"The president of the United States is",
|
|
25
|
+
"How many players are on a standard soccer team on the field at one time?",
|
|
26
|
+
"In Greek mythology, who is the god of the sea?",
|
|
27
|
+
"In what year did the Titanic sink?",
|
|
28
|
+
"In which museum is the Mona Lisa displayed?",
|
|
29
|
+
"Mount Everest is located in which mountain range?",
|
|
30
|
+
"What ancient empire was ruled by Julius Caesar?",
|
|
31
|
+
"What are the four fundamental forces of nature?",
|
|
32
|
+
'What does "CPU" stand for?',
|
|
33
|
+
'What does "HTML" stand for?',
|
|
34
|
+
"What is the capital of Australia?",
|
|
35
|
+
"What is the chemical symbol for gold?",
|
|
36
|
+
"What is the currency of Switzerland?",
|
|
37
|
+
"What is the distance from the Earth to the Sun called?",
|
|
38
|
+
"What is the freezing point of water in Celsius?",
|
|
39
|
+
"What is the hardest known natural substance on Earth?",
|
|
40
|
+
"What is the largest planet in our solar system?",
|
|
41
|
+
"What is the longest river in the world?",
|
|
42
|
+
"What is the main function of the kidneys in the human body?",
|
|
43
|
+
"What is the main ingredient in guacamole?",
|
|
44
|
+
"What is the most spoken language in the world by number of native speakers?",
|
|
45
|
+
"What is the process by which plants use sunlight to create food?",
|
|
46
|
+
"Which country is known as the Land of the Rising Sun?",
|
|
47
|
+
"Who developed the theory of general relativity?",
|
|
48
|
+
'Who directed the original "Star Wars" trilogy?',
|
|
49
|
+
"Who is credited with inventing the telephone?",
|
|
50
|
+
"Who painted the ceiling of the Sistine Chapel?",
|
|
51
|
+
"Who was the first female Prime Minister of the United Kingdom?",
|
|
52
|
+
"Who was the first person to walk on the moon?",
|
|
53
|
+
"Who wrote the American Declaration of Independence?",
|
|
54
|
+
'Who wrote the novel "Pride and Prejudice"?',
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@pytest.fixture
|
|
59
|
+
def sampling_params():
|
|
60
|
+
"""Standard sampling parameters for testing."""
|
|
61
|
+
return SamplingParams(
|
|
62
|
+
temperature=0.0,
|
|
63
|
+
max_tokens=32,
|
|
64
|
+
ignore_eos=True,
|
|
65
|
+
logprobs=1,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_disaggregated_serving(test_prompts, sampling_params):
|
|
70
|
+
"""
|
|
71
|
+
Test disaggregated serving end-to-end.
|
|
72
|
+
|
|
73
|
+
Equivalent to:
|
|
74
|
+
PREFILL_SLICES=4 DECODE_SLICES=4 python examples/offline_inference.py \
|
|
75
|
+
--model=meta-llama/Meta-Llama-3.1-8B-Instruct --task=generate \
|
|
76
|
+
--max_model_len=2048 --tensor_parallel_size 4
|
|
77
|
+
"""
|
|
78
|
+
# Set environment variables for disaggregated serving
|
|
79
|
+
# Using 4 slices for prefill and 4 for decode as requested
|
|
80
|
+
# Note: The user example used PREFILL_SLICES=4 DECODE_SLICES=4
|
|
81
|
+
# But usually slices are specified as "2x2" or similar if they are TPU topology.
|
|
82
|
+
# However, disagg_utils.py _parse_slices handles "4" as well (1D).
|
|
83
|
+
# We will stick to the user's example values.
|
|
84
|
+
|
|
85
|
+
# We need to mock the environment variables for this test
|
|
86
|
+
with patch.dict(
|
|
87
|
+
os.environ, {
|
|
88
|
+
"PREFILL_SLICES": "4",
|
|
89
|
+
"DECODE_SLICES": "4",
|
|
90
|
+
"SKIP_JAX_PRECOMPILE": "1",
|
|
91
|
+
"VLLM_XLA_CHECK_RECOMPILATION": "0"
|
|
92
|
+
}):
|
|
93
|
+
# Patch the EngineCore classes to use Disagg versions
|
|
94
|
+
with patch("vllm.v1.engine.core.EngineCore", DisaggEngineCore), \
|
|
95
|
+
patch("vllm.v1.engine.core.EngineCoreProc", DisaggEngineCoreProc):
|
|
96
|
+
|
|
97
|
+
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
|
98
|
+
os.system(f"rm -rf {vllm_envs.VLLM_XLA_CACHE_PATH}/*")
|
|
99
|
+
engine_args = EngineArgs(
|
|
100
|
+
model=model_name,
|
|
101
|
+
max_model_len=2048,
|
|
102
|
+
tensor_parallel_size=4,
|
|
103
|
+
gpu_memory_utilization=0.90,
|
|
104
|
+
enforce_eager=False,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
llm = LLM(**asdict(engine_args))
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
outputs = llm.generate(test_prompts, sampling_params)
|
|
111
|
+
|
|
112
|
+
# Verify outputs
|
|
113
|
+
assert len(outputs) == len(test_prompts)
|
|
114
|
+
for output in outputs:
|
|
115
|
+
assert len(output.outputs) > 0
|
|
116
|
+
assert len(output.outputs[0].text.strip()) > 0
|
|
117
|
+
print(f"Prompt: {output.prompt!r}")
|
|
118
|
+
print(f"Generated: {output.outputs[0].text!r}")
|
|
119
|
+
|
|
120
|
+
finally:
|
|
121
|
+
del llm
|
|
122
|
+
time.sleep(10)
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _run_inference(model_name: str,
|
|
127
|
+
test_prompts: list,
|
|
128
|
+
sampling_params: SamplingParams,
|
|
129
|
+
tensor_parallel_size: int = 1,
|
|
130
|
+
is_disagg: bool = False,
|
|
131
|
+
prefill_slices: str = "4",
|
|
132
|
+
decode_slices: str = "4") -> list:
|
|
133
|
+
"""Helper function to run inference with specified configuration."""
|
|
134
|
+
|
|
135
|
+
# Define the inner execution logic
|
|
136
|
+
def run_inner():
|
|
137
|
+
engine_args = EngineArgs(
|
|
138
|
+
model=model_name,
|
|
139
|
+
max_model_len=2048,
|
|
140
|
+
tensor_parallel_size=tensor_parallel_size,
|
|
141
|
+
gpu_memory_utilization=0.90,
|
|
142
|
+
enforce_eager=False,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
llm = LLM(**asdict(engine_args))
|
|
146
|
+
try:
|
|
147
|
+
return llm.generate(test_prompts, sampling_params)
|
|
148
|
+
finally:
|
|
149
|
+
del llm
|
|
150
|
+
time.sleep(10)
|
|
151
|
+
pass
|
|
152
|
+
|
|
153
|
+
if is_disagg:
|
|
154
|
+
# Mock environment variables and patch classes for disagg
|
|
155
|
+
with patch.dict(
|
|
156
|
+
os.environ, {
|
|
157
|
+
"PREFILL_SLICES": prefill_slices,
|
|
158
|
+
"DECODE_SLICES": decode_slices,
|
|
159
|
+
"SKIP_JAX_PRECOMPILE": "1",
|
|
160
|
+
"VLLM_XLA_CHECK_RECOMPILATION": "0"
|
|
161
|
+
}):
|
|
162
|
+
with patch("vllm.v1.engine.core.EngineCore", DisaggEngineCore), \
|
|
163
|
+
patch("vllm.v1.engine.core.EngineCoreProc", DisaggEngineCoreProc):
|
|
164
|
+
return run_inner()
|
|
165
|
+
else:
|
|
166
|
+
# Run standard inference
|
|
167
|
+
# We still set some env vars to ensure consistent behavior if needed
|
|
168
|
+
# but for baseline we want it as standard as possible.
|
|
169
|
+
# However, to match the disagg run's potential jax settings:
|
|
170
|
+
with patch.dict(os.environ, {
|
|
171
|
+
"SKIP_JAX_PRECOMPILE": "1",
|
|
172
|
+
"VLLM_XLA_CHECK_RECOMPILATION": "0"
|
|
173
|
+
}):
|
|
174
|
+
return run_inner()
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def test_disaggregated_serving_correctness(test_prompts, sampling_params):
|
|
178
|
+
"""
|
|
179
|
+
Test that disaggregated serving produces consistent results compared to a baseline.
|
|
180
|
+
"""
|
|
181
|
+
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
|
182
|
+
# Use a smaller subset of prompts for correctness testing
|
|
183
|
+
small_prompts = test_prompts[:20]
|
|
184
|
+
sampling_params.max_tokens = 16
|
|
185
|
+
|
|
186
|
+
# Run baseline (standard execution)
|
|
187
|
+
# We use tensor_parallel_size=4 to match the disagg resources if we assume
|
|
188
|
+
# the user has enough chips, or if we are just mocking.
|
|
189
|
+
# Since the original test used tp=4, we stick to it.
|
|
190
|
+
print("Running Baseline Inference...")
|
|
191
|
+
baseline_outputs = _run_inference(model_name=model_name,
|
|
192
|
+
test_prompts=small_prompts,
|
|
193
|
+
sampling_params=sampling_params,
|
|
194
|
+
tensor_parallel_size=4,
|
|
195
|
+
is_disagg=False)
|
|
196
|
+
|
|
197
|
+
# Run disaggregated inference
|
|
198
|
+
os.system(f"rm -rf {vllm_envs.VLLM_XLA_CACHE_PATH}/*")
|
|
199
|
+
print("Running Disaggregated Inference...")
|
|
200
|
+
|
|
201
|
+
disagg_outputs = _run_inference(model_name=model_name,
|
|
202
|
+
test_prompts=small_prompts,
|
|
203
|
+
sampling_params=sampling_params,
|
|
204
|
+
tensor_parallel_size=4,
|
|
205
|
+
is_disagg=True,
|
|
206
|
+
prefill_slices="4",
|
|
207
|
+
decode_slices="4")
|
|
208
|
+
|
|
209
|
+
# Compare outputs
|
|
210
|
+
assert len(baseline_outputs) == len(disagg_outputs)
|
|
211
|
+
|
|
212
|
+
text_matches = 0
|
|
213
|
+
text_mismatches = 0
|
|
214
|
+
token_mismatches = 0
|
|
215
|
+
|
|
216
|
+
for i, (baseline,
|
|
217
|
+
disagg) in enumerate(zip(baseline_outputs, disagg_outputs)):
|
|
218
|
+
baseline_text = baseline.outputs[0].text.strip()
|
|
219
|
+
disagg_text = disagg.outputs[0].text.strip()
|
|
220
|
+
|
|
221
|
+
# Check text output
|
|
222
|
+
if baseline_text == disagg_text:
|
|
223
|
+
text_matches += 1
|
|
224
|
+
else:
|
|
225
|
+
text_mismatches += 1
|
|
226
|
+
print(f"Text mismatch found in prompt {i}:")
|
|
227
|
+
print(f" Baseline: {baseline_text}")
|
|
228
|
+
print(f" Disagg: {disagg_text}")
|
|
229
|
+
|
|
230
|
+
# Check log probabilities (tokens) if available
|
|
231
|
+
baseline_logprobs = baseline.outputs[0].logprobs
|
|
232
|
+
disagg_logprobs = disagg.outputs[0].logprobs
|
|
233
|
+
|
|
234
|
+
if baseline_logprobs is not None and disagg_logprobs is not None:
|
|
235
|
+
assert len(baseline_logprobs) == len(disagg_logprobs), \
|
|
236
|
+
f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(disagg_logprobs)}"
|
|
237
|
+
|
|
238
|
+
for token_idx, (base_lp, disagg_lp) in enumerate(
|
|
239
|
+
zip(baseline_logprobs, disagg_logprobs)):
|
|
240
|
+
if base_lp and disagg_lp:
|
|
241
|
+
# Compare the top token IDs
|
|
242
|
+
base_top_token = list(base_lp.keys())[0]
|
|
243
|
+
disagg_top_token = list(disagg_lp.keys())[0]
|
|
244
|
+
|
|
245
|
+
if base_top_token != disagg_top_token:
|
|
246
|
+
token_mismatches += 1
|
|
247
|
+
print(
|
|
248
|
+
f"Token mismatch in prompt {i}, token {token_idx}:"
|
|
249
|
+
)
|
|
250
|
+
print(f" Baseline: {base_top_token}")
|
|
251
|
+
print(f" Disagg: {disagg_top_token}")
|
|
252
|
+
|
|
253
|
+
print("✓ Correctness test results:")
|
|
254
|
+
print(f" Text: {text_matches} matches, {text_mismatches} mismatches")
|
|
255
|
+
print(f" Token mismatches in logprobs: {token_mismatches}")
|
|
256
|
+
assert text_mismatches <= 5, f"Found {text_mismatches} text mismatches"
|
|
257
|
+
assert token_mismatches <= 40, f"Found {token_mismatches} token mismatches"
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# tests/e2e/test_model_loader.py
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import re
|
|
19
|
+
import signal
|
|
20
|
+
import subprocess
|
|
21
|
+
import sys
|
|
22
|
+
import tempfile
|
|
23
|
+
import time
|
|
24
|
+
|
|
25
|
+
import pytest
|
|
26
|
+
import requests
|
|
27
|
+
import torch
|
|
28
|
+
from flax import nnx
|
|
29
|
+
from vllm.model_executor.models.registry import ModelRegistry
|
|
30
|
+
|
|
31
|
+
from tpu_inference.models.common.model_loader import (_MODEL_REGISTRY,
|
|
32
|
+
register_model)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture
|
|
36
|
+
def cleanup_registries():
|
|
37
|
+
"""Cleans up the model registries before and after each test."""
|
|
38
|
+
_MODEL_REGISTRY.clear()
|
|
39
|
+
# vLLM's ModelRegistry uses a class-level dictionary to store model classes.
|
|
40
|
+
# We need to clear it to ensure test isolation.
|
|
41
|
+
if hasattr(ModelRegistry, "models"):
|
|
42
|
+
ModelRegistry.models.clear()
|
|
43
|
+
yield
|
|
44
|
+
_MODEL_REGISTRY.clear()
|
|
45
|
+
if hasattr(ModelRegistry, "models"):
|
|
46
|
+
ModelRegistry.models.clear()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class DummyGoodModel(nnx.Module):
|
|
50
|
+
"""A valid model that conforms to the expected interface."""
|
|
51
|
+
|
|
52
|
+
def __init__(self, vllm_config=None, rng=None, mesh=None):
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
def __call__(self,
|
|
56
|
+
kv_caches=None,
|
|
57
|
+
input_ids=None,
|
|
58
|
+
attention_metadata=None):
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_register_model_success(cleanup_registries):
|
|
63
|
+
"""Tests that a valid model is registered successfully."""
|
|
64
|
+
arch = "DummyGoodModelForCausalLM"
|
|
65
|
+
register_model(arch, DummyGoodModel)
|
|
66
|
+
|
|
67
|
+
# Check tpu_inference registry
|
|
68
|
+
assert arch in _MODEL_REGISTRY
|
|
69
|
+
|
|
70
|
+
class MockModelConfig:
|
|
71
|
+
|
|
72
|
+
def __init__(self, architectures):
|
|
73
|
+
self.hf_config = self._MockHfConfig(architectures)
|
|
74
|
+
self.model_impl = "flax_nnx"
|
|
75
|
+
|
|
76
|
+
class _MockHfConfig:
|
|
77
|
+
|
|
78
|
+
def __init__(self, architectures):
|
|
79
|
+
self.architectures = architectures
|
|
80
|
+
|
|
81
|
+
model_config = MockModelConfig(architectures=[arch])
|
|
82
|
+
vllm_compatible_model, _ = ModelRegistry.resolve_model_cls(
|
|
83
|
+
architectures=[arch], model_config=model_config)
|
|
84
|
+
assert vllm_compatible_model is not None
|
|
85
|
+
assert issubclass(vllm_compatible_model, torch.nn.Module)
|
|
86
|
+
assert issubclass(vllm_compatible_model, DummyGoodModel)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
# Attempt to import vLLM's interface validation function
|
|
91
|
+
from vllm.model_executor.models.interfaces_base import is_vllm_model
|
|
92
|
+
VLLM_INTERFACE_CHECK_AVAILABLE = True
|
|
93
|
+
except ImportError:
|
|
94
|
+
VLLM_INTERFACE_CHECK_AVAILABLE = False
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pytest.mark.skipif(not VLLM_INTERFACE_CHECK_AVAILABLE,
|
|
98
|
+
reason="is_vllm_model could not be imported from vllm.")
|
|
99
|
+
def test_registered_model_passes_vllm_interface_check(cleanup_registries):
|
|
100
|
+
"""
|
|
101
|
+
Ensures the wrapped model passes vLLM's own interface validation.
|
|
102
|
+
|
|
103
|
+
This test is future-proof. If vLLM adds new requirements to its
|
|
104
|
+
model interface, this test will fail, signaling that the wrapper
|
|
105
|
+
in `register_model` needs to be updated.
|
|
106
|
+
"""
|
|
107
|
+
arch = "DummyGoodModelForCausalLM"
|
|
108
|
+
register_model(arch, DummyGoodModel)
|
|
109
|
+
|
|
110
|
+
class MockModelConfig:
|
|
111
|
+
|
|
112
|
+
def __init__(self, architectures):
|
|
113
|
+
self.hf_config = self._MockHfConfig(architectures)
|
|
114
|
+
self.model_impl = "flax_nnx"
|
|
115
|
+
|
|
116
|
+
class _MockHfConfig:
|
|
117
|
+
|
|
118
|
+
def __init__(self, architectures):
|
|
119
|
+
self.architectures = architectures
|
|
120
|
+
|
|
121
|
+
model_config = MockModelConfig(architectures=[arch])
|
|
122
|
+
vllm_compatible_model, _ = ModelRegistry.resolve_model_cls(
|
|
123
|
+
architectures=[arch], model_config=model_config)
|
|
124
|
+
|
|
125
|
+
# This directly uses vLLM's checker, so it's always up-to-date.
|
|
126
|
+
# We assume is_vllm_model returns True for a valid model, and either
|
|
127
|
+
# returns False or raises an exception for an invalid one.
|
|
128
|
+
assert is_vllm_model(vllm_compatible_model)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _run_server_and_bench(model_name: str, model_impl_type: str,
|
|
132
|
+
port: int) -> float:
|
|
133
|
+
env = os.environ.copy()
|
|
134
|
+
env["MODEL_IMPL_TYPE"] = model_impl_type
|
|
135
|
+
|
|
136
|
+
# Start server
|
|
137
|
+
server_cmd = [
|
|
138
|
+
sys.executable,
|
|
139
|
+
"-m",
|
|
140
|
+
"vllm.entrypoints.cli.main",
|
|
141
|
+
"serve",
|
|
142
|
+
model_name,
|
|
143
|
+
"--port",
|
|
144
|
+
str(port),
|
|
145
|
+
"--max-model-len",
|
|
146
|
+
"2048",
|
|
147
|
+
"--tensor-parallel-size",
|
|
148
|
+
"1",
|
|
149
|
+
"--disable-log-requests",
|
|
150
|
+
"--no-enable-prefix-caching",
|
|
151
|
+
"--gpu-memory-utilization",
|
|
152
|
+
"0.90",
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
print(f"Starting server ({model_impl_type}) on port {port}...")
|
|
156
|
+
# Use a new process group so we can kill the server and its children
|
|
157
|
+
# Use temporary files for stdout/stderr to avoid pipe buffer deadlocks
|
|
158
|
+
stdout_file = tempfile.TemporaryFile(mode='w+b')
|
|
159
|
+
stderr_file = tempfile.TemporaryFile(mode='w+b')
|
|
160
|
+
server_process = subprocess.Popen(server_cmd,
|
|
161
|
+
env=env,
|
|
162
|
+
stdout=stdout_file,
|
|
163
|
+
stderr=stderr_file,
|
|
164
|
+
preexec_fn=os.setsid)
|
|
165
|
+
|
|
166
|
+
try:
|
|
167
|
+
# Wait for server to be ready
|
|
168
|
+
start_time = time.time()
|
|
169
|
+
server_ready = False
|
|
170
|
+
while time.time() - start_time < 600: # 10 minutes timeout
|
|
171
|
+
try:
|
|
172
|
+
if requests.get(
|
|
173
|
+
f"http://localhost:{port}/health").status_code == 200:
|
|
174
|
+
server_ready = True
|
|
175
|
+
break
|
|
176
|
+
except requests.exceptions.RequestException:
|
|
177
|
+
pass
|
|
178
|
+
|
|
179
|
+
if server_process.poll() is not None:
|
|
180
|
+
stdout_file.seek(0)
|
|
181
|
+
stderr_file.seek(0)
|
|
182
|
+
stdout = stdout_file.read().decode("utf-8", errors="replace")
|
|
183
|
+
stderr = stderr_file.read().decode("utf-8", errors="replace")
|
|
184
|
+
raise RuntimeError(
|
|
185
|
+
f"Server process exited unexpectedly.\nStdout: {stdout}\nStderr: {stderr}"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
time.sleep(5)
|
|
189
|
+
|
|
190
|
+
if not server_ready:
|
|
191
|
+
stdout_file.seek(0)
|
|
192
|
+
stderr_file.seek(0)
|
|
193
|
+
stdout = stdout_file.read().decode("utf-8", errors="replace")
|
|
194
|
+
stderr = stderr_file.read().decode("utf-8", errors="replace")
|
|
195
|
+
raise RuntimeError(
|
|
196
|
+
f"Server failed to start within timeout.\nStdout: {stdout}\nStderr: {stderr}"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
print("Server is ready. Running benchmark...")
|
|
200
|
+
|
|
201
|
+
# Run benchmark
|
|
202
|
+
bench_cmd = [
|
|
203
|
+
"vllm", "bench", "serve", "--model", model_name, "--port",
|
|
204
|
+
str(port), "--dataset-name", "random", "--random-input-len", "50",
|
|
205
|
+
"--random-output-len", "128", "--num-prompts", "20"
|
|
206
|
+
]
|
|
207
|
+
|
|
208
|
+
result = subprocess.run(bench_cmd,
|
|
209
|
+
env=env,
|
|
210
|
+
capture_output=True,
|
|
211
|
+
text=True)
|
|
212
|
+
|
|
213
|
+
if result.returncode != 0:
|
|
214
|
+
raise RuntimeError(
|
|
215
|
+
f"Benchmark failed.\nStdout: {result.stdout}\nStderr: {result.stderr}"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Parse throughput
|
|
219
|
+
# Output example: "Request throughput (req/s): 12.34"
|
|
220
|
+
match = re.search(r"Request throughput \(req/s\):\s+([\d\.]+)",
|
|
221
|
+
result.stdout)
|
|
222
|
+
if not match:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"Could not parse throughput from output:\n{result.stdout}")
|
|
225
|
+
|
|
226
|
+
throughput = float(match.group(1))
|
|
227
|
+
return throughput
|
|
228
|
+
|
|
229
|
+
finally:
|
|
230
|
+
print("Stopping server...")
|
|
231
|
+
try:
|
|
232
|
+
os.killpg(os.getpgid(server_process.pid), signal.SIGTERM)
|
|
233
|
+
except ProcessLookupError:
|
|
234
|
+
pass
|
|
235
|
+
server_process.wait()
|
|
236
|
+
stdout_file.close()
|
|
237
|
+
stderr_file.close()
|
|
238
|
+
# Wait for TPU cleanup
|
|
239
|
+
time.sleep(5)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def test_flax_nnx_vs_vllm_performance():
|
|
243
|
+
"""
|
|
244
|
+
Compares the performance of flax_nnx and vllm model implementations.
|
|
245
|
+
|
|
246
|
+
This test ensures that the JAX-native (`flax_nnx`) implementation's
|
|
247
|
+
performance is not significantly different from the vLLM-native PyTorch
|
|
248
|
+
(`vllm`) implementation. It measures the request throughput for both
|
|
249
|
+
backends and asserts that the percentage
|
|
250
|
+
difference is within a reasonable threshold.
|
|
251
|
+
"""
|
|
252
|
+
model_name = "Qwen/Qwen3-4B"
|
|
253
|
+
# This should be 2-3% but 6% reduces flakiness.
|
|
254
|
+
percentage_difference_threshold = 0.06
|
|
255
|
+
|
|
256
|
+
throughput_vllm = _run_server_and_bench(model_name, "vllm", 8001)
|
|
257
|
+
throughput_flax = _run_server_and_bench(model_name, "flax_nnx", 8002)
|
|
258
|
+
|
|
259
|
+
print(f"vLLM (PyTorch) throughput: {throughput_vllm:.2f} req/s.")
|
|
260
|
+
print(f"flax_nnx (JAX) throughput: {throughput_flax:.2f} req/s.")
|
|
261
|
+
|
|
262
|
+
percentage_diff = abs(throughput_flax - throughput_vllm) / throughput_vllm
|
|
263
|
+
print(f"Percentage difference in throughput: {percentage_diff:.2%}.")
|
|
264
|
+
|
|
265
|
+
assert percentage_diff < percentage_difference_threshold, (
|
|
266
|
+
f"The performance difference between flax_nnx and vllm is too high. "
|
|
267
|
+
f"Difference: {percentage_diff:.2%}, Threshold: {percentage_difference_threshold:.2%}"
|
|
268
|
+
)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
#
|
|
3
|
+
# A simplified example to run multi-modal inference and verify the output.
|
|
4
|
+
# This script is a self-contained test that runs a single prompt and
|
|
5
|
+
# compares the output to a known-good output.
|
|
6
|
+
|
|
7
|
+
import difflib
|
|
8
|
+
import os
|
|
9
|
+
from dataclasses import asdict
|
|
10
|
+
|
|
11
|
+
import pytest
|
|
12
|
+
from vllm import LLM, EngineArgs, SamplingParams
|
|
13
|
+
from vllm.assets.image import ImageAsset
|
|
14
|
+
from vllm.multimodal.image import convert_image_mode
|
|
15
|
+
|
|
16
|
+
# Expected partial text output from the model. This is based on a previous
|
|
17
|
+
# run and is used for verification. The test is considered passed if the
|
|
18
|
+
# generated output match with this text.
|
|
19
|
+
EXPECTED_TEXT = (
|
|
20
|
+
"The image depicts a tall, cylindrical tower with a lattice-like structure, surrounded by cherry blossom trees in full bloom. The cherry blossoms are in various stages of opening, with pink petals covering the branches. The sky is clear and blue, providing a vibrant backdrop to the scene. The tower appears to be a significant landmark"
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# NOTE: Could be extended to more mm models/configs as needed
|
|
25
|
+
@pytest.mark.parametrize("enable_dynamic_image_sizes", [False, True])
|
|
26
|
+
def test_multi_modal_inference(monkeypatch, enable_dynamic_image_sizes):
|
|
27
|
+
"""
|
|
28
|
+
Runs multi-modal inference and verifies the output.
|
|
29
|
+
"""
|
|
30
|
+
os.environ['SKIP_JAX_PRECOMPILE'] = '1' # Skip warmup to save time.
|
|
31
|
+
os.environ[
|
|
32
|
+
'VLLM_XLA_CHECK_RECOMPILATION'] = '0' # Allow compilation during execution.
|
|
33
|
+
|
|
34
|
+
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
|
35
|
+
|
|
36
|
+
# --- Configuration ---
|
|
37
|
+
model = "Qwen/Qwen2.5-VL-3B-Instruct"
|
|
38
|
+
tensor_parallel_size = 1
|
|
39
|
+
temperature = 0.0
|
|
40
|
+
max_tokens = 64
|
|
41
|
+
max_model_len = 4096
|
|
42
|
+
gpu_memory_utilization = 0.5
|
|
43
|
+
modality = "image"
|
|
44
|
+
|
|
45
|
+
print("Preparing for multi-modal inference...")
|
|
46
|
+
|
|
47
|
+
# --- Prepare Inputs ---
|
|
48
|
+
image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
|
|
49
|
+
question = "What is the content of this image?"
|
|
50
|
+
|
|
51
|
+
# Using Qwen2.5-VL prompt template
|
|
52
|
+
# NOTE: other models may be different
|
|
53
|
+
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
|
54
|
+
f"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
|
55
|
+
f"{question}<|im_end|>\n"
|
|
56
|
+
"<|im_start|>assistant\n")
|
|
57
|
+
|
|
58
|
+
# --- Setup vLLM Engine ---
|
|
59
|
+
engine_args = EngineArgs(
|
|
60
|
+
model=model,
|
|
61
|
+
max_model_len=max_model_len,
|
|
62
|
+
tensor_parallel_size=tensor_parallel_size,
|
|
63
|
+
gpu_memory_utilization=gpu_memory_utilization,
|
|
64
|
+
max_num_seqs=1,
|
|
65
|
+
mm_processor_kwargs={
|
|
66
|
+
"min_pixels": 28 * 28,
|
|
67
|
+
"max_pixels": 1280 * 28 * 28,
|
|
68
|
+
"fps": 1,
|
|
69
|
+
},
|
|
70
|
+
limit_mm_per_prompt={modality: 1},
|
|
71
|
+
)
|
|
72
|
+
engine_args = asdict(engine_args)
|
|
73
|
+
if engine_args.get("additional_config") is None:
|
|
74
|
+
engine_args["additional_config"] = {}
|
|
75
|
+
|
|
76
|
+
engine_args["additional_config"][
|
|
77
|
+
"enable_dynamic_image_sizes"] = enable_dynamic_image_sizes
|
|
78
|
+
llm = LLM(**engine_args)
|
|
79
|
+
|
|
80
|
+
sampling_params = SamplingParams(
|
|
81
|
+
temperature=temperature,
|
|
82
|
+
max_tokens=max_tokens,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
inputs = {
|
|
86
|
+
"prompt": prompt,
|
|
87
|
+
"multi_modal_data": {
|
|
88
|
+
"image": image
|
|
89
|
+
},
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# --- Run Inference ---
|
|
93
|
+
print("Running inference...")
|
|
94
|
+
outputs = llm.generate(inputs, sampling_params)
|
|
95
|
+
|
|
96
|
+
# --- Verification ---
|
|
97
|
+
generated_text = outputs[0].outputs[0].text.strip()
|
|
98
|
+
|
|
99
|
+
print("-" * 50)
|
|
100
|
+
print("Generated Text:")
|
|
101
|
+
print(generated_text)
|
|
102
|
+
print("-" * 50)
|
|
103
|
+
|
|
104
|
+
# Check output
|
|
105
|
+
similarity_score = difflib.SequenceMatcher(None, generated_text,
|
|
106
|
+
EXPECTED_TEXT).ratio()
|
|
107
|
+
print(f"Similarity Score: {similarity_score:.4f}")
|
|
108
|
+
assert similarity_score >= 0.85, (
|
|
109
|
+
f"Text similarity too low ({similarity_score:.2f}).\n"
|
|
110
|
+
f"Expected: {EXPECTED_TEXT}\n"
|
|
111
|
+
f"Actual: {generated_text}")
|