tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import random
|
|
18
|
+
import string
|
|
19
|
+
import time
|
|
20
|
+
|
|
21
|
+
import pytest
|
|
22
|
+
from vllm import LLM, SamplingParams
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def sampling_config():
|
|
27
|
+
return SamplingParams(temperature=0,
|
|
28
|
+
max_tokens=120,
|
|
29
|
+
ignore_eos=True,
|
|
30
|
+
repetition_penalty=1,
|
|
31
|
+
frequency_penalty=0,
|
|
32
|
+
presence_penalty=0,
|
|
33
|
+
min_p=0,
|
|
34
|
+
logprobs=None)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@pytest.fixture
|
|
38
|
+
def model_name():
|
|
39
|
+
return "Qwen/Qwen2.5-1.5B-Instruct"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_test_prompts():
|
|
43
|
+
"""
|
|
44
|
+
Generates a list of prompts with a specific word count,
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
num_prompts: The number of prompts to generate.
|
|
48
|
+
input_len_words: The total number of words for each prompt.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
A list of strings with number of prompts = num_prompts and
|
|
52
|
+
The total number of words for each prompt = input_len_words.
|
|
53
|
+
"""
|
|
54
|
+
num_prompts = 500
|
|
55
|
+
input_len_words = 120
|
|
56
|
+
prompts = []
|
|
57
|
+
|
|
58
|
+
# For example w = 's'
|
|
59
|
+
# The generated prompt will be Keep repeating: s s s ...
|
|
60
|
+
num_repetitions = input_len_words
|
|
61
|
+
prefix = "Keep repeating: "
|
|
62
|
+
|
|
63
|
+
for _ in range(num_prompts):
|
|
64
|
+
# 1. Pick a random lowercase letter
|
|
65
|
+
w = random.choice(list(string.ascii_lowercase))
|
|
66
|
+
|
|
67
|
+
# 2. Create the string of repeated words
|
|
68
|
+
# This will have (num_repetitions) words
|
|
69
|
+
repeating_part = " ".join([w] * num_repetitions)
|
|
70
|
+
|
|
71
|
+
# 3. Combine with the prefix (if any)
|
|
72
|
+
print(f"{prefix}{repeating_part}")
|
|
73
|
+
prompts.append(f"{prefix}{repeating_part}")
|
|
74
|
+
|
|
75
|
+
return prompts
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _test_performance_helper(monkeypatch: pytest.MonkeyPatch,
|
|
79
|
+
sampling_config: SamplingParams, model_name: str,
|
|
80
|
+
min_speedup: float):
|
|
81
|
+
'''
|
|
82
|
+
Helper function to test async scheduler decoding performance.
|
|
83
|
+
Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
|
|
84
|
+
'''
|
|
85
|
+
|
|
86
|
+
with monkeypatch.context():
|
|
87
|
+
# Use a smaller set of prompts for performance testing
|
|
88
|
+
test_prompts = get_test_prompts() # num_prompts=100, input_len=120
|
|
89
|
+
|
|
90
|
+
# Test reference LLM timing
|
|
91
|
+
ref_llm = LLM(model=model_name,
|
|
92
|
+
max_model_len=800,
|
|
93
|
+
max_num_seqs=24,
|
|
94
|
+
max_num_batched_tokens=512,
|
|
95
|
+
enable_prefix_caching=False,
|
|
96
|
+
async_scheduling=0)
|
|
97
|
+
|
|
98
|
+
start_time = time.time()
|
|
99
|
+
_ = ref_llm.generate(test_prompts, sampling_config)
|
|
100
|
+
ref_time = time.time() - start_time
|
|
101
|
+
|
|
102
|
+
del ref_llm
|
|
103
|
+
# Waiting for TPUs to be released
|
|
104
|
+
time.sleep(10)
|
|
105
|
+
|
|
106
|
+
# # Test async LLM timing with max_num_seqs=256
|
|
107
|
+
async_llm = LLM(model=model_name,
|
|
108
|
+
max_model_len=800,
|
|
109
|
+
max_num_seqs=24,
|
|
110
|
+
max_num_batched_tokens=512,
|
|
111
|
+
enable_prefix_caching=False,
|
|
112
|
+
async_scheduling=1)
|
|
113
|
+
|
|
114
|
+
start_time = time.time()
|
|
115
|
+
_ = async_llm.generate(test_prompts, sampling_config)
|
|
116
|
+
async_time = time.time() - start_time
|
|
117
|
+
|
|
118
|
+
del async_llm
|
|
119
|
+
# # Waiting for TPUs to be released
|
|
120
|
+
time.sleep(10)
|
|
121
|
+
|
|
122
|
+
speedup = ref_time / async_time
|
|
123
|
+
print(f"Reference LLM time: {ref_time:.2f}s")
|
|
124
|
+
print(f"Async LLM time: {async_time:.2f}s")
|
|
125
|
+
print(f"Speedup: {speedup:.2f}x")
|
|
126
|
+
|
|
127
|
+
assert speedup >= min_speedup, f"Expected at least {min_speedup}x speedup for async scheduler, got {speedup:.2f}x"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_performance(
|
|
131
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
132
|
+
sampling_config: SamplingParams,
|
|
133
|
+
model_name: str,
|
|
134
|
+
):
|
|
135
|
+
'''
|
|
136
|
+
Test that async scheduler decoding provides significant performance improvement.
|
|
137
|
+
Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
|
|
138
|
+
Expects async_llm to be at least 1.3x faster than ref_llm.
|
|
139
|
+
'''
|
|
140
|
+
min_speed_up = 1.3
|
|
141
|
+
_test_performance_helper(monkeypatch, sampling_config, model_name,
|
|
142
|
+
min_speed_up)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _test_correctness_helper(
|
|
146
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
147
|
+
sampling_config: SamplingParams,
|
|
148
|
+
model_name: str,
|
|
149
|
+
):
|
|
150
|
+
'''
|
|
151
|
+
Helper function to test async scheduler correctness.
|
|
152
|
+
Compare the outputs of a original LLM and a async LLM
|
|
153
|
+
should be the same when using async scheduler decoding.
|
|
154
|
+
|
|
155
|
+
Known Edge Case (KV Cache Swapping):
|
|
156
|
+
Under this case, though the temperature is set to 0,
|
|
157
|
+
the output is still slightly different everytime.
|
|
158
|
+
This is an expected behaviour as the normal scheduler also
|
|
159
|
+
behaves the same and hence, it is difficult to design a test
|
|
160
|
+
for such scenario.
|
|
161
|
+
'''
|
|
162
|
+
with monkeypatch.context():
|
|
163
|
+
test_prompts = get_test_prompts()
|
|
164
|
+
|
|
165
|
+
ref_llm = LLM(model=model_name,
|
|
166
|
+
max_model_len=1024,
|
|
167
|
+
max_num_seqs=100,
|
|
168
|
+
async_scheduling=0)
|
|
169
|
+
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
|
|
170
|
+
|
|
171
|
+
del ref_llm
|
|
172
|
+
|
|
173
|
+
# Waiting for TPUs to be released.
|
|
174
|
+
time.sleep(10)
|
|
175
|
+
|
|
176
|
+
async_llm = LLM(model=model_name,
|
|
177
|
+
max_model_len=1024,
|
|
178
|
+
max_num_seqs=100,
|
|
179
|
+
async_scheduling=1)
|
|
180
|
+
async_outputs = async_llm.generate(test_prompts, sampling_config)
|
|
181
|
+
|
|
182
|
+
matches = 0
|
|
183
|
+
misses = 0
|
|
184
|
+
for ref_output, async_output in zip(ref_outputs, async_outputs):
|
|
185
|
+
if ref_output.outputs[0].text == async_output.outputs[0].text:
|
|
186
|
+
print(f"ref_output: {ref_output.outputs[0].text}")
|
|
187
|
+
print(f"async_output: {async_output.outputs[0].text}")
|
|
188
|
+
matches += 1
|
|
189
|
+
else:
|
|
190
|
+
misses += 1
|
|
191
|
+
print(f"ref_output: {ref_output.outputs[0].text}")
|
|
192
|
+
print(f"async_output: {async_output.outputs[0].text}")
|
|
193
|
+
|
|
194
|
+
assert misses == 0
|
|
195
|
+
del async_outputs
|
|
196
|
+
|
|
197
|
+
# Waiting for TPUs to be released.
|
|
198
|
+
time.sleep(10)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def test_async_correctness(
|
|
202
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
203
|
+
sampling_config: SamplingParams,
|
|
204
|
+
model_name: str,
|
|
205
|
+
):
|
|
206
|
+
'''
|
|
207
|
+
Compare the outputs of a original LLM and a async LLM
|
|
208
|
+
should be the same when using async scheduler.
|
|
209
|
+
'''
|
|
210
|
+
|
|
211
|
+
_test_correctness_helper(monkeypatch, sampling_config, model_name)
|
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import asdict
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
from vllm import LLM, EngineArgs, SamplingParams
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@pytest.fixture(autouse=True)
|
|
13
|
+
def setup_new_model_design():
|
|
14
|
+
"""Automatically set NEW_MODEL_DESIGN=1 for all tests."""
|
|
15
|
+
os.environ['NEW_MODEL_DESIGN'] = '1'
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def test_prompts():
|
|
20
|
+
"""Simple test prompts for data parallelism testing."""
|
|
21
|
+
return [
|
|
22
|
+
"Hello, my name is",
|
|
23
|
+
"The capital of France is",
|
|
24
|
+
"The colors of the rainbow are",
|
|
25
|
+
"The future of AI is",
|
|
26
|
+
"The president of the United States is",
|
|
27
|
+
"How many players are on a standard soccer team?",
|
|
28
|
+
"In Greek mythology, who is the god of the sea?",
|
|
29
|
+
"What is the capital of Australia?",
|
|
30
|
+
"What is the largest planet in our solar system?",
|
|
31
|
+
"Who developed the theory of general relativity?",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture
|
|
36
|
+
def sampling_params():
|
|
37
|
+
"""Standard sampling parameters for testing."""
|
|
38
|
+
return SamplingParams(
|
|
39
|
+
temperature=0.0,
|
|
40
|
+
max_tokens=32,
|
|
41
|
+
ignore_eos=True,
|
|
42
|
+
logprobs=1,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _run_inference_with_config(model_name: str,
|
|
47
|
+
test_prompts: list,
|
|
48
|
+
sampling_params: SamplingParams,
|
|
49
|
+
tensor_parallel_size: int = 1,
|
|
50
|
+
data_parallel_size: int = 1,
|
|
51
|
+
additional_config: dict = {},
|
|
52
|
+
kv_cache_dtype: str = "auto",
|
|
53
|
+
enable_prefix_caching: bool = False,
|
|
54
|
+
async_scheduling: bool = False,
|
|
55
|
+
measure_time: bool = False,
|
|
56
|
+
max_model_len: int = 32,
|
|
57
|
+
max_num_batched_tokens: int = 128,
|
|
58
|
+
max_num_seqs: int = 16):
|
|
59
|
+
"""Helper function to run inference with specified configuration.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
If measure_time=True: (outputs, elapsed_time) tuple
|
|
63
|
+
If measure_time=False: outputs list
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
# Create LLM args using parser-based approach similar to offline_inference.py
|
|
67
|
+
engine_args = EngineArgs(
|
|
68
|
+
model=model_name,
|
|
69
|
+
max_model_len=max_model_len,
|
|
70
|
+
tensor_parallel_size=tensor_parallel_size,
|
|
71
|
+
data_parallel_size=data_parallel_size,
|
|
72
|
+
gpu_memory_utilization=0.98,
|
|
73
|
+
max_num_batched_tokens=max_num_batched_tokens,
|
|
74
|
+
max_num_seqs=max_num_seqs,
|
|
75
|
+
enable_prefix_caching=enable_prefix_caching,
|
|
76
|
+
additional_config=additional_config,
|
|
77
|
+
kv_cache_dtype=kv_cache_dtype,
|
|
78
|
+
async_scheduling=async_scheduling,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
engine_args_dict = asdict(engine_args)
|
|
82
|
+
llm = LLM(**engine_args_dict)
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
start_time = time.time()
|
|
86
|
+
outputs = llm.generate(test_prompts, sampling_params)
|
|
87
|
+
elapsed_time = time.time() - start_time
|
|
88
|
+
if measure_time:
|
|
89
|
+
return outputs, elapsed_time
|
|
90
|
+
else:
|
|
91
|
+
return outputs
|
|
92
|
+
finally:
|
|
93
|
+
del llm
|
|
94
|
+
# Wait for TPUs to be released
|
|
95
|
+
time.sleep(5)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def test_data_parallelism_performance(sampling_params: SamplingParams, ):
|
|
99
|
+
"""
|
|
100
|
+
Test that data parallelism provides performance improvements compared to baseline.
|
|
101
|
+
This test measures the execution time with 128 prompts of length ~1k tokens.
|
|
102
|
+
|
|
103
|
+
Note: This is a performance benchmark test with large prompts.
|
|
104
|
+
"""
|
|
105
|
+
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '1'
|
|
106
|
+
os.environ['SKIP_JAX_PRECOMPILE'] = '0'
|
|
107
|
+
os.environ['MODEL_IMPL_TYPE'] = 'flax_nnx'
|
|
108
|
+
|
|
109
|
+
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
|
110
|
+
|
|
111
|
+
# Generate 128 prompts of approximately 1k tokens each
|
|
112
|
+
# Creating a base prompt of about 1k tokens using repeated text
|
|
113
|
+
base_text = (
|
|
114
|
+
"The rapid advancement of artificial intelligence has transformed numerous industries "
|
|
115
|
+
"and continues to reshape our understanding of technology's potential. Machine learning "
|
|
116
|
+
"algorithms have become increasingly sophisticated, enabling computers to perform tasks "
|
|
117
|
+
"that were once thought to require human intelligence. From natural language processing "
|
|
118
|
+
"to computer vision, AI systems are now capable of understanding context, recognizing "
|
|
119
|
+
"patterns, and making decisions with remarkable accuracy. " *
|
|
120
|
+
20 # Repeat to reach ~1k tokens
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Create 128 prompts with slight variations
|
|
124
|
+
long_prompts = [
|
|
125
|
+
f"Prompt {i}: {base_text} What are your thoughts on this topic?"
|
|
126
|
+
for i in range(128)
|
|
127
|
+
]
|
|
128
|
+
|
|
129
|
+
print(
|
|
130
|
+
f"Generated {len(long_prompts)} prompts, approximate length: {len(base_text.split())} tokens each"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Configuration for long sequences
|
|
134
|
+
max_model_len = 2048
|
|
135
|
+
max_num_batched_tokens = 4096
|
|
136
|
+
max_num_seqs = 64
|
|
137
|
+
|
|
138
|
+
# Run baseline (no data parallelism) with timing
|
|
139
|
+
baseline_outputs, baseline_time = _run_inference_with_config(
|
|
140
|
+
model_name=model_name,
|
|
141
|
+
test_prompts=long_prompts,
|
|
142
|
+
sampling_params=sampling_params,
|
|
143
|
+
tensor_parallel_size=1,
|
|
144
|
+
data_parallel_size=1,
|
|
145
|
+
async_scheduling=True,
|
|
146
|
+
measure_time=True,
|
|
147
|
+
max_model_len=max_model_len,
|
|
148
|
+
max_num_batched_tokens=max_num_batched_tokens,
|
|
149
|
+
max_num_seqs=max_num_seqs,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Run with model data parallelism and async scheduling with timing
|
|
153
|
+
dp_outputs, dp_time = _run_inference_with_config(
|
|
154
|
+
model_name=model_name,
|
|
155
|
+
test_prompts=long_prompts,
|
|
156
|
+
sampling_params=sampling_params,
|
|
157
|
+
tensor_parallel_size=1,
|
|
158
|
+
data_parallel_size=2,
|
|
159
|
+
async_scheduling=True,
|
|
160
|
+
measure_time=True,
|
|
161
|
+
max_model_len=max_model_len,
|
|
162
|
+
max_num_batched_tokens=max_num_batched_tokens,
|
|
163
|
+
max_num_seqs=max_num_seqs,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Calculate speedup
|
|
167
|
+
speedup = baseline_time / dp_time if dp_time > 0 else 0
|
|
168
|
+
|
|
169
|
+
print("✓ Performance test results:")
|
|
170
|
+
print(f" Number of prompts: {len(long_prompts)}")
|
|
171
|
+
print(f" Baseline time: {baseline_time:.2f}s")
|
|
172
|
+
print(f" Data parallel time: {dp_time:.2f}s")
|
|
173
|
+
print(f" Speedup: {speedup:.2f}x")
|
|
174
|
+
print(
|
|
175
|
+
f" Baseline throughput: {len(long_prompts)/baseline_time:.2f} prompts/s"
|
|
176
|
+
)
|
|
177
|
+
print(
|
|
178
|
+
f" Data parallel throughput: {len(long_prompts)/dp_time:.2f} prompts/s"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@pytest.mark.parametrize("model_impl_type", ["vllm", "flax_nnx"])
|
|
183
|
+
def test_model_data_parallelism(
|
|
184
|
+
test_prompts: list,
|
|
185
|
+
sampling_params: SamplingParams,
|
|
186
|
+
model_impl_type: str,
|
|
187
|
+
):
|
|
188
|
+
"""
|
|
189
|
+
Test model-wise data parallelism where data=2 in the mesh axis.
|
|
190
|
+
This test verifies that the model can run with data parallelism enabled,
|
|
191
|
+
duplicating the entire model across 2 data parallel workers.
|
|
192
|
+
|
|
193
|
+
Equivalent to:
|
|
194
|
+
python examples/offline_inference.py --tensor_parallel_size=4 --data_parallel_size=2
|
|
195
|
+
"""
|
|
196
|
+
# Use Llama 1B for this test
|
|
197
|
+
test_model = "meta-llama/Llama-3.2-1B-Instruct"
|
|
198
|
+
os.environ['MODEL_IMPL_TYPE'] = model_impl_type
|
|
199
|
+
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
|
|
200
|
+
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
|
|
201
|
+
|
|
202
|
+
# Test with data parallelism enabled
|
|
203
|
+
outputs = _run_inference_with_config(
|
|
204
|
+
model_name=test_model,
|
|
205
|
+
test_prompts=test_prompts,
|
|
206
|
+
sampling_params=sampling_params,
|
|
207
|
+
tensor_parallel_size=1,
|
|
208
|
+
data_parallel_size=2,
|
|
209
|
+
async_scheduling=False,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Verify we got outputs for all prompts
|
|
213
|
+
assert len(outputs) == len(
|
|
214
|
+
test_prompts
|
|
215
|
+
), f"Expected {len(test_prompts)} outputs, got {len(outputs)}"
|
|
216
|
+
|
|
217
|
+
# Verify each output has generated text
|
|
218
|
+
for output in outputs:
|
|
219
|
+
assert len(output.outputs) > 0, "Output has no generated text"
|
|
220
|
+
assert len(
|
|
221
|
+
output.outputs[0].text.strip()) > 0, "Generated text is empty"
|
|
222
|
+
|
|
223
|
+
print(f"✓ Model data parallelism test passed with {len(outputs)} outputs")
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def test_attention_data_parallelism(
|
|
227
|
+
test_prompts: list,
|
|
228
|
+
sampling_params: SamplingParams,
|
|
229
|
+
):
|
|
230
|
+
"""
|
|
231
|
+
Test attention data parallelism where only the attention layer gets duplicated,
|
|
232
|
+
attn_dp=2 in the mesh axis. This is useful when num_kv_heads < TP to avoid
|
|
233
|
+
wasting KV cache memory.
|
|
234
|
+
|
|
235
|
+
Equivalent to:
|
|
236
|
+
python examples/offline_inference.py --tensor_parallel_size=4 --kv-cache-dtype=fp8 \
|
|
237
|
+
--additional_config='{"sharding":{"sharding_strategy": {"enable_dp_attention":1}}}'
|
|
238
|
+
"""
|
|
239
|
+
# Use Qwen3 0.6B for this test with reduced tensor parallelism
|
|
240
|
+
test_model = "Qwen/Qwen3-0.6B"
|
|
241
|
+
|
|
242
|
+
os.environ['MODEL_IMPL_TYPE'] = "flax_nnx"
|
|
243
|
+
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
|
|
244
|
+
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
|
|
245
|
+
|
|
246
|
+
additional_config = {
|
|
247
|
+
"sharding": {
|
|
248
|
+
"sharding_strategy": {
|
|
249
|
+
"enable_dp_attention": 1
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
# Test with attention data parallelism enabled
|
|
255
|
+
# Reduced tensor_parallel_size from 8 to 4 to avoid memory exhaustion
|
|
256
|
+
outputs = _run_inference_with_config(
|
|
257
|
+
model_name=test_model,
|
|
258
|
+
test_prompts=test_prompts,
|
|
259
|
+
sampling_params=sampling_params,
|
|
260
|
+
tensor_parallel_size=4,
|
|
261
|
+
data_parallel_size=1,
|
|
262
|
+
additional_config=additional_config,
|
|
263
|
+
kv_cache_dtype="fp8",
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Verify we got outputs for all prompts
|
|
267
|
+
assert len(outputs) == len(
|
|
268
|
+
test_prompts
|
|
269
|
+
), f"Expected {len(test_prompts)} outputs, got {len(outputs)}"
|
|
270
|
+
|
|
271
|
+
# Verify each output has generated text
|
|
272
|
+
for output in outputs:
|
|
273
|
+
assert len(output.outputs) > 0, "Output has no generated text"
|
|
274
|
+
assert len(
|
|
275
|
+
output.outputs[0].text.strip()) > 0, "Generated text is empty"
|
|
276
|
+
|
|
277
|
+
print(
|
|
278
|
+
f"✓ Attention data parallelism test passed with {len(outputs)} outputs"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def test_data_parallelism_correctness(
|
|
283
|
+
test_prompts: list,
|
|
284
|
+
sampling_params: SamplingParams,
|
|
285
|
+
):
|
|
286
|
+
"""
|
|
287
|
+
Test that data parallelism produces consistent results compared to a baseline.
|
|
288
|
+
This test compares outputs from a single-device run with data parallel runs
|
|
289
|
+
to ensure correctness, including log probabilities.
|
|
290
|
+
"""
|
|
291
|
+
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
|
|
292
|
+
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
|
|
293
|
+
os.environ['MODEL_IMPL_TYPE'] = "flax_nnx"
|
|
294
|
+
|
|
295
|
+
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
|
296
|
+
# Use a smaller subset of prompts for correctness testing
|
|
297
|
+
small_prompts = test_prompts[:10]
|
|
298
|
+
|
|
299
|
+
# Run baseline (no data parallelism)
|
|
300
|
+
baseline_outputs = _run_inference_with_config(
|
|
301
|
+
model_name=model_name,
|
|
302
|
+
test_prompts=small_prompts,
|
|
303
|
+
sampling_params=sampling_params,
|
|
304
|
+
tensor_parallel_size=1,
|
|
305
|
+
data_parallel_size=1,
|
|
306
|
+
async_scheduling=True,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Run with model data parallelism and async scheduling
|
|
310
|
+
dp_outputs = _run_inference_with_config(
|
|
311
|
+
model_name=model_name,
|
|
312
|
+
test_prompts=small_prompts,
|
|
313
|
+
sampling_params=sampling_params,
|
|
314
|
+
tensor_parallel_size=1,
|
|
315
|
+
data_parallel_size=2,
|
|
316
|
+
async_scheduling=True,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Compare outputs - they should be identical for greedy sampling
|
|
320
|
+
assert len(baseline_outputs) == len(dp_outputs)
|
|
321
|
+
|
|
322
|
+
text_matches = 0
|
|
323
|
+
text_mismatches = 0
|
|
324
|
+
logprob_mismatches = 0
|
|
325
|
+
max_logprob_diff = 0.0
|
|
326
|
+
|
|
327
|
+
for i, (baseline, dp_result) in enumerate(zip(baseline_outputs,
|
|
328
|
+
dp_outputs)):
|
|
329
|
+
baseline_text = baseline.outputs[0].text.strip()
|
|
330
|
+
dp_text = dp_result.outputs[0].text.strip()
|
|
331
|
+
|
|
332
|
+
# Check text output
|
|
333
|
+
if baseline_text == dp_text:
|
|
334
|
+
text_matches += 1
|
|
335
|
+
else:
|
|
336
|
+
text_mismatches += 1
|
|
337
|
+
print(f"Text mismatch found in prompt {i}:")
|
|
338
|
+
print(f" Baseline: {baseline_text}")
|
|
339
|
+
print(f" Data Parallel: {dp_text}")
|
|
340
|
+
|
|
341
|
+
# Check log probabilities
|
|
342
|
+
baseline_logprobs = baseline.outputs[0].logprobs
|
|
343
|
+
dp_logprobs = dp_result.outputs[0].logprobs
|
|
344
|
+
|
|
345
|
+
if baseline_logprobs is not None and dp_logprobs is not None:
|
|
346
|
+
# Compare log probabilities for each token
|
|
347
|
+
assert len(baseline_logprobs) == len(dp_logprobs), \
|
|
348
|
+
f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(dp_logprobs)}"
|
|
349
|
+
|
|
350
|
+
for token_idx, (base_lp, dp_lp) in enumerate(
|
|
351
|
+
zip(baseline_logprobs, dp_logprobs)):
|
|
352
|
+
# Get the top logprob value for the selected token
|
|
353
|
+
if base_lp and dp_lp:
|
|
354
|
+
# Get the top token's logprob from each
|
|
355
|
+
base_top_token = list(base_lp.keys())[0]
|
|
356
|
+
dp_top_token = list(dp_lp.keys())[0]
|
|
357
|
+
|
|
358
|
+
base_logprob_val = base_lp[base_top_token].logprob
|
|
359
|
+
dp_logprob_val = dp_lp[dp_top_token].logprob
|
|
360
|
+
|
|
361
|
+
# Calculate absolute difference
|
|
362
|
+
diff = abs(base_logprob_val - dp_logprob_val)
|
|
363
|
+
max_logprob_diff = max(max_logprob_diff, diff)
|
|
364
|
+
|
|
365
|
+
# Allow small numerical differences
|
|
366
|
+
if diff > 0.15:
|
|
367
|
+
logprob_mismatches += 1
|
|
368
|
+
print(
|
|
369
|
+
f"Logprob mismatch in prompt {i}, token {token_idx}:"
|
|
370
|
+
)
|
|
371
|
+
print(
|
|
372
|
+
f" Baseline token: {base_top_token}, logprob: {base_logprob_val:.6f}"
|
|
373
|
+
)
|
|
374
|
+
print(
|
|
375
|
+
f" DP token: {dp_top_token}, logprob: {dp_logprob_val:.6f}"
|
|
376
|
+
)
|
|
377
|
+
print(f" Difference: {diff:.6f}")
|
|
378
|
+
|
|
379
|
+
print("✓ Correctness test results:")
|
|
380
|
+
print(f" Text: {text_matches} matches, {text_mismatches} mismatches")
|
|
381
|
+
print(f" Max logprob difference: {max_logprob_diff:.6e}")
|
|
382
|
+
print(f" Significant logprob mismatches (>0.15): {logprob_mismatches}")
|
|
383
|
+
|
|
384
|
+
# Allow for some variance due to potential numerical differences
|
|
385
|
+
# but most outputs should match with greedy sampling
|
|
386
|
+
text_match_rate = text_matches / len(baseline_outputs)
|
|
387
|
+
assert text_match_rate >= 0.9, f"Text match rate {text_match_rate:.2%} is too low"
|
|
388
|
+
|
|
389
|
+
# Log probabilities should be very close (allow small numerical errors)
|
|
390
|
+
assert max_logprob_diff < 0.15, f"Max logprob difference {max_logprob_diff} is too large"
|
|
391
|
+
|
|
392
|
+
# Log probabilities should be very close (allow small numerical errors)
|
|
393
|
+
assert max_logprob_diff < 0.15, f"Max logprob difference {max_logprob_diff} is too large"
|