tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,211 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import random
18
+ import string
19
+ import time
20
+
21
+ import pytest
22
+ from vllm import LLM, SamplingParams
23
+
24
+
25
+ @pytest.fixture
26
+ def sampling_config():
27
+ return SamplingParams(temperature=0,
28
+ max_tokens=120,
29
+ ignore_eos=True,
30
+ repetition_penalty=1,
31
+ frequency_penalty=0,
32
+ presence_penalty=0,
33
+ min_p=0,
34
+ logprobs=None)
35
+
36
+
37
+ @pytest.fixture
38
+ def model_name():
39
+ return "Qwen/Qwen2.5-1.5B-Instruct"
40
+
41
+
42
+ def get_test_prompts():
43
+ """
44
+ Generates a list of prompts with a specific word count,
45
+
46
+ Args:
47
+ num_prompts: The number of prompts to generate.
48
+ input_len_words: The total number of words for each prompt.
49
+
50
+ Returns:
51
+ A list of strings with number of prompts = num_prompts and
52
+ The total number of words for each prompt = input_len_words.
53
+ """
54
+ num_prompts = 500
55
+ input_len_words = 120
56
+ prompts = []
57
+
58
+ # For example w = 's'
59
+ # The generated prompt will be Keep repeating: s s s ...
60
+ num_repetitions = input_len_words
61
+ prefix = "Keep repeating: "
62
+
63
+ for _ in range(num_prompts):
64
+ # 1. Pick a random lowercase letter
65
+ w = random.choice(list(string.ascii_lowercase))
66
+
67
+ # 2. Create the string of repeated words
68
+ # This will have (num_repetitions) words
69
+ repeating_part = " ".join([w] * num_repetitions)
70
+
71
+ # 3. Combine with the prefix (if any)
72
+ print(f"{prefix}{repeating_part}")
73
+ prompts.append(f"{prefix}{repeating_part}")
74
+
75
+ return prompts
76
+
77
+
78
+ def _test_performance_helper(monkeypatch: pytest.MonkeyPatch,
79
+ sampling_config: SamplingParams, model_name: str,
80
+ min_speedup: float):
81
+ '''
82
+ Helper function to test async scheduler decoding performance.
83
+ Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
84
+ '''
85
+
86
+ with monkeypatch.context():
87
+ # Use a smaller set of prompts for performance testing
88
+ test_prompts = get_test_prompts() # num_prompts=100, input_len=120
89
+
90
+ # Test reference LLM timing
91
+ ref_llm = LLM(model=model_name,
92
+ max_model_len=800,
93
+ max_num_seqs=24,
94
+ max_num_batched_tokens=512,
95
+ enable_prefix_caching=False,
96
+ async_scheduling=0)
97
+
98
+ start_time = time.time()
99
+ _ = ref_llm.generate(test_prompts, sampling_config)
100
+ ref_time = time.time() - start_time
101
+
102
+ del ref_llm
103
+ # Waiting for TPUs to be released
104
+ time.sleep(10)
105
+
106
+ # # Test async LLM timing with max_num_seqs=256
107
+ async_llm = LLM(model=model_name,
108
+ max_model_len=800,
109
+ max_num_seqs=24,
110
+ max_num_batched_tokens=512,
111
+ enable_prefix_caching=False,
112
+ async_scheduling=1)
113
+
114
+ start_time = time.time()
115
+ _ = async_llm.generate(test_prompts, sampling_config)
116
+ async_time = time.time() - start_time
117
+
118
+ del async_llm
119
+ # # Waiting for TPUs to be released
120
+ time.sleep(10)
121
+
122
+ speedup = ref_time / async_time
123
+ print(f"Reference LLM time: {ref_time:.2f}s")
124
+ print(f"Async LLM time: {async_time:.2f}s")
125
+ print(f"Speedup: {speedup:.2f}x")
126
+
127
+ assert speedup >= min_speedup, f"Expected at least {min_speedup}x speedup for async scheduler, got {speedup:.2f}x"
128
+
129
+
130
+ def test_performance(
131
+ monkeypatch: pytest.MonkeyPatch,
132
+ sampling_config: SamplingParams,
133
+ model_name: str,
134
+ ):
135
+ '''
136
+ Test that async scheduler decoding provides significant performance improvement.
137
+ Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
138
+ Expects async_llm to be at least 1.3x faster than ref_llm.
139
+ '''
140
+ min_speed_up = 1.3
141
+ _test_performance_helper(monkeypatch, sampling_config, model_name,
142
+ min_speed_up)
143
+
144
+
145
+ def _test_correctness_helper(
146
+ monkeypatch: pytest.MonkeyPatch,
147
+ sampling_config: SamplingParams,
148
+ model_name: str,
149
+ ):
150
+ '''
151
+ Helper function to test async scheduler correctness.
152
+ Compare the outputs of a original LLM and a async LLM
153
+ should be the same when using async scheduler decoding.
154
+
155
+ Known Edge Case (KV Cache Swapping):
156
+ Under this case, though the temperature is set to 0,
157
+ the output is still slightly different everytime.
158
+ This is an expected behaviour as the normal scheduler also
159
+ behaves the same and hence, it is difficult to design a test
160
+ for such scenario.
161
+ '''
162
+ with monkeypatch.context():
163
+ test_prompts = get_test_prompts()
164
+
165
+ ref_llm = LLM(model=model_name,
166
+ max_model_len=1024,
167
+ max_num_seqs=100,
168
+ async_scheduling=0)
169
+ ref_outputs = ref_llm.generate(test_prompts, sampling_config)
170
+
171
+ del ref_llm
172
+
173
+ # Waiting for TPUs to be released.
174
+ time.sleep(10)
175
+
176
+ async_llm = LLM(model=model_name,
177
+ max_model_len=1024,
178
+ max_num_seqs=100,
179
+ async_scheduling=1)
180
+ async_outputs = async_llm.generate(test_prompts, sampling_config)
181
+
182
+ matches = 0
183
+ misses = 0
184
+ for ref_output, async_output in zip(ref_outputs, async_outputs):
185
+ if ref_output.outputs[0].text == async_output.outputs[0].text:
186
+ print(f"ref_output: {ref_output.outputs[0].text}")
187
+ print(f"async_output: {async_output.outputs[0].text}")
188
+ matches += 1
189
+ else:
190
+ misses += 1
191
+ print(f"ref_output: {ref_output.outputs[0].text}")
192
+ print(f"async_output: {async_output.outputs[0].text}")
193
+
194
+ assert misses == 0
195
+ del async_outputs
196
+
197
+ # Waiting for TPUs to be released.
198
+ time.sleep(10)
199
+
200
+
201
+ def test_async_correctness(
202
+ monkeypatch: pytest.MonkeyPatch,
203
+ sampling_config: SamplingParams,
204
+ model_name: str,
205
+ ):
206
+ '''
207
+ Compare the outputs of a original LLM and a async LLM
208
+ should be the same when using async scheduler.
209
+ '''
210
+
211
+ _test_correctness_helper(monkeypatch, sampling_config, model_name)
@@ -0,0 +1,393 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import os
5
+ import time
6
+ from dataclasses import asdict
7
+
8
+ import pytest
9
+ from vllm import LLM, EngineArgs, SamplingParams
10
+
11
+
12
+ @pytest.fixture(autouse=True)
13
+ def setup_new_model_design():
14
+ """Automatically set NEW_MODEL_DESIGN=1 for all tests."""
15
+ os.environ['NEW_MODEL_DESIGN'] = '1'
16
+
17
+
18
+ @pytest.fixture
19
+ def test_prompts():
20
+ """Simple test prompts for data parallelism testing."""
21
+ return [
22
+ "Hello, my name is",
23
+ "The capital of France is",
24
+ "The colors of the rainbow are",
25
+ "The future of AI is",
26
+ "The president of the United States is",
27
+ "How many players are on a standard soccer team?",
28
+ "In Greek mythology, who is the god of the sea?",
29
+ "What is the capital of Australia?",
30
+ "What is the largest planet in our solar system?",
31
+ "Who developed the theory of general relativity?",
32
+ ]
33
+
34
+
35
+ @pytest.fixture
36
+ def sampling_params():
37
+ """Standard sampling parameters for testing."""
38
+ return SamplingParams(
39
+ temperature=0.0,
40
+ max_tokens=32,
41
+ ignore_eos=True,
42
+ logprobs=1,
43
+ )
44
+
45
+
46
+ def _run_inference_with_config(model_name: str,
47
+ test_prompts: list,
48
+ sampling_params: SamplingParams,
49
+ tensor_parallel_size: int = 1,
50
+ data_parallel_size: int = 1,
51
+ additional_config: dict = {},
52
+ kv_cache_dtype: str = "auto",
53
+ enable_prefix_caching: bool = False,
54
+ async_scheduling: bool = False,
55
+ measure_time: bool = False,
56
+ max_model_len: int = 32,
57
+ max_num_batched_tokens: int = 128,
58
+ max_num_seqs: int = 16):
59
+ """Helper function to run inference with specified configuration.
60
+
61
+ Returns:
62
+ If measure_time=True: (outputs, elapsed_time) tuple
63
+ If measure_time=False: outputs list
64
+ """
65
+
66
+ # Create LLM args using parser-based approach similar to offline_inference.py
67
+ engine_args = EngineArgs(
68
+ model=model_name,
69
+ max_model_len=max_model_len,
70
+ tensor_parallel_size=tensor_parallel_size,
71
+ data_parallel_size=data_parallel_size,
72
+ gpu_memory_utilization=0.98,
73
+ max_num_batched_tokens=max_num_batched_tokens,
74
+ max_num_seqs=max_num_seqs,
75
+ enable_prefix_caching=enable_prefix_caching,
76
+ additional_config=additional_config,
77
+ kv_cache_dtype=kv_cache_dtype,
78
+ async_scheduling=async_scheduling,
79
+ )
80
+
81
+ engine_args_dict = asdict(engine_args)
82
+ llm = LLM(**engine_args_dict)
83
+
84
+ try:
85
+ start_time = time.time()
86
+ outputs = llm.generate(test_prompts, sampling_params)
87
+ elapsed_time = time.time() - start_time
88
+ if measure_time:
89
+ return outputs, elapsed_time
90
+ else:
91
+ return outputs
92
+ finally:
93
+ del llm
94
+ # Wait for TPUs to be released
95
+ time.sleep(5)
96
+
97
+
98
+ def test_data_parallelism_performance(sampling_params: SamplingParams, ):
99
+ """
100
+ Test that data parallelism provides performance improvements compared to baseline.
101
+ This test measures the execution time with 128 prompts of length ~1k tokens.
102
+
103
+ Note: This is a performance benchmark test with large prompts.
104
+ """
105
+ os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '1'
106
+ os.environ['SKIP_JAX_PRECOMPILE'] = '0'
107
+ os.environ['MODEL_IMPL_TYPE'] = 'flax_nnx'
108
+
109
+ model_name = "Qwen/Qwen2.5-1.5B-Instruct"
110
+
111
+ # Generate 128 prompts of approximately 1k tokens each
112
+ # Creating a base prompt of about 1k tokens using repeated text
113
+ base_text = (
114
+ "The rapid advancement of artificial intelligence has transformed numerous industries "
115
+ "and continues to reshape our understanding of technology's potential. Machine learning "
116
+ "algorithms have become increasingly sophisticated, enabling computers to perform tasks "
117
+ "that were once thought to require human intelligence. From natural language processing "
118
+ "to computer vision, AI systems are now capable of understanding context, recognizing "
119
+ "patterns, and making decisions with remarkable accuracy. " *
120
+ 20 # Repeat to reach ~1k tokens
121
+ )
122
+
123
+ # Create 128 prompts with slight variations
124
+ long_prompts = [
125
+ f"Prompt {i}: {base_text} What are your thoughts on this topic?"
126
+ for i in range(128)
127
+ ]
128
+
129
+ print(
130
+ f"Generated {len(long_prompts)} prompts, approximate length: {len(base_text.split())} tokens each"
131
+ )
132
+
133
+ # Configuration for long sequences
134
+ max_model_len = 2048
135
+ max_num_batched_tokens = 4096
136
+ max_num_seqs = 64
137
+
138
+ # Run baseline (no data parallelism) with timing
139
+ baseline_outputs, baseline_time = _run_inference_with_config(
140
+ model_name=model_name,
141
+ test_prompts=long_prompts,
142
+ sampling_params=sampling_params,
143
+ tensor_parallel_size=1,
144
+ data_parallel_size=1,
145
+ async_scheduling=True,
146
+ measure_time=True,
147
+ max_model_len=max_model_len,
148
+ max_num_batched_tokens=max_num_batched_tokens,
149
+ max_num_seqs=max_num_seqs,
150
+ )
151
+
152
+ # Run with model data parallelism and async scheduling with timing
153
+ dp_outputs, dp_time = _run_inference_with_config(
154
+ model_name=model_name,
155
+ test_prompts=long_prompts,
156
+ sampling_params=sampling_params,
157
+ tensor_parallel_size=1,
158
+ data_parallel_size=2,
159
+ async_scheduling=True,
160
+ measure_time=True,
161
+ max_model_len=max_model_len,
162
+ max_num_batched_tokens=max_num_batched_tokens,
163
+ max_num_seqs=max_num_seqs,
164
+ )
165
+
166
+ # Calculate speedup
167
+ speedup = baseline_time / dp_time if dp_time > 0 else 0
168
+
169
+ print("✓ Performance test results:")
170
+ print(f" Number of prompts: {len(long_prompts)}")
171
+ print(f" Baseline time: {baseline_time:.2f}s")
172
+ print(f" Data parallel time: {dp_time:.2f}s")
173
+ print(f" Speedup: {speedup:.2f}x")
174
+ print(
175
+ f" Baseline throughput: {len(long_prompts)/baseline_time:.2f} prompts/s"
176
+ )
177
+ print(
178
+ f" Data parallel throughput: {len(long_prompts)/dp_time:.2f} prompts/s"
179
+ )
180
+
181
+
182
+ @pytest.mark.parametrize("model_impl_type", ["vllm", "flax_nnx"])
183
+ def test_model_data_parallelism(
184
+ test_prompts: list,
185
+ sampling_params: SamplingParams,
186
+ model_impl_type: str,
187
+ ):
188
+ """
189
+ Test model-wise data parallelism where data=2 in the mesh axis.
190
+ This test verifies that the model can run with data parallelism enabled,
191
+ duplicating the entire model across 2 data parallel workers.
192
+
193
+ Equivalent to:
194
+ python examples/offline_inference.py --tensor_parallel_size=4 --data_parallel_size=2
195
+ """
196
+ # Use Llama 1B for this test
197
+ test_model = "meta-llama/Llama-3.2-1B-Instruct"
198
+ os.environ['MODEL_IMPL_TYPE'] = model_impl_type
199
+ os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
200
+ os.environ['SKIP_JAX_PRECOMPILE'] = '1'
201
+
202
+ # Test with data parallelism enabled
203
+ outputs = _run_inference_with_config(
204
+ model_name=test_model,
205
+ test_prompts=test_prompts,
206
+ sampling_params=sampling_params,
207
+ tensor_parallel_size=1,
208
+ data_parallel_size=2,
209
+ async_scheduling=False,
210
+ )
211
+
212
+ # Verify we got outputs for all prompts
213
+ assert len(outputs) == len(
214
+ test_prompts
215
+ ), f"Expected {len(test_prompts)} outputs, got {len(outputs)}"
216
+
217
+ # Verify each output has generated text
218
+ for output in outputs:
219
+ assert len(output.outputs) > 0, "Output has no generated text"
220
+ assert len(
221
+ output.outputs[0].text.strip()) > 0, "Generated text is empty"
222
+
223
+ print(f"✓ Model data parallelism test passed with {len(outputs)} outputs")
224
+
225
+
226
+ def test_attention_data_parallelism(
227
+ test_prompts: list,
228
+ sampling_params: SamplingParams,
229
+ ):
230
+ """
231
+ Test attention data parallelism where only the attention layer gets duplicated,
232
+ attn_dp=2 in the mesh axis. This is useful when num_kv_heads < TP to avoid
233
+ wasting KV cache memory.
234
+
235
+ Equivalent to:
236
+ python examples/offline_inference.py --tensor_parallel_size=4 --kv-cache-dtype=fp8 \
237
+ --additional_config='{"sharding":{"sharding_strategy": {"enable_dp_attention":1}}}'
238
+ """
239
+ # Use Qwen3 0.6B for this test with reduced tensor parallelism
240
+ test_model = "Qwen/Qwen3-0.6B"
241
+
242
+ os.environ['MODEL_IMPL_TYPE'] = "flax_nnx"
243
+ os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
244
+ os.environ['SKIP_JAX_PRECOMPILE'] = '1'
245
+
246
+ additional_config = {
247
+ "sharding": {
248
+ "sharding_strategy": {
249
+ "enable_dp_attention": 1
250
+ }
251
+ }
252
+ }
253
+
254
+ # Test with attention data parallelism enabled
255
+ # Reduced tensor_parallel_size from 8 to 4 to avoid memory exhaustion
256
+ outputs = _run_inference_with_config(
257
+ model_name=test_model,
258
+ test_prompts=test_prompts,
259
+ sampling_params=sampling_params,
260
+ tensor_parallel_size=4,
261
+ data_parallel_size=1,
262
+ additional_config=additional_config,
263
+ kv_cache_dtype="fp8",
264
+ )
265
+
266
+ # Verify we got outputs for all prompts
267
+ assert len(outputs) == len(
268
+ test_prompts
269
+ ), f"Expected {len(test_prompts)} outputs, got {len(outputs)}"
270
+
271
+ # Verify each output has generated text
272
+ for output in outputs:
273
+ assert len(output.outputs) > 0, "Output has no generated text"
274
+ assert len(
275
+ output.outputs[0].text.strip()) > 0, "Generated text is empty"
276
+
277
+ print(
278
+ f"✓ Attention data parallelism test passed with {len(outputs)} outputs"
279
+ )
280
+
281
+
282
+ def test_data_parallelism_correctness(
283
+ test_prompts: list,
284
+ sampling_params: SamplingParams,
285
+ ):
286
+ """
287
+ Test that data parallelism produces consistent results compared to a baseline.
288
+ This test compares outputs from a single-device run with data parallel runs
289
+ to ensure correctness, including log probabilities.
290
+ """
291
+ os.environ['SKIP_JAX_PRECOMPILE'] = '1'
292
+ os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
293
+ os.environ['MODEL_IMPL_TYPE'] = "flax_nnx"
294
+
295
+ model_name = "Qwen/Qwen2.5-1.5B-Instruct"
296
+ # Use a smaller subset of prompts for correctness testing
297
+ small_prompts = test_prompts[:10]
298
+
299
+ # Run baseline (no data parallelism)
300
+ baseline_outputs = _run_inference_with_config(
301
+ model_name=model_name,
302
+ test_prompts=small_prompts,
303
+ sampling_params=sampling_params,
304
+ tensor_parallel_size=1,
305
+ data_parallel_size=1,
306
+ async_scheduling=True,
307
+ )
308
+
309
+ # Run with model data parallelism and async scheduling
310
+ dp_outputs = _run_inference_with_config(
311
+ model_name=model_name,
312
+ test_prompts=small_prompts,
313
+ sampling_params=sampling_params,
314
+ tensor_parallel_size=1,
315
+ data_parallel_size=2,
316
+ async_scheduling=True,
317
+ )
318
+
319
+ # Compare outputs - they should be identical for greedy sampling
320
+ assert len(baseline_outputs) == len(dp_outputs)
321
+
322
+ text_matches = 0
323
+ text_mismatches = 0
324
+ logprob_mismatches = 0
325
+ max_logprob_diff = 0.0
326
+
327
+ for i, (baseline, dp_result) in enumerate(zip(baseline_outputs,
328
+ dp_outputs)):
329
+ baseline_text = baseline.outputs[0].text.strip()
330
+ dp_text = dp_result.outputs[0].text.strip()
331
+
332
+ # Check text output
333
+ if baseline_text == dp_text:
334
+ text_matches += 1
335
+ else:
336
+ text_mismatches += 1
337
+ print(f"Text mismatch found in prompt {i}:")
338
+ print(f" Baseline: {baseline_text}")
339
+ print(f" Data Parallel: {dp_text}")
340
+
341
+ # Check log probabilities
342
+ baseline_logprobs = baseline.outputs[0].logprobs
343
+ dp_logprobs = dp_result.outputs[0].logprobs
344
+
345
+ if baseline_logprobs is not None and dp_logprobs is not None:
346
+ # Compare log probabilities for each token
347
+ assert len(baseline_logprobs) == len(dp_logprobs), \
348
+ f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(dp_logprobs)}"
349
+
350
+ for token_idx, (base_lp, dp_lp) in enumerate(
351
+ zip(baseline_logprobs, dp_logprobs)):
352
+ # Get the top logprob value for the selected token
353
+ if base_lp and dp_lp:
354
+ # Get the top token's logprob from each
355
+ base_top_token = list(base_lp.keys())[0]
356
+ dp_top_token = list(dp_lp.keys())[0]
357
+
358
+ base_logprob_val = base_lp[base_top_token].logprob
359
+ dp_logprob_val = dp_lp[dp_top_token].logprob
360
+
361
+ # Calculate absolute difference
362
+ diff = abs(base_logprob_val - dp_logprob_val)
363
+ max_logprob_diff = max(max_logprob_diff, diff)
364
+
365
+ # Allow small numerical differences
366
+ if diff > 0.15:
367
+ logprob_mismatches += 1
368
+ print(
369
+ f"Logprob mismatch in prompt {i}, token {token_idx}:"
370
+ )
371
+ print(
372
+ f" Baseline token: {base_top_token}, logprob: {base_logprob_val:.6f}"
373
+ )
374
+ print(
375
+ f" DP token: {dp_top_token}, logprob: {dp_logprob_val:.6f}"
376
+ )
377
+ print(f" Difference: {diff:.6f}")
378
+
379
+ print("✓ Correctness test results:")
380
+ print(f" Text: {text_matches} matches, {text_mismatches} mismatches")
381
+ print(f" Max logprob difference: {max_logprob_diff:.6e}")
382
+ print(f" Significant logprob mismatches (>0.15): {logprob_mismatches}")
383
+
384
+ # Allow for some variance due to potential numerical differences
385
+ # but most outputs should match with greedy sampling
386
+ text_match_rate = text_matches / len(baseline_outputs)
387
+ assert text_match_rate >= 0.9, f"Text match rate {text_match_rate:.2%} is too low"
388
+
389
+ # Log probabilities should be very close (allow small numerical errors)
390
+ assert max_logprob_diff < 0.15, f"Max logprob difference {max_logprob_diff} is too large"
391
+
392
+ # Log probabilities should be very close (allow small numerical errors)
393
+ assert max_logprob_diff < 0.15, f"Max logprob difference {max_logprob_diff} is too large"