tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from dataclasses import asdict
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
from vllm import LLM, EngineArgs, SamplingParams
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.fixture
|
|
12
|
+
def model_name():
|
|
13
|
+
"""Choose gemma-27b as the test model as it has both full attention and
|
|
14
|
+
sliding window attention."""
|
|
15
|
+
return "google/gemma-3-27b-it"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def test_prompts():
|
|
20
|
+
"""Simple test prompts for hybrid kv cache testing."""
|
|
21
|
+
return [
|
|
22
|
+
"Hello, my name is",
|
|
23
|
+
"The capital of France is",
|
|
24
|
+
"The colors of the rainbow are",
|
|
25
|
+
"The future of AI is",
|
|
26
|
+
"The president of the United States is",
|
|
27
|
+
"How many players are on a standard soccer team?",
|
|
28
|
+
"In Greek mythology, who is the god of the sea?",
|
|
29
|
+
"What is the capital of Australia?",
|
|
30
|
+
"What is the largest planet in our solar system?",
|
|
31
|
+
"Who developed the theory of general relativity?",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture
|
|
36
|
+
def sampling_params():
|
|
37
|
+
"""Standard sampling parameters for testing."""
|
|
38
|
+
return SamplingParams(
|
|
39
|
+
temperature=0.0,
|
|
40
|
+
max_tokens=32,
|
|
41
|
+
ignore_eos=True,
|
|
42
|
+
logprobs=1,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _run_inference_with_config(
|
|
47
|
+
model_name: str,
|
|
48
|
+
test_prompts: list,
|
|
49
|
+
sampling_params: SamplingParams,
|
|
50
|
+
tensor_parallel_size: int = 4,
|
|
51
|
+
kv_cache_dtype: str = "auto",
|
|
52
|
+
enable_prefix_caching: bool = False,
|
|
53
|
+
disable_hybrid_kv_cache_manager: bool = False) -> list:
|
|
54
|
+
"""Helper function to run inference with specified configuration."""
|
|
55
|
+
|
|
56
|
+
# Create LLM args using parser-based approach similar to offline_inference.py
|
|
57
|
+
engine_args = EngineArgs(
|
|
58
|
+
model=model_name,
|
|
59
|
+
max_model_len=64,
|
|
60
|
+
tensor_parallel_size=tensor_parallel_size,
|
|
61
|
+
gpu_memory_utilization=0.95,
|
|
62
|
+
max_num_batched_tokens=256,
|
|
63
|
+
max_num_seqs=16,
|
|
64
|
+
enable_prefix_caching=enable_prefix_caching,
|
|
65
|
+
kv_cache_dtype=kv_cache_dtype,
|
|
66
|
+
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
engine_args_dict = asdict(engine_args)
|
|
70
|
+
llm = LLM(**engine_args_dict)
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
outputs = llm.generate(test_prompts, sampling_params)
|
|
74
|
+
return outputs
|
|
75
|
+
finally:
|
|
76
|
+
del llm
|
|
77
|
+
# Wait for TPUs to be released
|
|
78
|
+
time.sleep(10)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def test_hybrid_kv_cache(
|
|
82
|
+
model_name: str,
|
|
83
|
+
test_prompts: list,
|
|
84
|
+
sampling_params: SamplingParams,
|
|
85
|
+
):
|
|
86
|
+
"""
|
|
87
|
+
Test hybrid kv cache works on gemma vLLM models.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
os.environ['MODEL_IMPL_TYPE'] = 'vllm'
|
|
91
|
+
# Test with hybrid kv cache alloctaion enabled.
|
|
92
|
+
outputs = _run_inference_with_config(
|
|
93
|
+
model_name=model_name,
|
|
94
|
+
test_prompts=test_prompts,
|
|
95
|
+
sampling_params=sampling_params,
|
|
96
|
+
disable_hybrid_kv_cache_manager=False,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Verify we got outputs for all prompts
|
|
100
|
+
assert len(outputs) == len(test_prompts)
|
|
101
|
+
|
|
102
|
+
# Verify each output has generated text
|
|
103
|
+
for output in outputs:
|
|
104
|
+
assert len(output.outputs) > 0
|
|
105
|
+
assert len(output.outputs[0].text.strip()) > 0
|
|
106
|
+
|
|
107
|
+
print(f"✓ Hybrid KV cache test passed with {len(outputs)} outputs")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def test_hybrid_kv_cache_correctness(
|
|
111
|
+
model_name: str,
|
|
112
|
+
test_prompts: list,
|
|
113
|
+
sampling_params: SamplingParams,
|
|
114
|
+
):
|
|
115
|
+
"""
|
|
116
|
+
Test that hybrid kv cache allocation produces consistent results compared
|
|
117
|
+
to standard kv cache allocation.
|
|
118
|
+
"""
|
|
119
|
+
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
|
|
120
|
+
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
|
|
121
|
+
|
|
122
|
+
small_prompts = test_prompts
|
|
123
|
+
|
|
124
|
+
# Run baseline (no hybrid kv cache)
|
|
125
|
+
baseline_outputs = _run_inference_with_config(
|
|
126
|
+
model_name=model_name,
|
|
127
|
+
test_prompts=small_prompts,
|
|
128
|
+
sampling_params=sampling_params,
|
|
129
|
+
disable_hybrid_kv_cache_manager=True,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Run with hybrid kv cache enabled.
|
|
133
|
+
hybrid_kvcache_outputs = _run_inference_with_config(
|
|
134
|
+
model_name=model_name,
|
|
135
|
+
test_prompts=small_prompts,
|
|
136
|
+
sampling_params=sampling_params,
|
|
137
|
+
disable_hybrid_kv_cache_manager=False,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Compare outputs - in theory they should be identical for greedy sampling
|
|
141
|
+
# in reality there may be some differences, but overall the outputs should
|
|
142
|
+
# be very similar.
|
|
143
|
+
|
|
144
|
+
# an example:
|
|
145
|
+
# prompt: What is the capital of Australia?
|
|
146
|
+
# both answers should be acceptable.
|
|
147
|
+
# The capital of Australia is Canberra. It is located in the Australian Capital Territory (ACT) and is home to many
|
|
148
|
+
# Canberra is the capital of Australia. It is located in the Australian Capital Territory (ACT) and is home to
|
|
149
|
+
assert len(baseline_outputs) == len(hybrid_kvcache_outputs)
|
|
150
|
+
|
|
151
|
+
text_matches = 0
|
|
152
|
+
text_mismatches = 0
|
|
153
|
+
logprob_mismatches = 0
|
|
154
|
+
max_logprob_diff = 0.0
|
|
155
|
+
|
|
156
|
+
for i, (baseline, hybrid_kvcache_result) in enumerate(
|
|
157
|
+
zip(baseline_outputs, hybrid_kvcache_outputs)):
|
|
158
|
+
baseline_text = baseline.outputs[0].text.strip()
|
|
159
|
+
hybrid_kvcache_text = hybrid_kvcache_result.outputs[0].text.strip()
|
|
160
|
+
|
|
161
|
+
# Check text output
|
|
162
|
+
if baseline_text == hybrid_kvcache_text:
|
|
163
|
+
text_matches += 1
|
|
164
|
+
else:
|
|
165
|
+
text_mismatches += 1
|
|
166
|
+
print(f"Text mismatch found in prompt {i}:")
|
|
167
|
+
print(f" Baseline: {baseline_text}")
|
|
168
|
+
print(f" Hybrid KV Cache: {hybrid_kvcache_text}")
|
|
169
|
+
|
|
170
|
+
# Check log probabilities
|
|
171
|
+
baseline_logprobs = baseline.outputs[0].logprobs
|
|
172
|
+
hybrid_kvcache_logprobs = hybrid_kvcache_result.outputs[0].logprobs
|
|
173
|
+
if baseline_logprobs is not None and hybrid_kvcache_logprobs is not None:
|
|
174
|
+
# Compare log probabilities for each token
|
|
175
|
+
assert len(baseline_logprobs) == len(hybrid_kvcache_logprobs), \
|
|
176
|
+
f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(hybrid_kvcache_logprobs)}"
|
|
177
|
+
for token_idx, (base_lp, hybrid_kvcache_lp) in enumerate(
|
|
178
|
+
zip(baseline_logprobs, hybrid_kvcache_logprobs)):
|
|
179
|
+
# Get the top logprob value for the selected token
|
|
180
|
+
if base_lp and hybrid_kvcache_lp:
|
|
181
|
+
# Get the top token's logprob from each
|
|
182
|
+
base_top_token = list(base_lp.keys())[0]
|
|
183
|
+
hybrid_kvcache_top_token = list(
|
|
184
|
+
hybrid_kvcache_lp.keys())[0]
|
|
185
|
+
|
|
186
|
+
base_logprob_val = base_lp[base_top_token].logprob
|
|
187
|
+
hybrid_kvcache_logprob_val = hybrid_kvcache_lp[
|
|
188
|
+
hybrid_kvcache_top_token].logprob
|
|
189
|
+
|
|
190
|
+
# Calculate absolute difference
|
|
191
|
+
diff = abs(base_logprob_val - hybrid_kvcache_logprob_val)
|
|
192
|
+
max_logprob_diff = max(max_logprob_diff, diff)
|
|
193
|
+
|
|
194
|
+
# Allow small numerical differences (e.g., 1e-3)
|
|
195
|
+
if diff > 1e-3:
|
|
196
|
+
logprob_mismatches += 1
|
|
197
|
+
print(
|
|
198
|
+
f"Logprob mismatch in prompt {i}, token {token_idx}:"
|
|
199
|
+
)
|
|
200
|
+
print(
|
|
201
|
+
f" Baseline token: {base_top_token}, logprob: {base_logprob_val:.6f}"
|
|
202
|
+
)
|
|
203
|
+
print(
|
|
204
|
+
f" Hybrid KV Cache token: {hybrid_kvcache_top_token}, logprob: {hybrid_kvcache_logprob_val:.6f}"
|
|
205
|
+
)
|
|
206
|
+
print(f" Difference: {diff:.6f}")
|
|
207
|
+
|
|
208
|
+
print("✓ Correctness test results:")
|
|
209
|
+
print(f" Text: {text_matches} matches, {text_mismatches} mismatches")
|
|
210
|
+
print(f" Max logprob difference: {max_logprob_diff:.6e}")
|
|
211
|
+
print(f" Significant logprob mismatches (>1e-3): {logprob_mismatches}")
|
|
212
|
+
|
|
213
|
+
# Allow for some variance due to potential numerical differences
|
|
214
|
+
# but most outputs should match with greedy sampling
|
|
215
|
+
text_match_rate = text_matches / len(baseline_outputs)
|
|
216
|
+
assert text_match_rate >= 0.9, f"Text match rate {text_match_rate:.2%} is too low"
|
|
217
|
+
|
|
218
|
+
# Log probabilities should be very close (allow small numerical errors)
|
|
219
|
+
assert max_logprob_diff < 2, f"Max logprob difference {max_logprob_diff} is too large"
|
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import asdict
|
|
7
|
+
from unittest.mock import patch
|
|
8
|
+
|
|
9
|
+
import pytest
|
|
10
|
+
import vllm.envs as vllm_envs
|
|
11
|
+
from vllm import LLM, EngineArgs, SamplingParams
|
|
12
|
+
|
|
13
|
+
from tpu_inference.core.core_tpu import DisaggEngineCore, DisaggEngineCoreProc
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@pytest.fixture
|
|
17
|
+
def test_prompts():
|
|
18
|
+
"""Simple test prompts for disaggregated serving testing."""
|
|
19
|
+
return [
|
|
20
|
+
"Hello, my name is",
|
|
21
|
+
"The capital of France is",
|
|
22
|
+
"The colors of the rainbow are",
|
|
23
|
+
"The future of AI is",
|
|
24
|
+
"The president of the United States is",
|
|
25
|
+
"How many players are on a standard soccer team on the field at one time?",
|
|
26
|
+
"In Greek mythology, who is the god of the sea?",
|
|
27
|
+
"In what year did the Titanic sink?",
|
|
28
|
+
"In which museum is the Mona Lisa displayed?",
|
|
29
|
+
"Mount Everest is located in which mountain range?",
|
|
30
|
+
"What ancient empire was ruled by Julius Caesar?",
|
|
31
|
+
"What are the four fundamental forces of nature?",
|
|
32
|
+
'What does "CPU" stand for?',
|
|
33
|
+
'What does "HTML" stand for?',
|
|
34
|
+
"What is the capital of Australia?",
|
|
35
|
+
"What is the chemical symbol for gold?",
|
|
36
|
+
"What is the currency of Switzerland?",
|
|
37
|
+
"What is the distance from the Earth to the Sun called?",
|
|
38
|
+
"What is the freezing point of water in Celsius?",
|
|
39
|
+
"What is the hardest known natural substance on Earth?",
|
|
40
|
+
"What is the largest planet in our solar system?",
|
|
41
|
+
"What is the longest river in the world?",
|
|
42
|
+
"What is the main function of the kidneys in the human body?",
|
|
43
|
+
"What is the main ingredient in guacamole?",
|
|
44
|
+
"What is the most spoken language in the world by number of native speakers?",
|
|
45
|
+
"What is the process by which plants use sunlight to create food?",
|
|
46
|
+
"Which country is known as the Land of the Rising Sun?",
|
|
47
|
+
"Who developed the theory of general relativity?",
|
|
48
|
+
'Who directed the original "Star Wars" trilogy?',
|
|
49
|
+
"Who is credited with inventing the telephone?",
|
|
50
|
+
"Who painted the ceiling of the Sistine Chapel?",
|
|
51
|
+
"Who was the first female Prime Minister of the United Kingdom?",
|
|
52
|
+
"Who was the first person to walk on the moon?",
|
|
53
|
+
"Who wrote the American Declaration of Independence?",
|
|
54
|
+
'Who wrote the novel "Pride and Prejudice"?',
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@pytest.fixture
|
|
59
|
+
def sampling_params():
|
|
60
|
+
"""Standard sampling parameters for testing."""
|
|
61
|
+
return SamplingParams(
|
|
62
|
+
temperature=0.0,
|
|
63
|
+
max_tokens=32,
|
|
64
|
+
ignore_eos=True,
|
|
65
|
+
logprobs=1,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_disaggregated_serving(test_prompts, sampling_params):
|
|
70
|
+
"""
|
|
71
|
+
Test disaggregated serving end-to-end.
|
|
72
|
+
|
|
73
|
+
Equivalent to:
|
|
74
|
+
PREFILL_SLICES=4 DECODE_SLICES=4 python examples/offline_inference.py \
|
|
75
|
+
--model=meta-llama/Meta-Llama-3.1-8B-Instruct --task=generate \
|
|
76
|
+
--max_model_len=2048 --tensor_parallel_size 4
|
|
77
|
+
"""
|
|
78
|
+
# Set environment variables for disaggregated serving
|
|
79
|
+
# Using 4 slices for prefill and 4 for decode as requested
|
|
80
|
+
# Note: The user example used PREFILL_SLICES=4 DECODE_SLICES=4
|
|
81
|
+
# But usually slices are specified as "2x2" or similar if they are TPU topology.
|
|
82
|
+
# However, disagg_utils.py _parse_slices handles "4" as well (1D).
|
|
83
|
+
# We will stick to the user's example values.
|
|
84
|
+
|
|
85
|
+
# We need to mock the environment variables for this test
|
|
86
|
+
with patch.dict(
|
|
87
|
+
os.environ, {
|
|
88
|
+
"PREFILL_SLICES": "4",
|
|
89
|
+
"DECODE_SLICES": "4",
|
|
90
|
+
"SKIP_JAX_PRECOMPILE": "1",
|
|
91
|
+
"VLLM_XLA_CHECK_RECOMPILATION": "0"
|
|
92
|
+
}):
|
|
93
|
+
# Patch the EngineCore classes to use Disagg versions
|
|
94
|
+
with patch("vllm.v1.engine.core.EngineCore", DisaggEngineCore), \
|
|
95
|
+
patch("vllm.v1.engine.core.EngineCoreProc", DisaggEngineCoreProc):
|
|
96
|
+
|
|
97
|
+
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
|
98
|
+
os.system(f"rm -rf {vllm_envs.VLLM_XLA_CACHE_PATH}/*")
|
|
99
|
+
engine_args = EngineArgs(
|
|
100
|
+
model=model_name,
|
|
101
|
+
max_model_len=2048,
|
|
102
|
+
tensor_parallel_size=4,
|
|
103
|
+
gpu_memory_utilization=0.90,
|
|
104
|
+
enforce_eager=False,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
llm = LLM(**asdict(engine_args))
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
outputs = llm.generate(test_prompts, sampling_params)
|
|
111
|
+
|
|
112
|
+
# Verify outputs
|
|
113
|
+
assert len(outputs) == len(test_prompts)
|
|
114
|
+
for output in outputs:
|
|
115
|
+
assert len(output.outputs) > 0
|
|
116
|
+
assert len(output.outputs[0].text.strip()) > 0
|
|
117
|
+
print(f"Prompt: {output.prompt!r}")
|
|
118
|
+
print(f"Generated: {output.outputs[0].text!r}")
|
|
119
|
+
|
|
120
|
+
finally:
|
|
121
|
+
del llm
|
|
122
|
+
time.sleep(10)
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _run_inference(model_name: str,
|
|
127
|
+
test_prompts: list,
|
|
128
|
+
sampling_params: SamplingParams,
|
|
129
|
+
tensor_parallel_size: int = 1,
|
|
130
|
+
is_disagg: bool = False,
|
|
131
|
+
prefill_slices: str = "4",
|
|
132
|
+
decode_slices: str = "4") -> list:
|
|
133
|
+
"""Helper function to run inference with specified configuration."""
|
|
134
|
+
|
|
135
|
+
# Define the inner execution logic
|
|
136
|
+
def run_inner():
|
|
137
|
+
engine_args = EngineArgs(
|
|
138
|
+
model=model_name,
|
|
139
|
+
max_model_len=2048,
|
|
140
|
+
tensor_parallel_size=tensor_parallel_size,
|
|
141
|
+
gpu_memory_utilization=0.90,
|
|
142
|
+
enforce_eager=False,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
llm = LLM(**asdict(engine_args))
|
|
146
|
+
try:
|
|
147
|
+
return llm.generate(test_prompts, sampling_params)
|
|
148
|
+
finally:
|
|
149
|
+
del llm
|
|
150
|
+
time.sleep(10)
|
|
151
|
+
pass
|
|
152
|
+
|
|
153
|
+
if is_disagg:
|
|
154
|
+
# Mock environment variables and patch classes for disagg
|
|
155
|
+
with patch.dict(
|
|
156
|
+
os.environ, {
|
|
157
|
+
"PREFILL_SLICES": prefill_slices,
|
|
158
|
+
"DECODE_SLICES": decode_slices,
|
|
159
|
+
"SKIP_JAX_PRECOMPILE": "1",
|
|
160
|
+
"VLLM_XLA_CHECK_RECOMPILATION": "0"
|
|
161
|
+
}):
|
|
162
|
+
with patch("vllm.v1.engine.core.EngineCore", DisaggEngineCore), \
|
|
163
|
+
patch("vllm.v1.engine.core.EngineCoreProc", DisaggEngineCoreProc):
|
|
164
|
+
return run_inner()
|
|
165
|
+
else:
|
|
166
|
+
# Run standard inference
|
|
167
|
+
# We still set some env vars to ensure consistent behavior if needed
|
|
168
|
+
# but for baseline we want it as standard as possible.
|
|
169
|
+
# However, to match the disagg run's potential jax settings:
|
|
170
|
+
with patch.dict(os.environ, {
|
|
171
|
+
"SKIP_JAX_PRECOMPILE": "1",
|
|
172
|
+
"VLLM_XLA_CHECK_RECOMPILATION": "0"
|
|
173
|
+
}):
|
|
174
|
+
return run_inner()
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def test_disaggregated_serving_correctness(test_prompts, sampling_params):
|
|
178
|
+
"""
|
|
179
|
+
Test that disaggregated serving produces consistent results compared to a baseline.
|
|
180
|
+
"""
|
|
181
|
+
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
|
182
|
+
# Use a smaller subset of prompts for correctness testing
|
|
183
|
+
small_prompts = test_prompts[:20]
|
|
184
|
+
sampling_params.max_tokens = 16
|
|
185
|
+
|
|
186
|
+
# Run baseline (standard execution)
|
|
187
|
+
# We use tensor_parallel_size=4 to match the disagg resources if we assume
|
|
188
|
+
# the user has enough chips, or if we are just mocking.
|
|
189
|
+
# Since the original test used tp=4, we stick to it.
|
|
190
|
+
print("Running Baseline Inference...")
|
|
191
|
+
baseline_outputs = _run_inference(model_name=model_name,
|
|
192
|
+
test_prompts=small_prompts,
|
|
193
|
+
sampling_params=sampling_params,
|
|
194
|
+
tensor_parallel_size=4,
|
|
195
|
+
is_disagg=False)
|
|
196
|
+
|
|
197
|
+
# Run disaggregated inference
|
|
198
|
+
os.system(f"rm -rf {vllm_envs.VLLM_XLA_CACHE_PATH}/*")
|
|
199
|
+
print("Running Disaggregated Inference...")
|
|
200
|
+
|
|
201
|
+
disagg_outputs = _run_inference(model_name=model_name,
|
|
202
|
+
test_prompts=small_prompts,
|
|
203
|
+
sampling_params=sampling_params,
|
|
204
|
+
tensor_parallel_size=4,
|
|
205
|
+
is_disagg=True,
|
|
206
|
+
prefill_slices="4",
|
|
207
|
+
decode_slices="4")
|
|
208
|
+
|
|
209
|
+
# Compare outputs
|
|
210
|
+
assert len(baseline_outputs) == len(disagg_outputs)
|
|
211
|
+
|
|
212
|
+
text_matches = 0
|
|
213
|
+
text_mismatches = 0
|
|
214
|
+
token_mismatches = 0
|
|
215
|
+
|
|
216
|
+
for i, (baseline,
|
|
217
|
+
disagg) in enumerate(zip(baseline_outputs, disagg_outputs)):
|
|
218
|
+
baseline_text = baseline.outputs[0].text.strip()
|
|
219
|
+
disagg_text = disagg.outputs[0].text.strip()
|
|
220
|
+
|
|
221
|
+
# Check text output
|
|
222
|
+
if baseline_text == disagg_text:
|
|
223
|
+
text_matches += 1
|
|
224
|
+
else:
|
|
225
|
+
text_mismatches += 1
|
|
226
|
+
print(f"Text mismatch found in prompt {i}:")
|
|
227
|
+
print(f" Baseline: {baseline_text}")
|
|
228
|
+
print(f" Disagg: {disagg_text}")
|
|
229
|
+
|
|
230
|
+
# Check log probabilities (tokens) if available
|
|
231
|
+
baseline_logprobs = baseline.outputs[0].logprobs
|
|
232
|
+
disagg_logprobs = disagg.outputs[0].logprobs
|
|
233
|
+
|
|
234
|
+
if baseline_logprobs is not None and disagg_logprobs is not None:
|
|
235
|
+
assert len(baseline_logprobs) == len(disagg_logprobs), \
|
|
236
|
+
f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(disagg_logprobs)}"
|
|
237
|
+
|
|
238
|
+
for token_idx, (base_lp, disagg_lp) in enumerate(
|
|
239
|
+
zip(baseline_logprobs, disagg_logprobs)):
|
|
240
|
+
if base_lp and disagg_lp:
|
|
241
|
+
# Compare the top token IDs
|
|
242
|
+
base_top_token = list(base_lp.keys())[0]
|
|
243
|
+
disagg_top_token = list(disagg_lp.keys())[0]
|
|
244
|
+
|
|
245
|
+
if base_top_token != disagg_top_token:
|
|
246
|
+
token_mismatches += 1
|
|
247
|
+
print(
|
|
248
|
+
f"Token mismatch in prompt {i}, token {token_idx}:"
|
|
249
|
+
)
|
|
250
|
+
print(f" Baseline: {base_top_token}")
|
|
251
|
+
print(f" Disagg: {disagg_top_token}")
|
|
252
|
+
|
|
253
|
+
print("✓ Correctness test results:")
|
|
254
|
+
print(f" Text: {text_matches} matches, {text_mismatches} mismatches")
|
|
255
|
+
print(f" Token mismatches in logprobs: {token_mismatches}")
|
|
256
|
+
assert text_mismatches <= 5, f"Found {text_mismatches} text mismatches"
|
|
257
|
+
assert token_mismatches <= 40, f"Found {token_mismatches} token mismatches"
|