tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,651 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+
17
+ import jax
18
+ import pytest
19
+ import torch
20
+ import torchax
21
+ from jax._src import test_util as jtu
22
+ from jax.sharding import NamedSharding, PartitionSpec
23
+ from torchax.interop import torch_view
24
+ from torchax.ops.mappings import j2t, t2j
25
+ from vllm.config import ParallelConfig, set_current_vllm_config
26
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
27
+ init_distributed_environment)
28
+ from vllm.engine.arg_utils import EngineArgs
29
+ from vllm.forward_context import set_forward_context
30
+ from vllm.model_executor.layers.fused_moe import FusedMoE
31
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
32
+ LinearBase,
33
+ MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear)
36
+ from vllm.model_executor.model_loader import get_model as vllm_get_model
37
+
38
+ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
39
+ from tpu_inference.layers.vllm.quantization.unquantized import (
40
+ VllmUnquantizedConfig, VllmUnquantizedFusedMoEMethod,
41
+ VllmUnquantizedLinearMethod)
42
+
43
+ from . import utils as test_utils
44
+
45
+ P = PartitionSpec
46
+ MODELS = ["Qwen/Qwen2-1.5B-Instruct"]
47
+
48
+
49
+ @pytest.fixture(autouse=True)
50
+ def setup_environment():
51
+ # This is a fake config used for init dist env.
52
+ # RowParallelLinear needs dist env to be initialized.
53
+ engine_args = EngineArgs(
54
+ model=MODELS[0],
55
+ max_model_len=64,
56
+ max_num_batched_tokens=64,
57
+ max_num_seqs=4,
58
+ )
59
+
60
+ vllm_config = engine_args.create_engine_config()
61
+
62
+ with set_current_vllm_config(vllm_config):
63
+ temp_file = tempfile.mkstemp()[1]
64
+ init_distributed_environment(
65
+ 1,
66
+ 0,
67
+ local_rank=0,
68
+ distributed_init_method=f"file://{temp_file}",
69
+ backend="gloo")
70
+ ensure_model_parallel_initialized(1, 1)
71
+
72
+
73
+ @pytest.mark.parametrize("model", MODELS)
74
+ @pytest.mark.parametrize("mesh", [
75
+ test_utils.get_spmd_mesh(1),
76
+ test_utils.get_spmd_mesh(jax.local_device_count())
77
+ ])
78
+ def test_quant_override(model, mesh):
79
+
80
+ engine_args = EngineArgs(
81
+ model=model,
82
+ max_model_len=64,
83
+ max_num_batched_tokens=64,
84
+ max_num_seqs=4,
85
+ )
86
+ vllm_config = engine_args.create_engine_config()
87
+ vllm_config.model_config.dtype = torch.bfloat16
88
+
89
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
90
+ assert isinstance(quant_config, VllmUnquantizedConfig)
91
+ assert quant_config.vllm_config == vllm_config
92
+ assert quant_config.mesh == mesh
93
+
94
+
95
+ @pytest.mark.parametrize("model", MODELS)
96
+ @pytest.mark.parametrize("mesh", [
97
+ test_utils.get_spmd_mesh(1),
98
+ test_utils.get_spmd_mesh(jax.local_device_count())
99
+ ])
100
+ def test_loading_model(model, mesh):
101
+ engine_args = EngineArgs(
102
+ model=model,
103
+ max_model_len=64,
104
+ max_num_batched_tokens=64,
105
+ max_num_seqs=4,
106
+ )
107
+ vllm_config = engine_args.create_engine_config()
108
+ vllm_config.model_config.dtype = torch.bfloat16
109
+ vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
110
+ vllm_config.device_config.device = "cpu"
111
+
112
+ vllm_model = vllm_get_model(vllm_config=vllm_config)
113
+ layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
114
+ for layer in layers:
115
+ assert isinstance(layer.quant_config, VllmUnquantizedConfig)
116
+ assert isinstance(layer.quant_method, VllmUnquantizedLinearMethod)
117
+
118
+
119
+ @pytest.mark.parametrize("model", MODELS)
120
+ @pytest.mark.parametrize("bias", [False, True])
121
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
122
+ @pytest.mark.parametrize("enable_sp", [False, True])
123
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
124
+ def test_row_parallel_linear(model, bias, num_devices, enable_sp,
125
+ enable_attn_dp):
126
+ # Skip if enable_attn_dp is True but we don't have enough devices
127
+ if enable_attn_dp and num_devices < 2:
128
+ pytest.skip("enable_attn_dp requires at least 2 devices")
129
+
130
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
131
+ dtype = torch.bfloat16
132
+
133
+ engine_args = EngineArgs(
134
+ model=model,
135
+ max_model_len=64,
136
+ max_num_batched_tokens=64,
137
+ max_num_seqs=4,
138
+ )
139
+ vllm_config = engine_args.create_engine_config()
140
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
141
+
142
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
143
+ input_tensor = input_tensor.to('cpu')
144
+
145
+ with set_current_vllm_config(vllm_config):
146
+ row_linear = RowParallelLinear(
147
+ input_size=4096,
148
+ output_size=8192,
149
+ bias=bias,
150
+ params_dtype=dtype,
151
+ return_bias=False,
152
+ )
153
+
154
+ weight_data = torch.rand_like(row_linear.weight.data) / 10
155
+ if bias:
156
+ bias_data = torch.rand_like(row_linear.bias.data)
157
+
158
+ row_linear.weight.data = weight_data
159
+ if bias:
160
+ row_linear.bias.data = bias_data
161
+ row_linear = row_linear.to('cpu')
162
+ row_linear.quant_method.process_weights_after_loading(row_linear)
163
+ output = row_linear(input_tensor).to(dtype)
164
+
165
+ vllm_config.model_config.dtype = dtype
166
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
167
+ with set_current_vllm_config(vllm_config):
168
+ jax_row_linear = RowParallelLinear(
169
+ input_size=4096,
170
+ output_size=8192,
171
+ bias=bias,
172
+ params_dtype=dtype,
173
+ return_bias=False,
174
+ quant_config=quant_config,
175
+ )
176
+
177
+ jax_row_linear.weight.data = weight_data
178
+ if bias:
179
+ jax_row_linear.bias.data = bias_data
180
+
181
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
182
+ jax_input_tensor.apply_jax_(jax.device_put,
183
+ NamedSharding(mesh, P(None, None)))
184
+ with torchax.default_env():
185
+ assert isinstance(jax_row_linear.quant_method,
186
+ VllmUnquantizedLinearMethod)
187
+ jax_row_linear.quant_method.process_weights_after_loading(
188
+ jax_row_linear)
189
+ jax_output = jax_row_linear(jax_input_tensor)
190
+ # j2t() doens't support bfloat16, so we cast it into float32 as an intermedate step.
191
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
192
+
193
+ torch.testing.assert_close(output, jax_output)
194
+
195
+
196
+ @pytest.mark.parametrize("model", MODELS)
197
+ @pytest.mark.parametrize("bias", [False, True])
198
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
199
+ @pytest.mark.parametrize("enable_sp", [False, True])
200
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
201
+ def test_column_parallel_linear(model, bias, num_devices, enable_sp,
202
+ enable_attn_dp):
203
+ # Skip if enable_attn_dp is True but we don't have enough devices
204
+ if enable_attn_dp and num_devices < 2:
205
+ pytest.skip("enable_attn_dp requires at least 2 devices")
206
+
207
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
208
+ dtype = torch.bfloat16
209
+
210
+ engine_args = EngineArgs(
211
+ model=model,
212
+ max_model_len=64,
213
+ max_num_batched_tokens=64,
214
+ max_num_seqs=4,
215
+ )
216
+ vllm_config = engine_args.create_engine_config()
217
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
218
+
219
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
220
+ input_tensor = input_tensor.to('cpu')
221
+
222
+ with set_current_vllm_config(vllm_config):
223
+ column_linear = ColumnParallelLinear(
224
+ input_size=4096,
225
+ output_size=8192,
226
+ bias=bias,
227
+ params_dtype=dtype,
228
+ return_bias=False,
229
+ )
230
+
231
+ weight_data = torch.rand_like(column_linear.weight.data) / 10
232
+ if bias:
233
+ bias_data = torch.rand_like(column_linear.bias.data)
234
+
235
+ column_linear.weight.data = weight_data
236
+ if bias:
237
+ column_linear.bias.data = bias_data
238
+ column_linear = column_linear.to('cpu')
239
+ column_linear.quant_method.process_weights_after_loading(column_linear)
240
+ output = column_linear(input_tensor).to(dtype)
241
+
242
+ vllm_config.model_config.dtype = dtype
243
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
244
+ with set_current_vllm_config(vllm_config):
245
+ jax_column_linear = ColumnParallelLinear(
246
+ input_size=4096,
247
+ output_size=8192,
248
+ bias=bias,
249
+ params_dtype=dtype,
250
+ return_bias=False,
251
+ quant_config=quant_config,
252
+ )
253
+
254
+ jax_column_linear.weight.data = weight_data
255
+ if bias:
256
+ jax_column_linear.bias.data = bias_data
257
+
258
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
259
+ jax_input_tensor.apply_jax_(jax.device_put,
260
+ NamedSharding(mesh, P(None, None)))
261
+ with torchax.default_env():
262
+ assert isinstance(jax_column_linear.quant_method,
263
+ VllmUnquantizedLinearMethod)
264
+ jax_column_linear.quant_method.process_weights_after_loading(
265
+ jax_column_linear)
266
+ jax_output = jax_column_linear(jax_input_tensor)
267
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
268
+
269
+ torch.testing.assert_close(output, jax_output)
270
+
271
+
272
+ @pytest.mark.parametrize("model", MODELS)
273
+ @pytest.mark.parametrize("bias", [False, True])
274
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
275
+ @pytest.mark.parametrize("enable_sp", [False, True])
276
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
277
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
278
+ def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
279
+ enable_attn_dp):
280
+ # Skip if enable_attn_dp is True but we don't have enough devices
281
+ if enable_attn_dp and num_devices < 2:
282
+ pytest.skip("enable_attn_dp requires at least 2 devices")
283
+
284
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
285
+ dtype = torch.bfloat16
286
+
287
+ engine_args = EngineArgs(
288
+ model=model,
289
+ max_model_len=64,
290
+ max_num_batched_tokens=64,
291
+ max_num_seqs=4,
292
+ )
293
+ vllm_config = engine_args.create_engine_config()
294
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
295
+
296
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
297
+ input_tensor = input_tensor.to('cpu')
298
+
299
+ with set_current_vllm_config(vllm_config):
300
+ qkv_linear = QKVParallelLinear(
301
+ hidden_size=4096,
302
+ head_size=128,
303
+ total_num_heads=32,
304
+ total_num_kv_heads=8,
305
+ bias=bias,
306
+ params_dtype=dtype,
307
+ return_bias=False,
308
+ )
309
+
310
+ weight_data = torch.rand_like(qkv_linear.weight.data) / 10
311
+ if bias:
312
+ bias_data = torch.rand_like(qkv_linear.bias.data)
313
+
314
+ qkv_linear.weight.data = weight_data
315
+ if bias:
316
+ qkv_linear.bias.data = bias_data
317
+ qkv_linear = qkv_linear.to('cpu')
318
+ qkv_linear.quant_method.process_weights_after_loading(qkv_linear)
319
+ output = qkv_linear(input_tensor).to(dtype)
320
+
321
+ vllm_config.model_config.dtype = dtype
322
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
323
+ with set_current_vllm_config(vllm_config):
324
+ jax_qkv_linear = QKVParallelLinear(
325
+ hidden_size=4096,
326
+ head_size=128,
327
+ total_num_heads=32,
328
+ total_num_kv_heads=8,
329
+ bias=bias,
330
+ params_dtype=dtype,
331
+ return_bias=False,
332
+ quant_config=quant_config,
333
+ )
334
+ jax_qkv_linear.quant_method.fuse_matmuls = fuse_matmuls
335
+
336
+ jax_qkv_linear.weight.data = weight_data
337
+ if bias:
338
+ jax_qkv_linear.bias.data = bias_data
339
+
340
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
341
+ jax_input_tensor.apply_jax_(jax.device_put,
342
+ NamedSharding(mesh, P(None, None)))
343
+ with torchax.default_env():
344
+ assert isinstance(jax_qkv_linear.quant_method,
345
+ VllmUnquantizedLinearMethod)
346
+ jax_qkv_linear.quant_method.process_weights_after_loading(
347
+ jax_qkv_linear)
348
+ jax_output = jax_qkv_linear(jax_input_tensor)
349
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
350
+
351
+ torch.testing.assert_close(output, jax_output)
352
+
353
+
354
+ @pytest.mark.parametrize("model", MODELS)
355
+ @pytest.mark.parametrize("bias", [False, True])
356
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
357
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
358
+ @pytest.mark.parametrize("enable_sp", [False, True])
359
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
360
+ def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
361
+ enable_sp, enable_attn_dp):
362
+ # Skip if enable_attn_dp is True but we don't have enough devices
363
+ if enable_attn_dp and num_devices < 2:
364
+ pytest.skip("enable_attn_dp requires at least 2 devices")
365
+
366
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
367
+ dtype = torch.bfloat16
368
+
369
+ engine_args = EngineArgs(
370
+ model=model,
371
+ max_model_len=64,
372
+ max_num_batched_tokens=64,
373
+ max_num_seqs=4,
374
+ )
375
+ vllm_config = engine_args.create_engine_config()
376
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
377
+
378
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
379
+ input_tensor = input_tensor.to('cpu')
380
+
381
+ # Call vLLM code
382
+ with set_current_vllm_config(vllm_config):
383
+ merged_column_linear = MergedColumnParallelLinear(
384
+ input_size=4096,
385
+ output_sizes=[14336] * 2,
386
+ bias=bias,
387
+ params_dtype=dtype,
388
+ return_bias=False,
389
+ )
390
+
391
+ weight_data = torch.rand_like(merged_column_linear.weight.data) / 10
392
+ if bias:
393
+ bias_data = torch.rand_like(merged_column_linear.bias.data)
394
+
395
+ merged_column_linear.weight.data = weight_data
396
+ if bias:
397
+ merged_column_linear.bias.data = bias_data
398
+ merged_column_linear = merged_column_linear.to('cpu')
399
+ merged_column_linear.quant_method.process_weights_after_loading(
400
+ merged_column_linear)
401
+ output = merged_column_linear(input_tensor).to(dtype)
402
+
403
+ # Call tpu_inference code
404
+ vllm_config.model_config.dtype = dtype
405
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
406
+ with set_current_vllm_config(vllm_config):
407
+ jax_merged_column_linear = MergedColumnParallelLinear(
408
+ input_size=4096,
409
+ output_sizes=[14336] * 2,
410
+ bias=bias,
411
+ params_dtype=dtype,
412
+ return_bias=False,
413
+ quant_config=quant_config,
414
+ )
415
+ jax_merged_column_linear.quant_method.fuse_matmuls = fuse_matmuls
416
+
417
+ jax_merged_column_linear.weight.data = weight_data
418
+ if bias:
419
+ jax_merged_column_linear.bias.data = bias_data
420
+
421
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
422
+ jax_input_tensor.apply_jax_(jax.device_put,
423
+ NamedSharding(mesh, P(None, None)))
424
+ with torchax.default_env():
425
+ assert isinstance(jax_merged_column_linear.quant_method,
426
+ VllmUnquantizedLinearMethod)
427
+ jax_merged_column_linear.quant_method.process_weights_after_loading(
428
+ jax_merged_column_linear)
429
+ jax_output = jax_merged_column_linear(jax_input_tensor)
430
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
431
+
432
+ torch.testing.assert_close(output, jax_output)
433
+
434
+
435
+ @pytest.mark.parametrize("use_ep", [True, False])
436
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
437
+ @pytest.mark.parametrize("num_tokens", [8])
438
+ @pytest.mark.parametrize("intermediate_size", [1024, 2048])
439
+ @pytest.mark.parametrize("hidden_size", [128, 512])
440
+ @pytest.mark.parametrize("num_experts", [8])
441
+ @pytest.mark.parametrize("topk", [2])
442
+ @pytest.mark.parametrize("has_bias", [False, True])
443
+ @pytest.mark.parametrize("activation", ["silu", "swigluoai"])
444
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
445
+ def test_fused_moe(use_ep, num_devices, num_tokens, intermediate_size,
446
+ hidden_size, num_experts, topk, has_bias, activation,
447
+ enable_attn_dp):
448
+ # Skip if enable_attn_dp is True but we don't have enough devices
449
+ if enable_attn_dp and num_devices < 2:
450
+ pytest.skip("enable_attn_dp requires at least 2 devices")
451
+
452
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
453
+
454
+ torch.manual_seed(42)
455
+ dtype = torch.bfloat16
456
+
457
+ a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
458
+ w1 = torch.randn(
459
+ (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
460
+ w2 = torch.randn(
461
+ (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
462
+ score = torch.randn((num_tokens, num_experts), dtype=dtype)
463
+
464
+ w1_bias = w2_bias = None
465
+ if has_bias:
466
+ w1_bias = torch.randn(
467
+ (num_experts, 2 * intermediate_size), dtype=dtype) / 10
468
+ w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
469
+
470
+ engine_args = EngineArgs(
471
+ model="Qwen/Qwen2-1.5B-Instruct",
472
+ max_model_len=64,
473
+ max_num_batched_tokens=64,
474
+ max_num_seqs=4,
475
+ )
476
+ vllm_config = engine_args.create_engine_config()
477
+ vllm_config.model_config.dtype = dtype
478
+
479
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
480
+ with set_current_vllm_config(vllm_config):
481
+ vllm_fused_moe = FusedMoE(
482
+ num_experts=num_experts,
483
+ top_k=topk,
484
+ hidden_size=hidden_size,
485
+ intermediate_size=intermediate_size,
486
+ reduce_results=False,
487
+ renormalize=False,
488
+ tp_size=1,
489
+ dp_size=1,
490
+ quant_config=quant_config,
491
+ has_bias=has_bias,
492
+ activation=activation,
493
+ )
494
+ vllm_fused_moe.moe_parallel_config.use_ep = use_ep
495
+ vllm_fused_moe.w13_weight.data = w1
496
+ vllm_fused_moe.w2_weight.data = w2
497
+ if has_bias:
498
+ vllm_fused_moe.w13_bias.data = w1_bias
499
+ vllm_fused_moe.w2_bias.data = w2_bias
500
+
501
+ expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
502
+ vllm_fused_moe.top_k,
503
+ vllm_fused_moe.renormalize,
504
+ vllm_fused_moe.activation)
505
+
506
+ with torchax.default_env(), set_forward_context(None, vllm_config):
507
+ assert isinstance(vllm_fused_moe.quant_method,
508
+ VllmUnquantizedFusedMoEMethod)
509
+
510
+ jax_a = a.to('jax')
511
+ score = score.to('jax')
512
+
513
+ vllm_fused_moe.quant_method.process_weights_after_loading(
514
+ vllm_fused_moe)
515
+ actual = vllm_fused_moe(jax_a, score)
516
+
517
+ torch.testing.assert_close(expected,
518
+ actual,
519
+ check_device=False,
520
+ atol=1e-1,
521
+ rtol=1e-1)
522
+
523
+
524
+ @pytest.mark.parametrize("num_devices", [jax.local_device_count()])
525
+ @pytest.mark.parametrize("num_tokens", [128, 512])
526
+ @pytest.mark.parametrize("intermediate_size", [512])
527
+ @pytest.mark.parametrize("hidden_size", [512])
528
+ @pytest.mark.parametrize("num_experts", [32])
529
+ @pytest.mark.parametrize("topk", [8])
530
+ @pytest.mark.parametrize("has_bias", [False, True])
531
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
532
+ def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
533
+ hidden_size, num_experts, topk, has_bias,
534
+ enable_attn_dp):
535
+ # Skip if enable_attn_dp is True but we don't have enough devices
536
+ if enable_attn_dp and num_devices < 2:
537
+ pytest.skip("enable_attn_dp requires at least 2 devices")
538
+
539
+ # Skip attn_dp tests for fused_moe_use_kernel since the kernel only supports 2D mesh
540
+ if enable_attn_dp:
541
+ pytest.skip(
542
+ "fused_moe kernel does not support attn_dp (requires 2D mesh)")
543
+
544
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
545
+
546
+ # TODO(Qiliang Cui): Remove when issue is resolved.
547
+ if not jtu.is_device_tpu_at_least(version=7):
548
+ pytest.skip(allow_module_level=True, reason="Expected TPUv7+")
549
+
550
+ torch.manual_seed(42)
551
+ dtype = torch.bfloat16
552
+
553
+ a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
554
+ w1 = torch.randn(
555
+ (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
556
+ w2 = torch.randn(
557
+ (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
558
+
559
+ w1_bias = w2_bias = None
560
+ if has_bias:
561
+ w1_bias = torch.randn(
562
+ (num_experts, 2 * intermediate_size), dtype=dtype) / 10
563
+ w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
564
+
565
+ # Use deterministic gating_output generation (same logic as fused_moe_v1_test.py)
566
+ # Generate base gating scores with deterministic pattern
567
+ score = (
568
+ torch.randn((num_tokens, num_experts), dtype=torch.float32) +
569
+ torch.arange(num_tokens * num_experts, dtype=torch.float32).reshape(
570
+ num_tokens, num_experts) / 100)
571
+
572
+ # Generate unique top-k indices
573
+ generator = torch.Generator()
574
+ generator.manual_seed(42)
575
+ top_k_indices = torch.randint(0,
576
+ num_experts - 1, (num_tokens, topk),
577
+ dtype=torch.int32,
578
+ generator=generator)
579
+
580
+ # Add one-hot encoding weighted by 10 to ensure selected experts have highest scores
581
+ one_hot = torch.nn.functional.one_hot(top_k_indices.long(),
582
+ num_classes=num_experts).float()
583
+ one_hot = one_hot.sum(dim=1) * 10
584
+ score = (score + one_hot).to(dtype)
585
+
586
+ engine_args = EngineArgs(
587
+ model="Qwen/Qwen2-1.5B-Instruct",
588
+ max_model_len=64,
589
+ max_num_batched_tokens=64,
590
+ max_num_seqs=4,
591
+ )
592
+ vllm_config = engine_args.create_engine_config()
593
+ vllm_config.model_config.dtype = dtype
594
+ vllm_config.parallel_config = ParallelConfig(
595
+ tensor_parallel_size=mesh.devices.size)
596
+
597
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
598
+ with set_current_vllm_config(vllm_config):
599
+ vllm_fused_moe = FusedMoE(
600
+ num_experts=num_experts,
601
+ top_k=topk,
602
+ hidden_size=hidden_size,
603
+ intermediate_size=intermediate_size,
604
+ reduce_results=True,
605
+ renormalize=False,
606
+ tp_size=mesh.devices.size,
607
+ dp_size=1,
608
+ quant_config=quant_config,
609
+ has_bias=has_bias,
610
+ )
611
+ vllm_fused_moe.moe_parallel_config.use_ep = True
612
+ vllm_fused_moe.quant_method.use_kernel = True
613
+
614
+ vllm_fused_moe.w13_weight.data = w1
615
+ vllm_fused_moe.w2_weight.data = w2
616
+ if has_bias:
617
+ vllm_fused_moe.w13_bias.data = w1_bias
618
+ vllm_fused_moe.w2_bias.data = w2_bias
619
+
620
+ expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
621
+ vllm_fused_moe.top_k,
622
+ vllm_fused_moe.renormalize,
623
+ vllm_fused_moe.activation)
624
+
625
+ with torchax.default_env(), set_forward_context(None, vllm_config):
626
+ assert isinstance(vllm_fused_moe.quant_method,
627
+ VllmUnquantizedFusedMoEMethod)
628
+ jax_a = a.to('jax')
629
+ score = score.to('jax')
630
+
631
+ vllm_fused_moe.quant_method.process_weights_after_loading(
632
+ vllm_fused_moe)
633
+ vllm_fused_moe.quant_method.block_size = {
634
+ "bt": 32,
635
+ "bf": 512,
636
+ "bd1": 512,
637
+ "bd2": 512,
638
+ "btc": 32,
639
+ "bfc": 256,
640
+ "bd1c": 256,
641
+ "bd2c": 256,
642
+ }
643
+ actual = vllm_fused_moe(jax_a, score)
644
+
645
+ torch.testing.assert_close(
646
+ expected,
647
+ actual,
648
+ check_device=False,
649
+ atol=1e-2,
650
+ rtol=1e-2,
651
+ )