tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,268 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # tests/e2e/test_model_loader.py
16
+
17
+ import os
18
+ import re
19
+ import signal
20
+ import subprocess
21
+ import sys
22
+ import tempfile
23
+ import time
24
+
25
+ import pytest
26
+ import requests
27
+ import torch
28
+ from flax import nnx
29
+ from vllm.model_executor.models.registry import ModelRegistry
30
+
31
+ from tpu_inference.models.common.model_loader import (_MODEL_REGISTRY,
32
+ register_model)
33
+
34
+
35
+ @pytest.fixture
36
+ def cleanup_registries():
37
+ """Cleans up the model registries before and after each test."""
38
+ _MODEL_REGISTRY.clear()
39
+ # vLLM's ModelRegistry uses a class-level dictionary to store model classes.
40
+ # We need to clear it to ensure test isolation.
41
+ if hasattr(ModelRegistry, "models"):
42
+ ModelRegistry.models.clear()
43
+ yield
44
+ _MODEL_REGISTRY.clear()
45
+ if hasattr(ModelRegistry, "models"):
46
+ ModelRegistry.models.clear()
47
+
48
+
49
+ class DummyGoodModel(nnx.Module):
50
+ """A valid model that conforms to the expected interface."""
51
+
52
+ def __init__(self, vllm_config=None, rng=None, mesh=None):
53
+ pass
54
+
55
+ def __call__(self,
56
+ kv_caches=None,
57
+ input_ids=None,
58
+ attention_metadata=None):
59
+ pass
60
+
61
+
62
+ def test_register_model_success(cleanup_registries):
63
+ """Tests that a valid model is registered successfully."""
64
+ arch = "DummyGoodModelForCausalLM"
65
+ register_model(arch, DummyGoodModel)
66
+
67
+ # Check tpu_inference registry
68
+ assert arch in _MODEL_REGISTRY
69
+
70
+ class MockModelConfig:
71
+
72
+ def __init__(self, architectures):
73
+ self.hf_config = self._MockHfConfig(architectures)
74
+ self.model_impl = "flax_nnx"
75
+
76
+ class _MockHfConfig:
77
+
78
+ def __init__(self, architectures):
79
+ self.architectures = architectures
80
+
81
+ model_config = MockModelConfig(architectures=[arch])
82
+ vllm_compatible_model, _ = ModelRegistry.resolve_model_cls(
83
+ architectures=[arch], model_config=model_config)
84
+ assert vllm_compatible_model is not None
85
+ assert issubclass(vllm_compatible_model, torch.nn.Module)
86
+ assert issubclass(vllm_compatible_model, DummyGoodModel)
87
+
88
+
89
+ try:
90
+ # Attempt to import vLLM's interface validation function
91
+ from vllm.model_executor.models.interfaces_base import is_vllm_model
92
+ VLLM_INTERFACE_CHECK_AVAILABLE = True
93
+ except ImportError:
94
+ VLLM_INTERFACE_CHECK_AVAILABLE = False
95
+
96
+
97
+ @pytest.mark.skipif(not VLLM_INTERFACE_CHECK_AVAILABLE,
98
+ reason="is_vllm_model could not be imported from vllm.")
99
+ def test_registered_model_passes_vllm_interface_check(cleanup_registries):
100
+ """
101
+ Ensures the wrapped model passes vLLM's own interface validation.
102
+
103
+ This test is future-proof. If vLLM adds new requirements to its
104
+ model interface, this test will fail, signaling that the wrapper
105
+ in `register_model` needs to be updated.
106
+ """
107
+ arch = "DummyGoodModelForCausalLM"
108
+ register_model(arch, DummyGoodModel)
109
+
110
+ class MockModelConfig:
111
+
112
+ def __init__(self, architectures):
113
+ self.hf_config = self._MockHfConfig(architectures)
114
+ self.model_impl = "flax_nnx"
115
+
116
+ class _MockHfConfig:
117
+
118
+ def __init__(self, architectures):
119
+ self.architectures = architectures
120
+
121
+ model_config = MockModelConfig(architectures=[arch])
122
+ vllm_compatible_model, _ = ModelRegistry.resolve_model_cls(
123
+ architectures=[arch], model_config=model_config)
124
+
125
+ # This directly uses vLLM's checker, so it's always up-to-date.
126
+ # We assume is_vllm_model returns True for a valid model, and either
127
+ # returns False or raises an exception for an invalid one.
128
+ assert is_vllm_model(vllm_compatible_model)
129
+
130
+
131
+ def _run_server_and_bench(model_name: str, model_impl_type: str,
132
+ port: int) -> float:
133
+ env = os.environ.copy()
134
+ env["MODEL_IMPL_TYPE"] = model_impl_type
135
+
136
+ # Start server
137
+ server_cmd = [
138
+ sys.executable,
139
+ "-m",
140
+ "vllm.entrypoints.cli.main",
141
+ "serve",
142
+ model_name,
143
+ "--port",
144
+ str(port),
145
+ "--max-model-len",
146
+ "2048",
147
+ "--tensor-parallel-size",
148
+ "1",
149
+ "--disable-log-requests",
150
+ "--no-enable-prefix-caching",
151
+ "--gpu-memory-utilization",
152
+ "0.90",
153
+ ]
154
+
155
+ print(f"Starting server ({model_impl_type}) on port {port}...")
156
+ # Use a new process group so we can kill the server and its children
157
+ # Use temporary files for stdout/stderr to avoid pipe buffer deadlocks
158
+ stdout_file = tempfile.TemporaryFile(mode='w+b')
159
+ stderr_file = tempfile.TemporaryFile(mode='w+b')
160
+ server_process = subprocess.Popen(server_cmd,
161
+ env=env,
162
+ stdout=stdout_file,
163
+ stderr=stderr_file,
164
+ preexec_fn=os.setsid)
165
+
166
+ try:
167
+ # Wait for server to be ready
168
+ start_time = time.time()
169
+ server_ready = False
170
+ while time.time() - start_time < 600: # 10 minutes timeout
171
+ try:
172
+ if requests.get(
173
+ f"http://localhost:{port}/health").status_code == 200:
174
+ server_ready = True
175
+ break
176
+ except requests.exceptions.RequestException:
177
+ pass
178
+
179
+ if server_process.poll() is not None:
180
+ stdout_file.seek(0)
181
+ stderr_file.seek(0)
182
+ stdout = stdout_file.read().decode("utf-8", errors="replace")
183
+ stderr = stderr_file.read().decode("utf-8", errors="replace")
184
+ raise RuntimeError(
185
+ f"Server process exited unexpectedly.\nStdout: {stdout}\nStderr: {stderr}"
186
+ )
187
+
188
+ time.sleep(5)
189
+
190
+ if not server_ready:
191
+ stdout_file.seek(0)
192
+ stderr_file.seek(0)
193
+ stdout = stdout_file.read().decode("utf-8", errors="replace")
194
+ stderr = stderr_file.read().decode("utf-8", errors="replace")
195
+ raise RuntimeError(
196
+ f"Server failed to start within timeout.\nStdout: {stdout}\nStderr: {stderr}"
197
+ )
198
+
199
+ print("Server is ready. Running benchmark...")
200
+
201
+ # Run benchmark
202
+ bench_cmd = [
203
+ "vllm", "bench", "serve", "--model", model_name, "--port",
204
+ str(port), "--dataset-name", "random", "--random-input-len", "50",
205
+ "--random-output-len", "128", "--num-prompts", "20"
206
+ ]
207
+
208
+ result = subprocess.run(bench_cmd,
209
+ env=env,
210
+ capture_output=True,
211
+ text=True)
212
+
213
+ if result.returncode != 0:
214
+ raise RuntimeError(
215
+ f"Benchmark failed.\nStdout: {result.stdout}\nStderr: {result.stderr}"
216
+ )
217
+
218
+ # Parse throughput
219
+ # Output example: "Request throughput (req/s): 12.34"
220
+ match = re.search(r"Request throughput \(req/s\):\s+([\d\.]+)",
221
+ result.stdout)
222
+ if not match:
223
+ raise ValueError(
224
+ f"Could not parse throughput from output:\n{result.stdout}")
225
+
226
+ throughput = float(match.group(1))
227
+ return throughput
228
+
229
+ finally:
230
+ print("Stopping server...")
231
+ try:
232
+ os.killpg(os.getpgid(server_process.pid), signal.SIGTERM)
233
+ except ProcessLookupError:
234
+ pass
235
+ server_process.wait()
236
+ stdout_file.close()
237
+ stderr_file.close()
238
+ # Wait for TPU cleanup
239
+ time.sleep(5)
240
+
241
+
242
+ def test_flax_nnx_vs_vllm_performance():
243
+ """
244
+ Compares the performance of flax_nnx and vllm model implementations.
245
+
246
+ This test ensures that the JAX-native (`flax_nnx`) implementation's
247
+ performance is not significantly different from the vLLM-native PyTorch
248
+ (`vllm`) implementation. It measures the request throughput for both
249
+ backends and asserts that the percentage
250
+ difference is within a reasonable threshold.
251
+ """
252
+ model_name = "Qwen/Qwen3-4B"
253
+ # This should be 2-3% but 6% reduces flakiness.
254
+ percentage_difference_threshold = 0.06
255
+
256
+ throughput_vllm = _run_server_and_bench(model_name, "vllm", 8001)
257
+ throughput_flax = _run_server_and_bench(model_name, "flax_nnx", 8002)
258
+
259
+ print(f"vLLM (PyTorch) throughput: {throughput_vllm:.2f} req/s.")
260
+ print(f"flax_nnx (JAX) throughput: {throughput_flax:.2f} req/s.")
261
+
262
+ percentage_diff = abs(throughput_flax - throughput_vllm) / throughput_vllm
263
+ print(f"Percentage difference in throughput: {percentage_diff:.2%}.")
264
+
265
+ assert percentage_diff < percentage_difference_threshold, (
266
+ f"The performance difference between flax_nnx and vllm is too high. "
267
+ f"Difference: {percentage_diff:.2%}, Threshold: {percentage_difference_threshold:.2%}"
268
+ )
@@ -0,0 +1,111 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ #
3
+ # A simplified example to run multi-modal inference and verify the output.
4
+ # This script is a self-contained test that runs a single prompt and
5
+ # compares the output to a known-good output.
6
+
7
+ import difflib
8
+ import os
9
+ from dataclasses import asdict
10
+
11
+ import pytest
12
+ from vllm import LLM, EngineArgs, SamplingParams
13
+ from vllm.assets.image import ImageAsset
14
+ from vllm.multimodal.image import convert_image_mode
15
+
16
+ # Expected partial text output from the model. This is based on a previous
17
+ # run and is used for verification. The test is considered passed if the
18
+ # generated output match with this text.
19
+ EXPECTED_TEXT = (
20
+ "The image depicts a tall, cylindrical tower with a lattice-like structure, surrounded by cherry blossom trees in full bloom. The cherry blossoms are in various stages of opening, with pink petals covering the branches. The sky is clear and blue, providing a vibrant backdrop to the scene. The tower appears to be a significant landmark"
21
+ )
22
+
23
+
24
+ # NOTE: Could be extended to more mm models/configs as needed
25
+ @pytest.mark.parametrize("enable_dynamic_image_sizes", [False, True])
26
+ def test_multi_modal_inference(monkeypatch, enable_dynamic_image_sizes):
27
+ """
28
+ Runs multi-modal inference and verifies the output.
29
+ """
30
+ os.environ['SKIP_JAX_PRECOMPILE'] = '1' # Skip warmup to save time.
31
+ os.environ[
32
+ 'VLLM_XLA_CHECK_RECOMPILATION'] = '0' # Allow compilation during execution.
33
+
34
+ monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
35
+
36
+ # --- Configuration ---
37
+ model = "Qwen/Qwen2.5-VL-3B-Instruct"
38
+ tensor_parallel_size = 1
39
+ temperature = 0.0
40
+ max_tokens = 64
41
+ max_model_len = 4096
42
+ gpu_memory_utilization = 0.5
43
+ modality = "image"
44
+
45
+ print("Preparing for multi-modal inference...")
46
+
47
+ # --- Prepare Inputs ---
48
+ image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
49
+ question = "What is the content of this image?"
50
+
51
+ # Using Qwen2.5-VL prompt template
52
+ # NOTE: other models may be different
53
+ prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
54
+ f"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
55
+ f"{question}<|im_end|>\n"
56
+ "<|im_start|>assistant\n")
57
+
58
+ # --- Setup vLLM Engine ---
59
+ engine_args = EngineArgs(
60
+ model=model,
61
+ max_model_len=max_model_len,
62
+ tensor_parallel_size=tensor_parallel_size,
63
+ gpu_memory_utilization=gpu_memory_utilization,
64
+ max_num_seqs=1,
65
+ mm_processor_kwargs={
66
+ "min_pixels": 28 * 28,
67
+ "max_pixels": 1280 * 28 * 28,
68
+ "fps": 1,
69
+ },
70
+ limit_mm_per_prompt={modality: 1},
71
+ )
72
+ engine_args = asdict(engine_args)
73
+ if engine_args.get("additional_config") is None:
74
+ engine_args["additional_config"] = {}
75
+
76
+ engine_args["additional_config"][
77
+ "enable_dynamic_image_sizes"] = enable_dynamic_image_sizes
78
+ llm = LLM(**engine_args)
79
+
80
+ sampling_params = SamplingParams(
81
+ temperature=temperature,
82
+ max_tokens=max_tokens,
83
+ )
84
+
85
+ inputs = {
86
+ "prompt": prompt,
87
+ "multi_modal_data": {
88
+ "image": image
89
+ },
90
+ }
91
+
92
+ # --- Run Inference ---
93
+ print("Running inference...")
94
+ outputs = llm.generate(inputs, sampling_params)
95
+
96
+ # --- Verification ---
97
+ generated_text = outputs[0].outputs[0].text.strip()
98
+
99
+ print("-" * 50)
100
+ print("Generated Text:")
101
+ print(generated_text)
102
+ print("-" * 50)
103
+
104
+ # Check output
105
+ similarity_score = difflib.SequenceMatcher(None, generated_text,
106
+ EXPECTED_TEXT).ratio()
107
+ print(f"Similarity Score: {similarity_score:.4f}")
108
+ assert similarity_score >= 0.85, (
109
+ f"Text similarity too low ({similarity_score:.2f}).\n"
110
+ f"Expected: {EXPECTED_TEXT}\n"
111
+ f"Actual: {generated_text}")
@@ -0,0 +1,265 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import os
5
+ import time
6
+ from dataclasses import asdict
7
+
8
+ import pytest
9
+ from vllm import LLM, EngineArgs, SamplingParams
10
+
11
+
12
+ @pytest.fixture
13
+ def model_name():
14
+ """Choose LLama3 8b as the test model as it supports PP on jax model impl."""
15
+ return "meta-llama/Llama-3.1-8B-Instruct"
16
+
17
+
18
+ @pytest.fixture
19
+ def test_prompts():
20
+ """Simple test prompts for data parallelism testing."""
21
+ return [
22
+ "Hello, my name is",
23
+ "The capital of France is",
24
+ "The colors of the rainbow are",
25
+ "The future of AI is",
26
+ "The president of the United States is",
27
+ "How many players are on a standard soccer team?",
28
+ "In Greek mythology, who is the god of the sea?",
29
+ "What is the capital of Australia?",
30
+ "What is the largest planet in our solar system?",
31
+ "Who developed the theory of general relativity?",
32
+ ]
33
+
34
+
35
+ @pytest.fixture
36
+ def sampling_params():
37
+ """Standard sampling parameters for testing."""
38
+ return SamplingParams(
39
+ temperature=0.0,
40
+ max_tokens=32,
41
+ ignore_eos=True,
42
+ logprobs=1,
43
+ )
44
+
45
+
46
+ def _run_inference_with_config(model_name: str,
47
+ test_prompts: list,
48
+ sampling_params: SamplingParams,
49
+ tensor_parallel_size: int = 1,
50
+ pipeline_parallel_size: int = 1,
51
+ additional_config: dict = {},
52
+ kv_cache_dtype: str = "auto",
53
+ enable_prefix_caching: bool = False) -> list:
54
+ """Helper function to run inference with specified configuration."""
55
+
56
+ # Create LLM args using parser-based approach similar to offline_inference.py
57
+ engine_args = EngineArgs(
58
+ model=model_name,
59
+ max_model_len=128,
60
+ tensor_parallel_size=tensor_parallel_size,
61
+ pipeline_parallel_size=pipeline_parallel_size,
62
+ gpu_memory_utilization=0.95,
63
+ max_num_batched_tokens=128,
64
+ max_num_seqs=16,
65
+ enable_prefix_caching=enable_prefix_caching,
66
+ additional_config=additional_config,
67
+ kv_cache_dtype=kv_cache_dtype,
68
+ )
69
+
70
+ engine_args_dict = asdict(engine_args)
71
+ llm = LLM(**engine_args_dict)
72
+
73
+ try:
74
+ outputs = llm.generate(test_prompts, sampling_params)
75
+ return outputs
76
+ finally:
77
+ del llm
78
+ # Wait for TPUs to be released
79
+ time.sleep(5)
80
+
81
+
82
+ @pytest.mark.skip(reason="PP is not fully enabled.")
83
+ def test_pipeline_parallelism_jax_model(
84
+ model_name: str,
85
+ test_prompts: list,
86
+ sampling_params: SamplingParams,
87
+ ):
88
+ """
89
+ Test pipline parallelism works on Jax models
90
+
91
+ Equivalent to:
92
+ python examples/offline_inference.py --tensor_parallel_size=1 --pipeline_parallel_size=2
93
+ """
94
+ # Test with pipeline parallelism enabled
95
+ outputs = _run_inference_with_config(
96
+ model_name=model_name,
97
+ test_prompts=test_prompts,
98
+ sampling_params=sampling_params,
99
+ tensor_parallel_size=1,
100
+ pipeline_parallel_size=2,
101
+ )
102
+
103
+ # Verify we got outputs for all prompts
104
+ assert len(outputs) == len(test_prompts)
105
+
106
+ # Verify each output has generated text
107
+ for output in outputs:
108
+ assert len(output.outputs) > 0
109
+ assert len(output.outputs[0].text.strip()) > 0
110
+
111
+ print(
112
+ f"✓ Pipeline Parallelism Jax model test passed with {len(outputs)} outputs"
113
+ )
114
+
115
+
116
+ @pytest.mark.skip(reason="PP is not fully enabled.")
117
+ def test_pipeline_parallelism_vllm_model(
118
+ model_name: str,
119
+ test_prompts: list,
120
+ sampling_params: SamplingParams,
121
+ ):
122
+ """
123
+ Test pipline parallelism works on vLLM models, and it also works with
124
+ with tensor parallelism.
125
+
126
+ Equivalent to:
127
+ MODEL_IMPL_TYPE=vllm python examples/offline_inference.py --tensor_parallel_size=1 --pipeline_parallel_size=2
128
+ """
129
+
130
+ os.environ['MODEL_IMPL_TYPE'] = 'vllm'
131
+ # Test with data parallelism enabled
132
+ outputs = _run_inference_with_config(
133
+ model_name=model_name,
134
+ test_prompts=test_prompts,
135
+ sampling_params=sampling_params,
136
+ tensor_parallel_size=1,
137
+ pipeline_parallel_size=2,
138
+ )
139
+
140
+ # Verify we got outputs for all prompts
141
+ assert len(outputs) == len(test_prompts)
142
+
143
+ # Verify each output has generated text
144
+ for output in outputs:
145
+ assert len(output.outputs) > 0
146
+ assert len(output.outputs[0].text.strip()) > 0
147
+
148
+ print(
149
+ f"✓ Pipeline Parallelism vLLM model test passed with {len(outputs)} outputs"
150
+ )
151
+
152
+
153
+ @pytest.mark.skip(reason="PP is not fully enabled.")
154
+ def test_pipeline_parallelism_jax_model_correctness(
155
+ model_name: str,
156
+ test_prompts: list,
157
+ sampling_params: SamplingParams,
158
+ ):
159
+ """
160
+ Test that pipeline parallelism produces consistent results compared to a baseline.
161
+ This test compares outputs from a single-device run with pipeline parallel runs
162
+ to ensure correctness, including log probabilities.
163
+ """
164
+ os.environ['SKIP_JAX_PRECOMPILE'] = '1'
165
+ os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
166
+
167
+ # Use a smaller subset of prompts for correctness testing
168
+ small_prompts = test_prompts[:10]
169
+
170
+ # Run baseline (no PP)
171
+ baseline_outputs = _run_inference_with_config(
172
+ model_name=model_name,
173
+ test_prompts=small_prompts,
174
+ sampling_params=sampling_params,
175
+ tensor_parallel_size=1,
176
+ pipeline_parallel_size=1,
177
+ )
178
+
179
+ # Run with model data parallelism and async scheduling
180
+ pp_outputs = _run_inference_with_config(
181
+ model_name=model_name,
182
+ test_prompts=small_prompts,
183
+ sampling_params=sampling_params,
184
+ tensor_parallel_size=1,
185
+ pipeline_parallel_size=2,
186
+ )
187
+
188
+ # Compare outputs - in theory they should be identical for greedy sampling
189
+ # in reality there may be some differences, but overall the outputs should
190
+ # be very similar.
191
+
192
+ # an example:
193
+ # prompt: What is the capital of Australia?
194
+ # both answers should be acceptable.
195
+ # The capital of Australia is Canberra. It is located in the Australian Capital Territory (ACT) and is home to many
196
+ # Canberra is the capital of Australia. It is located in the Australian Capital Territory (ACT) and is home to
197
+ assert len(baseline_outputs) == len(pp_outputs)
198
+
199
+ text_matches = 0
200
+ text_mismatches = 0
201
+ logprob_mismatches = 0
202
+ max_logprob_diff = 0.0
203
+
204
+ for i, (baseline, pp_result) in enumerate(zip(baseline_outputs,
205
+ pp_outputs)):
206
+ baseline_text = baseline.outputs[0].text.strip()
207
+ pp_text = pp_result.outputs[0].text.strip()
208
+
209
+ # Check text output
210
+ if baseline_text == pp_text:
211
+ text_matches += 1
212
+ else:
213
+ text_mismatches += 1
214
+ print(f"Text mismatch found in prompt {i}:")
215
+ print(f" Baseline: {baseline_text}")
216
+ print(f" Pipeline Parallel: {pp_text}")
217
+
218
+ # Check log probabilities
219
+ baseline_logprobs = baseline.outputs[0].logprobs
220
+ pp_logprobs = pp_result.outputs[0].logprobs
221
+ if baseline_logprobs is not None and pp_logprobs is not None:
222
+ # Compare log probabilities for each token
223
+ assert len(baseline_logprobs) == len(pp_logprobs), \
224
+ f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(pp_logprobs)}"
225
+ for token_idx, (base_lp, pp_lp) in enumerate(
226
+ zip(baseline_logprobs, pp_logprobs)):
227
+ # Get the top logprob value for the selected token
228
+ if base_lp and pp_lp:
229
+ # Get the top token's logprob from each
230
+ base_top_token = list(base_lp.keys())[0]
231
+ pp_top_token = list(pp_lp.keys())[0]
232
+
233
+ base_logprob_val = base_lp[base_top_token].logprob
234
+ pp_logprob_val = pp_lp[pp_top_token].logprob
235
+
236
+ # Calculate absolute difference
237
+ diff = abs(base_logprob_val - pp_logprob_val)
238
+ max_logprob_diff = max(max_logprob_diff, diff)
239
+
240
+ # Allow small numerical differences (e.g., 1e-3)
241
+ if diff > 1e-3:
242
+ logprob_mismatches += 1
243
+ print(
244
+ f"Logprob mismatch in prompt {i}, token {token_idx}:"
245
+ )
246
+ print(
247
+ f" Baseline token: {base_top_token}, logprob: {base_logprob_val:.6f}"
248
+ )
249
+ print(
250
+ f" PP token: {pp_top_token}, logprob: {pp_logprob_val:.6f}"
251
+ )
252
+ print(f" Difference: {diff:.6f}")
253
+
254
+ print("✓ Correctness test results:")
255
+ print(f" Text: {text_matches} matches, {text_mismatches} mismatches")
256
+ print(f" Max logprob difference: {max_logprob_diff:.6e}")
257
+ print(f" Significant logprob mismatches (>1e-3): {logprob_mismatches}")
258
+
259
+ # Allow for some variance due to potential numerical differences
260
+ # but most outputs should match with greedy sampling
261
+ text_match_rate = text_matches / len(baseline_outputs)
262
+ assert text_match_rate >= 0.9, f"Text match rate {text_match_rate:.2%} is too low"
263
+
264
+ # Log probabilities should be very close (allow small numerical errors)
265
+ assert max_logprob_diff < 1, f"Max logprob difference {max_logprob_diff} is too large"