tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,104 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This file contains end-to-end tests for the RunAI Model Streamer loader.
16
+ #
17
+ # The RunAI Model Streamer is a high-performance model loader that serves as an
18
+ # alternative to the default Hugging Face loader. Instead of downloading a model
19
+ # to local disk, it streams the weights from object storage (like GCS) into
20
+ # GPU memory. This streaming process is significantly faster than the
21
+ # traditional disk-based loading method.
22
+
23
+ # The tests in this file verify that loading model weights using the
24
+ # streamer produces the same results as loading the same model using the
25
+ # standard Hugging Face loader. This ensures the correctness of the streamer
26
+ # integration.
27
+
28
+ # The tests are performed by:
29
+ # 1. Loading a model from Google Cloud Storage using the `runai_streamer` format.
30
+ # 2. Generating output with this model.
31
+ # 3. Loading the same model from Hugging Face using the default loader.
32
+ # 4. Generating output with this second model.
33
+ # 5. Asserting that the outputs from both models are identical.
34
+
35
+ from __future__ import annotations
36
+
37
+ import time
38
+
39
+ import pytest
40
+ from vllm import LLM, SamplingParams
41
+
42
+
43
+ @pytest.fixture
44
+ def sampling_config():
45
+ return SamplingParams(temperature=0, max_tokens=10, ignore_eos=True)
46
+
47
+
48
+ @pytest.fixture
49
+ # TODO(amacaskill): Replace with GKE owned GCS bucket.
50
+ def gcs_model_name():
51
+ return "gs://vertex-model-garden-public-us/llama3/llama3-8b-hf"
52
+
53
+
54
+ @pytest.fixture
55
+ def hf_model_name():
56
+ return "meta-llama/Meta-Llama-3-8B"
57
+
58
+
59
+ @pytest.fixture
60
+ def prompt():
61
+ return "Hello, my name is"
62
+
63
+
64
+ def test_correctness(
65
+ sampling_config: SamplingParams,
66
+ gcs_model_name: str,
67
+ hf_model_name: str,
68
+ prompt: str,
69
+ monkeypatch: pytest.MonkeyPatch,
70
+ ):
71
+ '''
72
+ Compare the outputs of a model loaded from GCS via runai_model_streamer
73
+ and a model loaded from Hugging Face. The outputs should be the same.
74
+ These tests attempt to use tensor_parallel_size=1. The model is 16GB,
75
+ # and v6e has 32GB of HBM, so it will fit.
76
+ '''
77
+ # Set ENV variables so that runai_model_streamer uses anonymous GCS access.
78
+ monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake-project")
79
+ monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true")
80
+ monkeypatch.setenv("CLOUD_STORAGE_EMULATOR_ENDPOINT",
81
+ "https://storage.googleapis.com")
82
+ gcs_llm = LLM(model=gcs_model_name,
83
+ load_format="runai_streamer",
84
+ max_model_len=128,
85
+ max_num_seqs=16,
86
+ max_num_batched_tokens=256)
87
+ gcs_outputs = gcs_llm.generate([prompt], sampling_config)
88
+ gcs_output_text = gcs_outputs[0].outputs[0].text
89
+ del gcs_llm
90
+ time.sleep(10) # Wait for TPUs to be released
91
+
92
+ # Test with Hugging Face model
93
+ hf_llm = LLM(model=hf_model_name,
94
+ max_model_len=128,
95
+ max_num_seqs=16,
96
+ max_num_batched_tokens=256)
97
+ hf_outputs = hf_llm.generate([prompt], sampling_config)
98
+ hf_output_text = hf_outputs[0].outputs[0].text
99
+ del hf_llm
100
+ time.sleep(10) # Wait for TPUs to be released
101
+
102
+ assert gcs_output_text == hf_output_text, (
103
+ f"Outputs do not match! "
104
+ f"GCS output: {gcs_output_text}, HF output: {hf_output_text}")
@@ -0,0 +1,269 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This file contains end-to-end tests for sampling parameters.
16
+ #
17
+ # Sampling parameters control how the model selects tokens during generation.
18
+ # These tests verify that temperature, top_p, top_k, and logprobs work correctly.
19
+ #
20
+ # The tests in this file verify that:
21
+ # 1. Temperature=0 produces deterministic (greedy) outputs
22
+ # 2. Higher temperature produces more varied outputs
23
+ # 3. top_p (nucleus sampling) correctly constrains token selection
24
+ # 4. top_k correctly limits the number of candidate tokens
25
+ # 5. logprobs returns probability information for generated tokens
26
+
27
+ from __future__ import annotations
28
+
29
+ import pytest
30
+ from vllm import LLM, SamplingParams
31
+
32
+
33
+ @pytest.fixture(scope="module")
34
+ def llm():
35
+ """Create a shared LLM instance for all tests in this module."""
36
+ return LLM(
37
+ model='meta-llama/Llama-3.2-1B-Instruct',
38
+ max_model_len=1024,
39
+ max_num_seqs=4,
40
+ enable_prefix_caching=False,
41
+ )
42
+
43
+
44
+ class TestTemperature:
45
+ """Tests for temperature sampling parameter."""
46
+
47
+ def test_temperature_zero_is_deterministic(self, llm: LLM):
48
+ """Temperature=0 should produce identical outputs across multiple runs."""
49
+ prompt = "What is 2 + 2? Answer with just the number:"
50
+ sampling_params = SamplingParams(temperature=0, max_tokens=10)
51
+
52
+ outputs1 = llm.generate([prompt], sampling_params)
53
+ outputs2 = llm.generate([prompt], sampling_params)
54
+
55
+ assert outputs1[0].outputs[0].text == outputs2[0].outputs[0].text
56
+
57
+ def test_high_temperature_produces_variation(self, llm: LLM):
58
+ """High temperature should produce varied outputs across multiple runs."""
59
+ prompt = "Write a random word:"
60
+ sampling_params = SamplingParams(temperature=2,
61
+ max_tokens=10,
62
+ top_k=4096)
63
+
64
+ # Run multiple times and collect unique outputs
65
+ unique_outputs = set()
66
+ num_runs = 10
67
+ for _ in range(num_runs):
68
+ outputs = llm.generate([prompt], sampling_params)
69
+ unique_outputs.add(outputs[0].outputs[0].text)
70
+
71
+ # With high temperature, we expect some variation
72
+ assert len(unique_outputs) > 1, (
73
+ "High temperature should produce varied outputs")
74
+
75
+
76
+ class TestTopP:
77
+ """Tests for top_p (nucleus sampling) parameter."""
78
+
79
+ def test_top_p_restricts_sampling(self, llm: LLM):
80
+ """top_p=1.0 vs lower values should affect output diversity."""
81
+ prompt = "Name a color:"
82
+
83
+ # With top_p=1.0 (consider all tokens)
84
+ sampling_params_full = SamplingParams(temperature=0.8,
85
+ top_p=1.0,
86
+ max_tokens=5)
87
+
88
+ # With top_p=0.1 (very restrictive, only top tokens)
89
+ sampling_params_restricted = SamplingParams(temperature=0.8,
90
+ top_p=0.1,
91
+ max_tokens=5)
92
+
93
+ # Collect outputs with full nucleus
94
+ full_outputs = set()
95
+ for _ in range(10):
96
+ outputs = llm.generate([prompt], sampling_params_full)
97
+ full_outputs.add(outputs[0].outputs[0].text)
98
+
99
+ # Collect outputs with restricted nucleus
100
+ restricted_outputs = set()
101
+ for _ in range(10):
102
+ outputs = llm.generate([prompt], sampling_params_restricted)
103
+ restricted_outputs.add(outputs[0].outputs[0].text)
104
+
105
+ # Restricted top_p should generally produce less variety
106
+ # (though this isn't guaranteed, it's a reasonable expectation)
107
+ assert len(
108
+ restricted_outputs) >= 1, "Should produce at least one output"
109
+ assert len(full_outputs) >= 1, "Should produce at least one output"
110
+
111
+ def test_top_p_with_temperature_zero(self, llm: LLM):
112
+ """top_p should have no effect when temperature=0 (greedy)."""
113
+ prompt = "The capital of France is"
114
+
115
+ sampling_params_1 = SamplingParams(temperature=0,
116
+ top_p=0.1,
117
+ max_tokens=10)
118
+ sampling_params_2 = SamplingParams(temperature=0,
119
+ top_p=0.9,
120
+ max_tokens=10)
121
+
122
+ outputs1 = llm.generate([prompt], sampling_params_1)
123
+ outputs2 = llm.generate([prompt], sampling_params_2)
124
+
125
+ # Both should produce identical outputs since temperature=0
126
+ assert outputs1[0].outputs[0].text == outputs2[0].outputs[0].text
127
+
128
+
129
+ class TestTopK:
130
+ """Tests for top_k sampling parameter."""
131
+
132
+ def test_top_k_restricts_sampling(self, llm: LLM):
133
+ """top_k should limit the candidate tokens for sampling."""
134
+ prompt = "Pick a number between 1 and 10:"
135
+
136
+ # top_k=1 is equivalent to greedy (always pick the most likely)
137
+ sampling_params_k1 = SamplingParams(temperature=1.0,
138
+ top_k=1,
139
+ max_tokens=5)
140
+
141
+ # top_k=-1 considers all tokens
142
+ sampling_params_all = SamplingParams(temperature=1.0,
143
+ top_k=-1,
144
+ max_tokens=5)
145
+
146
+ # With top_k=1, outputs should be deterministic
147
+ outputs_k1_run1 = llm.generate([prompt], sampling_params_k1)
148
+ outputs_k1_run2 = llm.generate([prompt], sampling_params_k1)
149
+ assert outputs_k1_run1[0].outputs[0].text == outputs_k1_run2[
150
+ 0].outputs[0].text
151
+
152
+ # With top_k=-1 and temperature=1.0, we may see variation
153
+ all_outputs = set()
154
+ for _ in range(10):
155
+ outputs = llm.generate([prompt], sampling_params_all)
156
+ all_outputs.add(outputs[0].outputs[0].text)
157
+
158
+ # Should produce at least one valid output
159
+ assert len(all_outputs) >= 1
160
+
161
+ def test_top_k_with_temperature_zero(self, llm: LLM):
162
+ """top_k should have no effect when temperature=0 (greedy)."""
163
+ prompt = "The largest planet is"
164
+
165
+ sampling_params_k5 = SamplingParams(temperature=0,
166
+ top_k=5,
167
+ max_tokens=10)
168
+ sampling_params_k50 = SamplingParams(temperature=0,
169
+ top_k=50,
170
+ max_tokens=10)
171
+
172
+ outputs1 = llm.generate([prompt], sampling_params_k5)
173
+ outputs2 = llm.generate([prompt], sampling_params_k50)
174
+
175
+ # Both should produce identical outputs since temperature=0
176
+ assert outputs1[0].outputs[0].text == outputs2[0].outputs[0].text
177
+
178
+
179
+ class TestLogprobs:
180
+ """Tests for logprobs parameter."""
181
+
182
+ def test_logprobs_returns_probabilities(self, llm: LLM):
183
+ """logprobs parameter should return log probabilities for tokens."""
184
+ prompt = "Hello"
185
+ sampling_params = SamplingParams(temperature=0,
186
+ max_tokens=5,
187
+ logprobs=5)
188
+
189
+ outputs = llm.generate([prompt], sampling_params)
190
+ output = outputs[0].outputs[0]
191
+
192
+ # Check that logprobs are returned
193
+ assert output.logprobs is not None, "logprobs should be returned"
194
+ assert len(output.logprobs) > 0, "logprobs should contain entries"
195
+
196
+ # Each token should have logprob information
197
+ for token_logprobs in output.logprobs:
198
+ assert token_logprobs is not None
199
+ # Should have up to 5 top logprobs as requested
200
+ assert len(token_logprobs) <= 5
201
+
202
+ def test_logprobs_none_returns_no_probabilities(self, llm: LLM):
203
+ """When logprobs=None, no log probabilities should be returned."""
204
+ prompt = "Hello"
205
+ sampling_params = SamplingParams(temperature=0,
206
+ max_tokens=5,
207
+ logprobs=None)
208
+
209
+ outputs = llm.generate([prompt], sampling_params)
210
+ output = outputs[0].outputs[0]
211
+
212
+ # logprobs should be None when not requested
213
+ assert output.logprobs is None, "logprobs should be None when not requested"
214
+
215
+ def test_logprobs_values_are_valid(self, llm: LLM):
216
+ """Log probabilities should be valid (negative or zero)."""
217
+ prompt = "The sky is"
218
+ sampling_params = SamplingParams(temperature=0,
219
+ max_tokens=3,
220
+ logprobs=3)
221
+
222
+ outputs = llm.generate([prompt], sampling_params)
223
+ output = outputs[0].outputs[0]
224
+
225
+ assert output.logprobs is not None
226
+ for token_logprobs in output.logprobs:
227
+ for token_id, logprob_obj in token_logprobs.items():
228
+ # Log probabilities should be <= 0
229
+ assert logprob_obj.logprob <= 0, (
230
+ f"Log probability should be <= 0, got {logprob_obj.logprob}"
231
+ )
232
+
233
+
234
+ class TestCombinedParameters:
235
+ """Tests for combinations of sampling parameters."""
236
+
237
+ def test_top_p_and_top_k_combined(self, llm: LLM):
238
+ """top_p and top_k can be used together."""
239
+ prompt = "List a fruit:"
240
+ sampling_params = SamplingParams(
241
+ temperature=0.7,
242
+ top_p=0.9,
243
+ top_k=50,
244
+ max_tokens=10,
245
+ )
246
+
247
+ outputs = llm.generate([prompt], sampling_params)
248
+ assert len(outputs[0].outputs[0].text) > 0
249
+
250
+ def test_all_params_with_logprobs(self, llm: LLM):
251
+ """All sampling parameters should work together with logprobs."""
252
+ prompt = "Complete this sentence: The weather today is"
253
+ sampling_params = SamplingParams(
254
+ temperature=0.5,
255
+ top_p=0.95,
256
+ top_k=40,
257
+ max_tokens=10,
258
+ logprobs=3,
259
+ )
260
+
261
+ outputs = llm.generate([prompt], sampling_params)
262
+ output = outputs[0].outputs[0]
263
+
264
+ # Should have generated text
265
+ assert len(output.text) > 0
266
+
267
+ # Should have logprobs
268
+ assert output.logprobs is not None
269
+ assert len(output.logprobs) > 0
@@ -0,0 +1,311 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ import random
19
+ import string
20
+ import time
21
+
22
+ import pytest
23
+ from vllm import LLM, SamplingParams
24
+
25
+
26
+ # TODO (Qiliang Cui): remove this when XLA fixes the recursive jit call issue.
27
+ def _is_v7x():
28
+ # jax.devices() will hang so use IS_FOR_V7X to indicate the version.
29
+ return os.environ.get("IS_FOR_V7X", "false") == "true"
30
+
31
+
32
+ def _get_tensor_parallel_size():
33
+ # Work around an XLA issue.
34
+ if _is_v7x():
35
+ return 2
36
+ return 1
37
+
38
+
39
+ def get_ngram_test_prompts():
40
+ num_prompts = 100
41
+ prompts = []
42
+
43
+ for _ in range(num_prompts):
44
+ w = random.choice(list(string.ascii_lowercase))
45
+ prompts.append(
46
+ f"Keep repeating: {w} {w} {w} {w} {w} {w} {w} {w} {w} {w}")
47
+
48
+ return prompts
49
+
50
+
51
+ def get_eagle3_test_prompts():
52
+ num_prompts = 100
53
+ prompts = []
54
+
55
+ for _ in range(num_prompts):
56
+ prompts.append(
57
+ "Predict the continuation of this sequence: 1 2 3 4 5 6 7 8")
58
+
59
+ return prompts
60
+
61
+
62
+ def get_test_prompts(speculative_config: dict):
63
+ if speculative_config['method'] == 'ngram':
64
+ return get_ngram_test_prompts()
65
+ elif speculative_config['method'] == 'eagle3':
66
+ return get_eagle3_test_prompts()
67
+ else:
68
+ raise NotImplementedError(
69
+ f"{speculative_config['method']} is not supported yet.")
70
+
71
+
72
+ @pytest.fixture
73
+ def sampling_config():
74
+ return SamplingParams(temperature=0,
75
+ max_tokens=32,
76
+ ignore_eos=True,
77
+ repetition_penalty=1,
78
+ frequency_penalty=0,
79
+ presence_penalty=0,
80
+ min_p=0,
81
+ logprobs=None)
82
+
83
+
84
+ @pytest.fixture
85
+ def model_name():
86
+ return "Qwen/Qwen2.5-0.5B-Instruct"
87
+
88
+
89
+ # TODO(pooyam): run vLLM engine with InProcClient (`VLLM_ENABLE_V1_MULTIPROCESSING = 0`) mode to avoid TPU contention among processes.
90
+ def _test_correctness_helper(
91
+ monkeypatch: pytest.MonkeyPatch,
92
+ sampling_config: SamplingParams,
93
+ model_name: str,
94
+ speculative_config: dict,
95
+ ):
96
+ '''
97
+ Helper function to test ngram correctness.
98
+ Compare the outputs of a original LLM and a speculative LLM
99
+ should be the same when using ngram speculative decoding.
100
+ '''
101
+ with monkeypatch.context():
102
+ test_prompts = get_test_prompts(speculative_config)
103
+
104
+ ref_llm = LLM(model=model_name,
105
+ max_model_len=1024,
106
+ max_num_seqs=4,
107
+ tensor_parallel_size=_get_tensor_parallel_size())
108
+ ref_outputs = ref_llm.generate(test_prompts, sampling_config)
109
+
110
+ del ref_llm
111
+
112
+ # Waiting for TPUs to be released.
113
+ time.sleep(10)
114
+
115
+ spec_llm = LLM(model=model_name,
116
+ speculative_config=speculative_config,
117
+ max_model_len=1024,
118
+ max_num_seqs=4,
119
+ tensor_parallel_size=_get_tensor_parallel_size())
120
+ spec_outputs = spec_llm.generate(test_prompts, sampling_config)
121
+
122
+ matches = 0
123
+ misses = 0
124
+ for ref_output, spec_output in zip(ref_outputs, spec_outputs):
125
+ if ref_output.outputs[0].text == spec_output.outputs[0].text:
126
+ matches += 1
127
+ else:
128
+ misses += 1
129
+ print(f"ref_output: {ref_output.outputs[0].text}")
130
+ print(f"spec_output: {spec_output.outputs[0].text}")
131
+
132
+ assert misses == 0
133
+ del spec_llm
134
+
135
+ # Waiting for TPUs to be released.
136
+ time.sleep(10)
137
+
138
+
139
+ def test_ngram_correctness_greedy(
140
+ monkeypatch: pytest.MonkeyPatch,
141
+ sampling_config: SamplingParams,
142
+ model_name: str,
143
+ ):
144
+ '''
145
+ Compare the outputs of a original LLM and a speculative LLM
146
+ should be the same when using ngram speculative decoding with greedy sampling.
147
+ '''
148
+ _test_correctness_helper(
149
+ monkeypatch, sampling_config, model_name, {
150
+ "method": "ngram",
151
+ "prompt_lookup_max": 5,
152
+ "prompt_lookup_min": 3,
153
+ "num_speculative_tokens": 3,
154
+ })
155
+
156
+
157
+ def test_ngram_correctness_random(
158
+ monkeypatch: pytest.MonkeyPatch,
159
+ sampling_config: SamplingParams,
160
+ model_name: str,
161
+ ):
162
+ '''
163
+ Compare the outputs of a original LLM and a speculative LLM
164
+ should be the same when using ngram speculative decoding with random sampling.
165
+ '''
166
+ # Modify sampling config for random sampling
167
+ sampling_config.temperature = 0.01
168
+ sampling_config.top_p = 0.9
169
+ sampling_config.top_k = 5
170
+
171
+ _test_correctness_helper(
172
+ monkeypatch, sampling_config, model_name, {
173
+ "method": "ngram",
174
+ "prompt_lookup_max": 5,
175
+ "prompt_lookup_min": 3,
176
+ "num_speculative_tokens": 3,
177
+ })
178
+
179
+
180
+ def _test_performance_helper(
181
+ monkeypatch: pytest.MonkeyPatch,
182
+ sampling_config: SamplingParams,
183
+ speculative_config: dict,
184
+ min_speedup: float,
185
+ ):
186
+ '''
187
+ Helper function to test speculative decoding performance.
188
+ Compares timing between reference LLM and speculative LLM using Llama 3 8B.
189
+ '''
190
+ model_name = "meta-llama/Llama-3.1-8B-Instruct"
191
+
192
+ with monkeypatch.context():
193
+ # Use a smaller set of prompts for performance testing
194
+ test_prompts = get_test_prompts(speculative_config)
195
+
196
+ # Test reference LLM timing
197
+ ref_llm = LLM(model=model_name,
198
+ max_model_len=1024,
199
+ max_num_seqs=1,
200
+ enable_prefix_caching=False,
201
+ tensor_parallel_size=_get_tensor_parallel_size())
202
+
203
+ start_time = time.time()
204
+ _ = ref_llm.generate(test_prompts, sampling_config)
205
+ ref_time = time.time() - start_time
206
+
207
+ del ref_llm
208
+
209
+ # Waiting for TPUs to be released
210
+ time.sleep(10)
211
+
212
+ # Test speculative LLM timing with max_num_seqs=1
213
+ spec_llm = LLM(model=model_name,
214
+ speculative_config=speculative_config,
215
+ max_model_len=1024,
216
+ max_num_seqs=1,
217
+ tensor_parallel_size=_get_tensor_parallel_size(),
218
+ enable_prefix_caching=False)
219
+
220
+ start_time = time.time()
221
+ _ = spec_llm.generate(test_prompts, sampling_config)
222
+ spec_time = time.time() - start_time
223
+
224
+ del spec_llm
225
+ # Waiting for TPUs to be released
226
+ time.sleep(10)
227
+
228
+ speedup = ref_time / spec_time
229
+ print(f"Reference LLM time: {ref_time:.2f}s")
230
+ print(f"Speculative LLM time: {spec_time:.2f}s")
231
+ print(f"Speedup: {speedup:.2f}x")
232
+
233
+ # TODO(pooyam): Make this tighter once we have better performance.
234
+ assert speedup >= min_speedup, f"Expected at least {min_speedup}x speedup for {speculative_config['method']}, got {speedup:.2f}x"
235
+
236
+
237
+ def test_ngram_performance_greedy(
238
+ monkeypatch: pytest.MonkeyPatch,
239
+ sampling_config: SamplingParams,
240
+ ):
241
+ '''
242
+ Test that speculative decoding provides significant performance improvement.
243
+ Compares timing between reference LLM and speculative LLM using Llama 3 8B.
244
+ Expects spec_llm to be at least 3.x faster than ref_llm.
245
+ '''
246
+ _test_performance_helper(
247
+ monkeypatch, sampling_config, {
248
+ "method": "ngram",
249
+ "prompt_lookup_max": 2,
250
+ "prompt_lookup_min": 2,
251
+ "num_speculative_tokens": 4,
252
+ }, 1.2 if _is_v7x() else 3.0)
253
+
254
+
255
+ def test_ngram_performance_random(
256
+ monkeypatch: pytest.MonkeyPatch,
257
+ sampling_config: SamplingParams,
258
+ ):
259
+ '''
260
+ Test that speculative decoding provides significant performance improvement.
261
+ Compares timing between reference LLM and speculative LLM using Llama 3 8B.
262
+ Expects spec_llm to be at least 3.x faster than ref_llm.
263
+ '''
264
+ sampling_config.temperature = 0.01
265
+ sampling_config.top_p = 0.9
266
+ sampling_config.top_k = 5
267
+
268
+ _test_performance_helper(
269
+ monkeypatch, sampling_config, {
270
+ "method": "ngram",
271
+ "prompt_lookup_max": 2,
272
+ "prompt_lookup_min": 2,
273
+ "num_speculative_tokens": 4,
274
+ }, 1.5 if _is_v7x() else 3.0)
275
+
276
+
277
+ def test_eagle3_correctness(
278
+ monkeypatch: pytest.MonkeyPatch,
279
+ sampling_config: SamplingParams,
280
+ ):
281
+ '''
282
+ Compare the outputs of a original LLM and a speculative LLM
283
+ should be the same when using eagle-3 speculative decoding.
284
+ '''
285
+ model_name = 'meta-llama/Meta-Llama-3-8B-Instruct'
286
+
287
+ _test_correctness_helper(
288
+ monkeypatch, sampling_config, model_name, {
289
+ 'model': "unkmaster/EAGLE3-LLaMA3.1-Instruct-8B",
290
+ "num_speculative_tokens": 3,
291
+ "method": "eagle3",
292
+ "draft_tensor_parallel_size": 1
293
+ })
294
+
295
+
296
+ def test_eagle3_performance(
297
+ monkeypatch: pytest.MonkeyPatch,
298
+ sampling_config: SamplingParams,
299
+ ):
300
+ '''
301
+ Test that speculative decoding provides significant performance improvement.
302
+ Compares timing between reference LLM and speculative LLM using Llama 3 8B.
303
+ Expects spec_llm to be at least 1.8 faster than ref_llm.
304
+ '''
305
+ _test_performance_helper(
306
+ monkeypatch, sampling_config, {
307
+ "method": "eagle3",
308
+ "model": "unkmaster/EAGLE3-LLaMA3.1-Instruct-8B",
309
+ "num_speculative_tokens": 2,
310
+ "draft_tensor_parallel_size": 1
311
+ }, 1.2 if _is_v7x() else 1.8)