tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,265 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import os
5
+ import time
6
+ from dataclasses import asdict
7
+
8
+ import pytest
9
+ from vllm import LLM, EngineArgs, SamplingParams
10
+
11
+
12
+ @pytest.fixture
13
+ def model_name():
14
+ """Choose LLama3 8b as the test model as it supports PP on jax model impl."""
15
+ return "meta-llama/Llama-3.1-8B-Instruct"
16
+
17
+
18
+ @pytest.fixture
19
+ def test_prompts():
20
+ """Simple test prompts for data parallelism testing."""
21
+ return [
22
+ "Hello, my name is",
23
+ "The capital of France is",
24
+ "The colors of the rainbow are",
25
+ "The future of AI is",
26
+ "The president of the United States is",
27
+ "How many players are on a standard soccer team?",
28
+ "In Greek mythology, who is the god of the sea?",
29
+ "What is the capital of Australia?",
30
+ "What is the largest planet in our solar system?",
31
+ "Who developed the theory of general relativity?",
32
+ ]
33
+
34
+
35
+ @pytest.fixture
36
+ def sampling_params():
37
+ """Standard sampling parameters for testing."""
38
+ return SamplingParams(
39
+ temperature=0.0,
40
+ max_tokens=32,
41
+ ignore_eos=True,
42
+ logprobs=1,
43
+ )
44
+
45
+
46
+ def _run_inference_with_config(model_name: str,
47
+ test_prompts: list,
48
+ sampling_params: SamplingParams,
49
+ tensor_parallel_size: int = 1,
50
+ pipeline_parallel_size: int = 1,
51
+ additional_config: dict = {},
52
+ kv_cache_dtype: str = "auto",
53
+ enable_prefix_caching: bool = False) -> list:
54
+ """Helper function to run inference with specified configuration."""
55
+
56
+ # Create LLM args using parser-based approach similar to offline_inference.py
57
+ engine_args = EngineArgs(
58
+ model=model_name,
59
+ max_model_len=128,
60
+ tensor_parallel_size=tensor_parallel_size,
61
+ pipeline_parallel_size=pipeline_parallel_size,
62
+ gpu_memory_utilization=0.95,
63
+ max_num_batched_tokens=128,
64
+ max_num_seqs=16,
65
+ enable_prefix_caching=enable_prefix_caching,
66
+ additional_config=additional_config,
67
+ kv_cache_dtype=kv_cache_dtype,
68
+ )
69
+
70
+ engine_args_dict = asdict(engine_args)
71
+ llm = LLM(**engine_args_dict)
72
+
73
+ try:
74
+ outputs = llm.generate(test_prompts, sampling_params)
75
+ return outputs
76
+ finally:
77
+ del llm
78
+ # Wait for TPUs to be released
79
+ time.sleep(5)
80
+
81
+
82
+ @pytest.mark.skip(reason="PP is not fully enabled.")
83
+ def test_pipeline_parallelism_jax_model(
84
+ model_name: str,
85
+ test_prompts: list,
86
+ sampling_params: SamplingParams,
87
+ ):
88
+ """
89
+ Test pipline parallelism works on Jax models
90
+
91
+ Equivalent to:
92
+ python examples/offline_inference.py --tensor_parallel_size=1 --pipeline_parallel_size=2
93
+ """
94
+ # Test with pipeline parallelism enabled
95
+ outputs = _run_inference_with_config(
96
+ model_name=model_name,
97
+ test_prompts=test_prompts,
98
+ sampling_params=sampling_params,
99
+ tensor_parallel_size=1,
100
+ pipeline_parallel_size=2,
101
+ )
102
+
103
+ # Verify we got outputs for all prompts
104
+ assert len(outputs) == len(test_prompts)
105
+
106
+ # Verify each output has generated text
107
+ for output in outputs:
108
+ assert len(output.outputs) > 0
109
+ assert len(output.outputs[0].text.strip()) > 0
110
+
111
+ print(
112
+ f"✓ Pipeline Parallelism Jax model test passed with {len(outputs)} outputs"
113
+ )
114
+
115
+
116
+ @pytest.mark.skip(reason="PP is not fully enabled.")
117
+ def test_pipeline_parallelism_vllm_model(
118
+ model_name: str,
119
+ test_prompts: list,
120
+ sampling_params: SamplingParams,
121
+ ):
122
+ """
123
+ Test pipline parallelism works on vLLM models, and it also works with
124
+ with tensor parallelism.
125
+
126
+ Equivalent to:
127
+ MODEL_IMPL_TYPE=vllm python examples/offline_inference.py --tensor_parallel_size=1 --pipeline_parallel_size=2
128
+ """
129
+
130
+ os.environ['MODEL_IMPL_TYPE'] = 'vllm'
131
+ # Test with data parallelism enabled
132
+ outputs = _run_inference_with_config(
133
+ model_name=model_name,
134
+ test_prompts=test_prompts,
135
+ sampling_params=sampling_params,
136
+ tensor_parallel_size=1,
137
+ pipeline_parallel_size=2,
138
+ )
139
+
140
+ # Verify we got outputs for all prompts
141
+ assert len(outputs) == len(test_prompts)
142
+
143
+ # Verify each output has generated text
144
+ for output in outputs:
145
+ assert len(output.outputs) > 0
146
+ assert len(output.outputs[0].text.strip()) > 0
147
+
148
+ print(
149
+ f"✓ Pipeline Parallelism vLLM model test passed with {len(outputs)} outputs"
150
+ )
151
+
152
+
153
+ @pytest.mark.skip(reason="PP is not fully enabled.")
154
+ def test_pipeline_parallelism_jax_model_correctness(
155
+ model_name: str,
156
+ test_prompts: list,
157
+ sampling_params: SamplingParams,
158
+ ):
159
+ """
160
+ Test that pipeline parallelism produces consistent results compared to a baseline.
161
+ This test compares outputs from a single-device run with pipeline parallel runs
162
+ to ensure correctness, including log probabilities.
163
+ """
164
+ os.environ['SKIP_JAX_PRECOMPILE'] = '1'
165
+ os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
166
+
167
+ # Use a smaller subset of prompts for correctness testing
168
+ small_prompts = test_prompts[:10]
169
+
170
+ # Run baseline (no PP)
171
+ baseline_outputs = _run_inference_with_config(
172
+ model_name=model_name,
173
+ test_prompts=small_prompts,
174
+ sampling_params=sampling_params,
175
+ tensor_parallel_size=1,
176
+ pipeline_parallel_size=1,
177
+ )
178
+
179
+ # Run with model data parallelism and async scheduling
180
+ pp_outputs = _run_inference_with_config(
181
+ model_name=model_name,
182
+ test_prompts=small_prompts,
183
+ sampling_params=sampling_params,
184
+ tensor_parallel_size=1,
185
+ pipeline_parallel_size=2,
186
+ )
187
+
188
+ # Compare outputs - in theory they should be identical for greedy sampling
189
+ # in reality there may be some differences, but overall the outputs should
190
+ # be very similar.
191
+
192
+ # an example:
193
+ # prompt: What is the capital of Australia?
194
+ # both answers should be acceptable.
195
+ # The capital of Australia is Canberra. It is located in the Australian Capital Territory (ACT) and is home to many
196
+ # Canberra is the capital of Australia. It is located in the Australian Capital Territory (ACT) and is home to
197
+ assert len(baseline_outputs) == len(pp_outputs)
198
+
199
+ text_matches = 0
200
+ text_mismatches = 0
201
+ logprob_mismatches = 0
202
+ max_logprob_diff = 0.0
203
+
204
+ for i, (baseline, pp_result) in enumerate(zip(baseline_outputs,
205
+ pp_outputs)):
206
+ baseline_text = baseline.outputs[0].text.strip()
207
+ pp_text = pp_result.outputs[0].text.strip()
208
+
209
+ # Check text output
210
+ if baseline_text == pp_text:
211
+ text_matches += 1
212
+ else:
213
+ text_mismatches += 1
214
+ print(f"Text mismatch found in prompt {i}:")
215
+ print(f" Baseline: {baseline_text}")
216
+ print(f" Pipeline Parallel: {pp_text}")
217
+
218
+ # Check log probabilities
219
+ baseline_logprobs = baseline.outputs[0].logprobs
220
+ pp_logprobs = pp_result.outputs[0].logprobs
221
+ if baseline_logprobs is not None and pp_logprobs is not None:
222
+ # Compare log probabilities for each token
223
+ assert len(baseline_logprobs) == len(pp_logprobs), \
224
+ f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(pp_logprobs)}"
225
+ for token_idx, (base_lp, pp_lp) in enumerate(
226
+ zip(baseline_logprobs, pp_logprobs)):
227
+ # Get the top logprob value for the selected token
228
+ if base_lp and pp_lp:
229
+ # Get the top token's logprob from each
230
+ base_top_token = list(base_lp.keys())[0]
231
+ pp_top_token = list(pp_lp.keys())[0]
232
+
233
+ base_logprob_val = base_lp[base_top_token].logprob
234
+ pp_logprob_val = pp_lp[pp_top_token].logprob
235
+
236
+ # Calculate absolute difference
237
+ diff = abs(base_logprob_val - pp_logprob_val)
238
+ max_logprob_diff = max(max_logprob_diff, diff)
239
+
240
+ # Allow small numerical differences (e.g., 1e-3)
241
+ if diff > 1e-3:
242
+ logprob_mismatches += 1
243
+ print(
244
+ f"Logprob mismatch in prompt {i}, token {token_idx}:"
245
+ )
246
+ print(
247
+ f" Baseline token: {base_top_token}, logprob: {base_logprob_val:.6f}"
248
+ )
249
+ print(
250
+ f" PP token: {pp_top_token}, logprob: {pp_logprob_val:.6f}"
251
+ )
252
+ print(f" Difference: {diff:.6f}")
253
+
254
+ print("✓ Correctness test results:")
255
+ print(f" Text: {text_matches} matches, {text_mismatches} mismatches")
256
+ print(f" Max logprob difference: {max_logprob_diff:.6e}")
257
+ print(f" Significant logprob mismatches (>1e-3): {logprob_mismatches}")
258
+
259
+ # Allow for some variance due to potential numerical differences
260
+ # but most outputs should match with greedy sampling
261
+ text_match_rate = text_matches / len(baseline_outputs)
262
+ assert text_match_rate >= 0.9, f"Text match rate {text_match_rate:.2%} is too low"
263
+
264
+ # Log probabilities should be very close (allow small numerical errors)
265
+ assert max_logprob_diff < 1, f"Max logprob difference {max_logprob_diff} is too large"
@@ -0,0 +1,104 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This file contains end-to-end tests for the RunAI Model Streamer loader.
16
+ #
17
+ # The RunAI Model Streamer is a high-performance model loader that serves as an
18
+ # alternative to the default Hugging Face loader. Instead of downloading a model
19
+ # to local disk, it streams the weights from object storage (like GCS) into
20
+ # GPU memory. This streaming process is significantly faster than the
21
+ # traditional disk-based loading method.
22
+
23
+ # The tests in this file verify that loading model weights using the
24
+ # streamer produces the same results as loading the same model using the
25
+ # standard Hugging Face loader. This ensures the correctness of the streamer
26
+ # integration.
27
+
28
+ # The tests are performed by:
29
+ # 1. Loading a model from Google Cloud Storage using the `runai_streamer` format.
30
+ # 2. Generating output with this model.
31
+ # 3. Loading the same model from Hugging Face using the default loader.
32
+ # 4. Generating output with this second model.
33
+ # 5. Asserting that the outputs from both models are identical.
34
+
35
+ from __future__ import annotations
36
+
37
+ import time
38
+
39
+ import pytest
40
+ from vllm import LLM, SamplingParams
41
+
42
+
43
+ @pytest.fixture
44
+ def sampling_config():
45
+ return SamplingParams(temperature=0, max_tokens=10, ignore_eos=True)
46
+
47
+
48
+ @pytest.fixture
49
+ # TODO(amacaskill): Replace with GKE owned GCS bucket.
50
+ def gcs_model_name():
51
+ return "gs://vertex-model-garden-public-us/llama3/llama3-8b-hf"
52
+
53
+
54
+ @pytest.fixture
55
+ def hf_model_name():
56
+ return "meta-llama/Meta-Llama-3-8B"
57
+
58
+
59
+ @pytest.fixture
60
+ def prompt():
61
+ return "Hello, my name is"
62
+
63
+
64
+ def test_correctness(
65
+ sampling_config: SamplingParams,
66
+ gcs_model_name: str,
67
+ hf_model_name: str,
68
+ prompt: str,
69
+ monkeypatch: pytest.MonkeyPatch,
70
+ ):
71
+ '''
72
+ Compare the outputs of a model loaded from GCS via runai_model_streamer
73
+ and a model loaded from Hugging Face. The outputs should be the same.
74
+ These tests attempt to use tensor_parallel_size=1. The model is 16GB,
75
+ # and v6e has 32GB of HBM, so it will fit.
76
+ '''
77
+ # Set ENV variables so that runai_model_streamer uses anonymous GCS access.
78
+ monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake-project")
79
+ monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true")
80
+ monkeypatch.setenv("CLOUD_STORAGE_EMULATOR_ENDPOINT",
81
+ "https://storage.googleapis.com")
82
+ gcs_llm = LLM(model=gcs_model_name,
83
+ load_format="runai_streamer",
84
+ max_model_len=128,
85
+ max_num_seqs=16,
86
+ max_num_batched_tokens=256)
87
+ gcs_outputs = gcs_llm.generate([prompt], sampling_config)
88
+ gcs_output_text = gcs_outputs[0].outputs[0].text
89
+ del gcs_llm
90
+ time.sleep(10) # Wait for TPUs to be released
91
+
92
+ # Test with Hugging Face model
93
+ hf_llm = LLM(model=hf_model_name,
94
+ max_model_len=128,
95
+ max_num_seqs=16,
96
+ max_num_batched_tokens=256)
97
+ hf_outputs = hf_llm.generate([prompt], sampling_config)
98
+ hf_output_text = hf_outputs[0].outputs[0].text
99
+ del hf_llm
100
+ time.sleep(10) # Wait for TPUs to be released
101
+
102
+ assert gcs_output_text == hf_output_text, (
103
+ f"Outputs do not match! "
104
+ f"GCS output: {gcs_output_text}, HF output: {hf_output_text}")
@@ -0,0 +1,269 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This file contains end-to-end tests for sampling parameters.
16
+ #
17
+ # Sampling parameters control how the model selects tokens during generation.
18
+ # These tests verify that temperature, top_p, top_k, and logprobs work correctly.
19
+ #
20
+ # The tests in this file verify that:
21
+ # 1. Temperature=0 produces deterministic (greedy) outputs
22
+ # 2. Higher temperature produces more varied outputs
23
+ # 3. top_p (nucleus sampling) correctly constrains token selection
24
+ # 4. top_k correctly limits the number of candidate tokens
25
+ # 5. logprobs returns probability information for generated tokens
26
+
27
+ from __future__ import annotations
28
+
29
+ import pytest
30
+ from vllm import LLM, SamplingParams
31
+
32
+
33
+ @pytest.fixture(scope="module")
34
+ def llm():
35
+ """Create a shared LLM instance for all tests in this module."""
36
+ return LLM(
37
+ model='meta-llama/Llama-3.2-1B-Instruct',
38
+ max_model_len=1024,
39
+ max_num_seqs=4,
40
+ enable_prefix_caching=False,
41
+ )
42
+
43
+
44
+ class TestTemperature:
45
+ """Tests for temperature sampling parameter."""
46
+
47
+ def test_temperature_zero_is_deterministic(self, llm: LLM):
48
+ """Temperature=0 should produce identical outputs across multiple runs."""
49
+ prompt = "What is 2 + 2? Answer with just the number:"
50
+ sampling_params = SamplingParams(temperature=0, max_tokens=10)
51
+
52
+ outputs1 = llm.generate([prompt], sampling_params)
53
+ outputs2 = llm.generate([prompt], sampling_params)
54
+
55
+ assert outputs1[0].outputs[0].text == outputs2[0].outputs[0].text
56
+
57
+ def test_high_temperature_produces_variation(self, llm: LLM):
58
+ """High temperature should produce varied outputs across multiple runs."""
59
+ prompt = "Write a random word:"
60
+ sampling_params = SamplingParams(temperature=2,
61
+ max_tokens=10,
62
+ top_k=4096)
63
+
64
+ # Run multiple times and collect unique outputs
65
+ unique_outputs = set()
66
+ num_runs = 10
67
+ for _ in range(num_runs):
68
+ outputs = llm.generate([prompt], sampling_params)
69
+ unique_outputs.add(outputs[0].outputs[0].text)
70
+
71
+ # With high temperature, we expect some variation
72
+ assert len(unique_outputs) > 1, (
73
+ "High temperature should produce varied outputs")
74
+
75
+
76
+ class TestTopP:
77
+ """Tests for top_p (nucleus sampling) parameter."""
78
+
79
+ def test_top_p_restricts_sampling(self, llm: LLM):
80
+ """top_p=1.0 vs lower values should affect output diversity."""
81
+ prompt = "Name a color:"
82
+
83
+ # With top_p=1.0 (consider all tokens)
84
+ sampling_params_full = SamplingParams(temperature=0.8,
85
+ top_p=1.0,
86
+ max_tokens=5)
87
+
88
+ # With top_p=0.1 (very restrictive, only top tokens)
89
+ sampling_params_restricted = SamplingParams(temperature=0.8,
90
+ top_p=0.1,
91
+ max_tokens=5)
92
+
93
+ # Collect outputs with full nucleus
94
+ full_outputs = set()
95
+ for _ in range(10):
96
+ outputs = llm.generate([prompt], sampling_params_full)
97
+ full_outputs.add(outputs[0].outputs[0].text)
98
+
99
+ # Collect outputs with restricted nucleus
100
+ restricted_outputs = set()
101
+ for _ in range(10):
102
+ outputs = llm.generate([prompt], sampling_params_restricted)
103
+ restricted_outputs.add(outputs[0].outputs[0].text)
104
+
105
+ # Restricted top_p should generally produce less variety
106
+ # (though this isn't guaranteed, it's a reasonable expectation)
107
+ assert len(
108
+ restricted_outputs) >= 1, "Should produce at least one output"
109
+ assert len(full_outputs) >= 1, "Should produce at least one output"
110
+
111
+ def test_top_p_with_temperature_zero(self, llm: LLM):
112
+ """top_p should have no effect when temperature=0 (greedy)."""
113
+ prompt = "The capital of France is"
114
+
115
+ sampling_params_1 = SamplingParams(temperature=0,
116
+ top_p=0.1,
117
+ max_tokens=10)
118
+ sampling_params_2 = SamplingParams(temperature=0,
119
+ top_p=0.9,
120
+ max_tokens=10)
121
+
122
+ outputs1 = llm.generate([prompt], sampling_params_1)
123
+ outputs2 = llm.generate([prompt], sampling_params_2)
124
+
125
+ # Both should produce identical outputs since temperature=0
126
+ assert outputs1[0].outputs[0].text == outputs2[0].outputs[0].text
127
+
128
+
129
+ class TestTopK:
130
+ """Tests for top_k sampling parameter."""
131
+
132
+ def test_top_k_restricts_sampling(self, llm: LLM):
133
+ """top_k should limit the candidate tokens for sampling."""
134
+ prompt = "Pick a number between 1 and 10:"
135
+
136
+ # top_k=1 is equivalent to greedy (always pick the most likely)
137
+ sampling_params_k1 = SamplingParams(temperature=1.0,
138
+ top_k=1,
139
+ max_tokens=5)
140
+
141
+ # top_k=-1 considers all tokens
142
+ sampling_params_all = SamplingParams(temperature=1.0,
143
+ top_k=-1,
144
+ max_tokens=5)
145
+
146
+ # With top_k=1, outputs should be deterministic
147
+ outputs_k1_run1 = llm.generate([prompt], sampling_params_k1)
148
+ outputs_k1_run2 = llm.generate([prompt], sampling_params_k1)
149
+ assert outputs_k1_run1[0].outputs[0].text == outputs_k1_run2[
150
+ 0].outputs[0].text
151
+
152
+ # With top_k=-1 and temperature=1.0, we may see variation
153
+ all_outputs = set()
154
+ for _ in range(10):
155
+ outputs = llm.generate([prompt], sampling_params_all)
156
+ all_outputs.add(outputs[0].outputs[0].text)
157
+
158
+ # Should produce at least one valid output
159
+ assert len(all_outputs) >= 1
160
+
161
+ def test_top_k_with_temperature_zero(self, llm: LLM):
162
+ """top_k should have no effect when temperature=0 (greedy)."""
163
+ prompt = "The largest planet is"
164
+
165
+ sampling_params_k5 = SamplingParams(temperature=0,
166
+ top_k=5,
167
+ max_tokens=10)
168
+ sampling_params_k50 = SamplingParams(temperature=0,
169
+ top_k=50,
170
+ max_tokens=10)
171
+
172
+ outputs1 = llm.generate([prompt], sampling_params_k5)
173
+ outputs2 = llm.generate([prompt], sampling_params_k50)
174
+
175
+ # Both should produce identical outputs since temperature=0
176
+ assert outputs1[0].outputs[0].text == outputs2[0].outputs[0].text
177
+
178
+
179
+ class TestLogprobs:
180
+ """Tests for logprobs parameter."""
181
+
182
+ def test_logprobs_returns_probabilities(self, llm: LLM):
183
+ """logprobs parameter should return log probabilities for tokens."""
184
+ prompt = "Hello"
185
+ sampling_params = SamplingParams(temperature=0,
186
+ max_tokens=5,
187
+ logprobs=5)
188
+
189
+ outputs = llm.generate([prompt], sampling_params)
190
+ output = outputs[0].outputs[0]
191
+
192
+ # Check that logprobs are returned
193
+ assert output.logprobs is not None, "logprobs should be returned"
194
+ assert len(output.logprobs) > 0, "logprobs should contain entries"
195
+
196
+ # Each token should have logprob information
197
+ for token_logprobs in output.logprobs:
198
+ assert token_logprobs is not None
199
+ # Should have up to 5 top logprobs as requested
200
+ assert len(token_logprobs) <= 5
201
+
202
+ def test_logprobs_none_returns_no_probabilities(self, llm: LLM):
203
+ """When logprobs=None, no log probabilities should be returned."""
204
+ prompt = "Hello"
205
+ sampling_params = SamplingParams(temperature=0,
206
+ max_tokens=5,
207
+ logprobs=None)
208
+
209
+ outputs = llm.generate([prompt], sampling_params)
210
+ output = outputs[0].outputs[0]
211
+
212
+ # logprobs should be None when not requested
213
+ assert output.logprobs is None, "logprobs should be None when not requested"
214
+
215
+ def test_logprobs_values_are_valid(self, llm: LLM):
216
+ """Log probabilities should be valid (negative or zero)."""
217
+ prompt = "The sky is"
218
+ sampling_params = SamplingParams(temperature=0,
219
+ max_tokens=3,
220
+ logprobs=3)
221
+
222
+ outputs = llm.generate([prompt], sampling_params)
223
+ output = outputs[0].outputs[0]
224
+
225
+ assert output.logprobs is not None
226
+ for token_logprobs in output.logprobs:
227
+ for token_id, logprob_obj in token_logprobs.items():
228
+ # Log probabilities should be <= 0
229
+ assert logprob_obj.logprob <= 0, (
230
+ f"Log probability should be <= 0, got {logprob_obj.logprob}"
231
+ )
232
+
233
+
234
+ class TestCombinedParameters:
235
+ """Tests for combinations of sampling parameters."""
236
+
237
+ def test_top_p_and_top_k_combined(self, llm: LLM):
238
+ """top_p and top_k can be used together."""
239
+ prompt = "List a fruit:"
240
+ sampling_params = SamplingParams(
241
+ temperature=0.7,
242
+ top_p=0.9,
243
+ top_k=50,
244
+ max_tokens=10,
245
+ )
246
+
247
+ outputs = llm.generate([prompt], sampling_params)
248
+ assert len(outputs[0].outputs[0].text) > 0
249
+
250
+ def test_all_params_with_logprobs(self, llm: LLM):
251
+ """All sampling parameters should work together with logprobs."""
252
+ prompt = "Complete this sentence: The weather today is"
253
+ sampling_params = SamplingParams(
254
+ temperature=0.5,
255
+ top_p=0.95,
256
+ top_k=40,
257
+ max_tokens=10,
258
+ logprobs=3,
259
+ )
260
+
261
+ outputs = llm.generate([prompt], sampling_params)
262
+ output = outputs[0].outputs[0]
263
+
264
+ # Should have generated text
265
+ assert len(output.text) > 0
266
+
267
+ # Should have logprobs
268
+ assert output.logprobs is not None
269
+ assert len(output.logprobs) > 0