tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,662 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+ from unittest import mock
17
+
18
+ import jax
19
+ import pytest
20
+ import torch
21
+ import torchax
22
+ from jax._src import test_util as jtu
23
+ from jax.sharding import NamedSharding, PartitionSpec
24
+ from torchax.interop import torch_view
25
+ from torchax.ops.mappings import j2t, t2j
26
+ from vllm.config import ParallelConfig, set_current_vllm_config
27
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
28
+ init_distributed_environment)
29
+ from vllm.engine.arg_utils import EngineArgs
30
+ from vllm.forward_context import set_forward_context
31
+ from vllm.model_executor.layers.fused_moe import FusedMoE
32
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
33
+ LinearBase,
34
+ MergedColumnParallelLinear,
35
+ QKVParallelLinear,
36
+ RowParallelLinear)
37
+ from vllm.model_executor.model_loader import get_model as vllm_get_model
38
+
39
+ from tpu_inference.layers.vllm.fused_moe import FusedMoEBackend
40
+ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
41
+ from tpu_inference.layers.vllm.quantization.unquantized import (
42
+ VllmUnquantizedConfig, VllmUnquantizedFusedMoEMethod,
43
+ VllmUnquantizedLinearMethod)
44
+
45
+ from . import utils as test_utils
46
+
47
+ P = PartitionSpec
48
+ MODELS = ["Qwen/Qwen2-1.5B-Instruct"]
49
+
50
+
51
+ @pytest.fixture(autouse=True)
52
+ def setup_environment():
53
+ # This is a fake config used for init dist env.
54
+ # RowParallelLinear needs dist env to be initialized.
55
+ engine_args = EngineArgs(
56
+ model=MODELS[0],
57
+ max_model_len=64,
58
+ max_num_batched_tokens=64,
59
+ max_num_seqs=4,
60
+ )
61
+
62
+ vllm_config = engine_args.create_engine_config()
63
+
64
+ with set_current_vllm_config(vllm_config):
65
+ temp_file = tempfile.mkstemp()[1]
66
+ init_distributed_environment(
67
+ 1,
68
+ 0,
69
+ local_rank=0,
70
+ distributed_init_method=f"file://{temp_file}",
71
+ backend="gloo")
72
+ ensure_model_parallel_initialized(1, 1)
73
+
74
+
75
+ @pytest.mark.parametrize("model", MODELS)
76
+ @pytest.mark.parametrize("mesh", [
77
+ test_utils.get_spmd_mesh(1),
78
+ test_utils.get_spmd_mesh(jax.local_device_count())
79
+ ])
80
+ def test_quant_override(model, mesh):
81
+
82
+ engine_args = EngineArgs(
83
+ model=model,
84
+ max_model_len=64,
85
+ max_num_batched_tokens=64,
86
+ max_num_seqs=4,
87
+ )
88
+ vllm_config = engine_args.create_engine_config()
89
+ vllm_config.model_config.dtype = torch.bfloat16
90
+
91
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
92
+ assert isinstance(quant_config, VllmUnquantizedConfig)
93
+ assert quant_config.vllm_config == vllm_config
94
+ assert quant_config.mesh == mesh
95
+
96
+
97
+ @pytest.mark.parametrize("model", MODELS)
98
+ @pytest.mark.parametrize("mesh", [
99
+ test_utils.get_spmd_mesh(1),
100
+ test_utils.get_spmd_mesh(jax.local_device_count())
101
+ ])
102
+ def test_loading_model(model, mesh):
103
+ engine_args = EngineArgs(
104
+ model=model,
105
+ max_model_len=64,
106
+ max_num_batched_tokens=64,
107
+ max_num_seqs=4,
108
+ )
109
+ vllm_config = engine_args.create_engine_config()
110
+ vllm_config.model_config.dtype = torch.bfloat16
111
+ vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
112
+ vllm_config.device_config.device = "cpu"
113
+
114
+ vllm_model = vllm_get_model(vllm_config=vllm_config)
115
+ layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
116
+ for layer in layers:
117
+ assert isinstance(layer.quant_config, VllmUnquantizedConfig)
118
+ assert isinstance(layer.quant_method, VllmUnquantizedLinearMethod)
119
+
120
+
121
+ @pytest.mark.parametrize("model", MODELS)
122
+ @pytest.mark.parametrize("bias", [False, True])
123
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
124
+ @pytest.mark.parametrize("enable_sp", [False, True])
125
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
126
+ def test_row_parallel_linear(model, bias, num_devices, enable_sp,
127
+ enable_attn_dp):
128
+ # Skip if enable_attn_dp is True but we don't have enough devices
129
+ if enable_attn_dp and num_devices < 2:
130
+ pytest.skip("enable_attn_dp requires at least 2 devices")
131
+
132
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
133
+ dtype = torch.bfloat16
134
+
135
+ engine_args = EngineArgs(
136
+ model=model,
137
+ max_model_len=64,
138
+ max_num_batched_tokens=64,
139
+ max_num_seqs=4,
140
+ )
141
+ vllm_config = engine_args.create_engine_config()
142
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
143
+
144
+ with set_current_vllm_config(vllm_config):
145
+ row_linear = RowParallelLinear(
146
+ input_size=4096,
147
+ output_size=8192,
148
+ bias=bias,
149
+ params_dtype=dtype,
150
+ return_bias=False,
151
+ )
152
+
153
+ input_tensor = torch.rand(10, row_linear.input_size, dtype=dtype) / 10
154
+ input_tensor = input_tensor.to('cpu')
155
+
156
+ weight_data = torch.rand_like(row_linear.weight.data) / 10
157
+ if bias:
158
+ bias_data = torch.rand_like(row_linear.bias.data)
159
+
160
+ row_linear.weight.data = weight_data
161
+ if bias:
162
+ row_linear.bias.data = bias_data
163
+ row_linear = row_linear.to('cpu')
164
+ row_linear.quant_method.process_weights_after_loading(row_linear)
165
+ output = row_linear(input_tensor).to(dtype)
166
+
167
+ vllm_config.model_config.dtype = dtype
168
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
169
+ with set_current_vllm_config(vllm_config):
170
+ jax_row_linear = RowParallelLinear(
171
+ input_size=4096,
172
+ output_size=8192,
173
+ bias=bias,
174
+ params_dtype=dtype,
175
+ return_bias=False,
176
+ quant_config=quant_config,
177
+ )
178
+
179
+ jax_row_linear.weight.data = weight_data
180
+ if bias:
181
+ jax_row_linear.bias.data = bias_data
182
+
183
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
184
+ jax_input_tensor.apply_jax_(jax.device_put,
185
+ NamedSharding(mesh, P(None, None)))
186
+ with torchax.default_env():
187
+ assert isinstance(jax_row_linear.quant_method,
188
+ VllmUnquantizedLinearMethod)
189
+ jax_row_linear.quant_method.process_weights_after_loading(
190
+ jax_row_linear)
191
+ jax_output = jax_row_linear(jax_input_tensor)
192
+ # j2t() doens't support bfloat16, so we cast it into float32 as an intermedate step.
193
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
194
+
195
+ torch.testing.assert_close(output, jax_output)
196
+
197
+
198
+ @pytest.mark.parametrize("model", MODELS)
199
+ @pytest.mark.parametrize("bias", [False, True])
200
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
201
+ @pytest.mark.parametrize("enable_sp", [False, True])
202
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
203
+ def test_column_parallel_linear(model, bias, num_devices, enable_sp,
204
+ enable_attn_dp):
205
+ # Skip if enable_attn_dp is True but we don't have enough devices
206
+ if enable_attn_dp and num_devices < 2:
207
+ pytest.skip("enable_attn_dp requires at least 2 devices")
208
+
209
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
210
+ dtype = torch.bfloat16
211
+
212
+ engine_args = EngineArgs(
213
+ model=model,
214
+ max_model_len=64,
215
+ max_num_batched_tokens=64,
216
+ max_num_seqs=4,
217
+ )
218
+ vllm_config = engine_args.create_engine_config()
219
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
220
+
221
+ with set_current_vllm_config(vllm_config):
222
+ column_linear = ColumnParallelLinear(
223
+ input_size=4096,
224
+ output_size=8192,
225
+ bias=bias,
226
+ params_dtype=dtype,
227
+ return_bias=False,
228
+ )
229
+
230
+ input_tensor = torch.rand(10, column_linear.input_size, dtype=dtype) / 10
231
+ input_tensor = input_tensor.to('cpu')
232
+
233
+ weight_data = torch.rand_like(column_linear.weight.data) / 10
234
+ if bias:
235
+ bias_data = torch.rand_like(column_linear.bias.data)
236
+
237
+ column_linear.weight.data = weight_data
238
+ if bias:
239
+ column_linear.bias.data = bias_data
240
+ column_linear = column_linear.to('cpu')
241
+ column_linear.quant_method.process_weights_after_loading(column_linear)
242
+ output = column_linear(input_tensor).to(dtype)
243
+
244
+ vllm_config.model_config.dtype = dtype
245
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
246
+ with set_current_vllm_config(vllm_config):
247
+ jax_column_linear = ColumnParallelLinear(
248
+ input_size=4096,
249
+ output_size=8192,
250
+ bias=bias,
251
+ params_dtype=dtype,
252
+ return_bias=False,
253
+ quant_config=quant_config,
254
+ )
255
+
256
+ jax_column_linear.weight.data = weight_data
257
+ if bias:
258
+ jax_column_linear.bias.data = bias_data
259
+
260
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
261
+ jax_input_tensor.apply_jax_(jax.device_put,
262
+ NamedSharding(mesh, P(None, None)))
263
+ with torchax.default_env():
264
+ assert isinstance(jax_column_linear.quant_method,
265
+ VllmUnquantizedLinearMethod)
266
+ jax_column_linear.quant_method.process_weights_after_loading(
267
+ jax_column_linear)
268
+ jax_output = jax_column_linear(jax_input_tensor)
269
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
270
+
271
+ torch.testing.assert_close(output, jax_output)
272
+
273
+
274
+ @pytest.mark.parametrize("model", MODELS)
275
+ @pytest.mark.parametrize("bias", [False, True])
276
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
277
+ @pytest.mark.parametrize("enable_sp", [False, True])
278
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
279
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
280
+ def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
281
+ enable_attn_dp):
282
+ # Skip if enable_attn_dp is True but we don't have enough devices
283
+ if enable_attn_dp and num_devices < 2:
284
+ pytest.skip("enable_attn_dp requires at least 2 devices")
285
+
286
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
287
+ dtype = torch.bfloat16
288
+
289
+ engine_args = EngineArgs(
290
+ model=model,
291
+ max_model_len=64,
292
+ max_num_batched_tokens=64,
293
+ max_num_seqs=4,
294
+ )
295
+ vllm_config = engine_args.create_engine_config()
296
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
297
+
298
+ with set_current_vllm_config(vllm_config):
299
+ qkv_linear = QKVParallelLinear(
300
+ hidden_size=4096,
301
+ head_size=128,
302
+ total_num_heads=32,
303
+ total_num_kv_heads=8,
304
+ bias=bias,
305
+ params_dtype=dtype,
306
+ return_bias=False,
307
+ )
308
+
309
+ input_tensor = torch.rand(10, qkv_linear.input_size, dtype=dtype) / 10
310
+ input_tensor = input_tensor.to('cpu')
311
+
312
+ weight_data = torch.rand_like(qkv_linear.weight.data) / 10
313
+ if bias:
314
+ bias_data = torch.rand_like(qkv_linear.bias.data)
315
+
316
+ qkv_linear.weight.data = weight_data
317
+ if bias:
318
+ qkv_linear.bias.data = bias_data
319
+ qkv_linear = qkv_linear.to('cpu')
320
+ qkv_linear.quant_method.process_weights_after_loading(qkv_linear)
321
+ output = qkv_linear(input_tensor).to(dtype)
322
+
323
+ vllm_config.model_config.dtype = dtype
324
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
325
+ with set_current_vllm_config(vllm_config):
326
+ jax_qkv_linear = QKVParallelLinear(
327
+ hidden_size=4096,
328
+ head_size=128,
329
+ total_num_heads=32,
330
+ total_num_kv_heads=8,
331
+ bias=bias,
332
+ params_dtype=dtype,
333
+ return_bias=False,
334
+ quant_config=quant_config,
335
+ )
336
+ jax_qkv_linear.quant_method.fuse_matmuls = fuse_matmuls
337
+
338
+ jax_qkv_linear.weight.data = weight_data
339
+ if bias:
340
+ jax_qkv_linear.bias.data = bias_data
341
+
342
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
343
+ jax_input_tensor.apply_jax_(jax.device_put,
344
+ NamedSharding(mesh, P(None, None)))
345
+ with torchax.default_env():
346
+ assert isinstance(jax_qkv_linear.quant_method,
347
+ VllmUnquantizedLinearMethod)
348
+ jax_qkv_linear.quant_method.process_weights_after_loading(
349
+ jax_qkv_linear)
350
+ jax_output = jax_qkv_linear(jax_input_tensor)
351
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
352
+
353
+ torch.testing.assert_close(output, jax_output)
354
+
355
+
356
+ @pytest.mark.parametrize("model", MODELS)
357
+ @pytest.mark.parametrize("bias", [False, True])
358
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
359
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
360
+ @pytest.mark.parametrize("enable_sp", [False, True])
361
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
362
+ def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
363
+ enable_sp, enable_attn_dp):
364
+ # Skip if enable_attn_dp is True but we don't have enough devices
365
+ if enable_attn_dp and num_devices < 2:
366
+ pytest.skip("enable_attn_dp requires at least 2 devices")
367
+
368
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
369
+ dtype = torch.bfloat16
370
+
371
+ engine_args = EngineArgs(
372
+ model=model,
373
+ max_model_len=64,
374
+ max_num_batched_tokens=64,
375
+ max_num_seqs=4,
376
+ )
377
+ vllm_config = engine_args.create_engine_config()
378
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
379
+
380
+ # Call vLLM code
381
+ with set_current_vllm_config(vllm_config):
382
+ merged_column_linear = MergedColumnParallelLinear(
383
+ input_size=4096,
384
+ output_sizes=[14336] * 2,
385
+ bias=bias,
386
+ params_dtype=dtype,
387
+ return_bias=False,
388
+ )
389
+
390
+ input_tensor = torch.rand(10, merged_column_linear.input_size,
391
+ dtype=dtype) / 10
392
+ input_tensor = input_tensor.to('cpu')
393
+
394
+ weight_data = torch.rand_like(merged_column_linear.weight.data) / 10
395
+ if bias:
396
+ bias_data = torch.rand_like(merged_column_linear.bias.data)
397
+
398
+ merged_column_linear.weight.data = weight_data
399
+ if bias:
400
+ merged_column_linear.bias.data = bias_data
401
+ merged_column_linear = merged_column_linear.to('cpu')
402
+ merged_column_linear.quant_method.process_weights_after_loading(
403
+ merged_column_linear)
404
+ output = merged_column_linear(input_tensor).to(dtype)
405
+
406
+ # Call tpu_inference code
407
+ vllm_config.model_config.dtype = dtype
408
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
409
+ with set_current_vllm_config(vllm_config):
410
+ jax_merged_column_linear = MergedColumnParallelLinear(
411
+ input_size=4096,
412
+ output_sizes=[14336] * 2,
413
+ bias=bias,
414
+ params_dtype=dtype,
415
+ return_bias=False,
416
+ quant_config=quant_config,
417
+ )
418
+ jax_merged_column_linear.quant_method.fuse_matmuls = fuse_matmuls
419
+
420
+ jax_merged_column_linear.weight.data = weight_data
421
+ if bias:
422
+ jax_merged_column_linear.bias.data = bias_data
423
+
424
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
425
+ jax_input_tensor.apply_jax_(jax.device_put,
426
+ NamedSharding(mesh, P(None, None)))
427
+ with torchax.default_env():
428
+ assert isinstance(jax_merged_column_linear.quant_method,
429
+ VllmUnquantizedLinearMethod)
430
+ jax_merged_column_linear.quant_method.process_weights_after_loading(
431
+ jax_merged_column_linear)
432
+ jax_output = jax_merged_column_linear(jax_input_tensor)
433
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
434
+
435
+ torch.testing.assert_close(output, jax_output)
436
+
437
+
438
+ @pytest.mark.parametrize("use_ep", [True, False])
439
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
440
+ @pytest.mark.parametrize("num_tokens", [8])
441
+ @pytest.mark.parametrize("intermediate_size", [1024, 2048])
442
+ @pytest.mark.parametrize("hidden_size", [128, 512])
443
+ @pytest.mark.parametrize("num_experts", [8])
444
+ @pytest.mark.parametrize("topk", [2])
445
+ @pytest.mark.parametrize("has_bias", [False, True])
446
+ @pytest.mark.parametrize("activation", ["silu", "swigluoai"])
447
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
448
+ def test_fused_moe(use_ep, num_devices, num_tokens, intermediate_size,
449
+ hidden_size, num_experts, topk, has_bias, activation,
450
+ enable_attn_dp):
451
+ # Skip if enable_attn_dp is True but we don't have enough devices
452
+ if enable_attn_dp and num_devices < 2:
453
+ pytest.skip("enable_attn_dp requires at least 2 devices")
454
+
455
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
456
+
457
+ torch.manual_seed(42)
458
+ dtype = torch.bfloat16
459
+
460
+ a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
461
+ w1 = torch.randn(
462
+ (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
463
+ w2 = torch.randn(
464
+ (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
465
+ score = torch.randn((num_tokens, num_experts), dtype=dtype)
466
+
467
+ w1_bias = w2_bias = None
468
+ if has_bias:
469
+ w1_bias = torch.randn(
470
+ (num_experts, 2 * intermediate_size), dtype=dtype) / 10
471
+ w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
472
+
473
+ engine_args = EngineArgs(
474
+ model="Qwen/Qwen2-1.5B-Instruct",
475
+ max_model_len=64,
476
+ max_num_batched_tokens=64,
477
+ max_num_seqs=4,
478
+ )
479
+ vllm_config = engine_args.create_engine_config()
480
+ vllm_config.model_config.dtype = dtype
481
+ vllm_config.parallel_config = ParallelConfig(
482
+ tensor_parallel_size=mesh.devices.size, enable_expert_parallel=use_ep)
483
+
484
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
485
+ with set_current_vllm_config(vllm_config):
486
+ vllm_fused_moe = FusedMoE(
487
+ num_experts=num_experts,
488
+ top_k=topk,
489
+ hidden_size=hidden_size,
490
+ intermediate_size=intermediate_size,
491
+ reduce_results=False,
492
+ renormalize=False,
493
+ tp_size=1,
494
+ dp_size=1,
495
+ quant_config=quant_config,
496
+ has_bias=has_bias,
497
+ activation=activation,
498
+ )
499
+ vllm_fused_moe.moe_parallel_config.use_ep = use_ep
500
+ vllm_fused_moe.w13_weight.data = w1
501
+ vllm_fused_moe.w2_weight.data = w2
502
+ if has_bias:
503
+ vllm_fused_moe.w13_bias.data = w1_bias
504
+ vllm_fused_moe.w2_bias.data = w2_bias
505
+
506
+ expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
507
+ vllm_fused_moe.top_k,
508
+ vllm_fused_moe.renormalize,
509
+ vllm_fused_moe.activation)
510
+
511
+ with torchax.default_env(), set_forward_context(None, vllm_config):
512
+ assert isinstance(vllm_fused_moe.quant_method,
513
+ VllmUnquantizedFusedMoEMethod)
514
+ if use_ep:
515
+ assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.GMM_EP
516
+ else:
517
+ assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.GMM_TP
518
+
519
+ jax_a = a.to('jax')
520
+ score = score.to('jax')
521
+
522
+ vllm_fused_moe.quant_method.process_weights_after_loading(
523
+ vllm_fused_moe)
524
+ actual = vllm_fused_moe(jax_a, score)
525
+
526
+ torch.testing.assert_close(expected,
527
+ actual,
528
+ check_device=False,
529
+ atol=1e-1,
530
+ rtol=1e-1)
531
+
532
+
533
+ @pytest.mark.parametrize("num_devices", [jax.local_device_count()])
534
+ @pytest.mark.parametrize("num_tokens", [128, 512])
535
+ @pytest.mark.parametrize("intermediate_size", [512])
536
+ @pytest.mark.parametrize("hidden_size", [512])
537
+ @pytest.mark.parametrize("num_experts", [32])
538
+ @pytest.mark.parametrize("topk", [8])
539
+ @pytest.mark.parametrize("has_bias", [False, True])
540
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
541
+ @mock.patch("os.environ", {"USE_MOE_EP_KERNEL": "1"})
542
+ def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
543
+ hidden_size, num_experts, topk, has_bias,
544
+ enable_attn_dp):
545
+ # Skip if enable_attn_dp is True but we don't have enough devices
546
+ if enable_attn_dp and num_devices < 2:
547
+ pytest.skip("enable_attn_dp requires at least 2 devices")
548
+
549
+ # Skip attn_dp tests for fused_moe_use_kernel since the kernel only supports 2D mesh
550
+ if enable_attn_dp:
551
+ pytest.skip(
552
+ "fused_moe kernel does not support attn_dp (requires 2D mesh)")
553
+
554
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
555
+
556
+ # TODO(Qiliang Cui): Remove when issue is resolved.
557
+ if not jtu.is_device_tpu_at_least(version=7):
558
+ pytest.skip(allow_module_level=True, reason="Expected TPUv7+")
559
+
560
+ torch.manual_seed(42)
561
+ dtype = torch.bfloat16
562
+
563
+ a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
564
+ w1 = torch.randn(
565
+ (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
566
+ w2 = torch.randn(
567
+ (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
568
+
569
+ w1_bias = w2_bias = None
570
+ if has_bias:
571
+ w1_bias = torch.randn(
572
+ (num_experts, 2 * intermediate_size), dtype=dtype) / 10
573
+ w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
574
+
575
+ # Use deterministic gating_output generation (same logic as fused_moe_v1_test.py)
576
+ # Generate base gating scores with deterministic pattern
577
+ score = (
578
+ torch.randn((num_tokens, num_experts), dtype=torch.float32) +
579
+ torch.arange(num_tokens * num_experts, dtype=torch.float32).reshape(
580
+ num_tokens, num_experts) / 100)
581
+
582
+ # Generate unique top-k indices
583
+ generator = torch.Generator()
584
+ generator.manual_seed(42)
585
+ top_k_indices = torch.randint(0,
586
+ num_experts - 1, (num_tokens, topk),
587
+ dtype=torch.int32,
588
+ generator=generator)
589
+
590
+ # Add one-hot encoding weighted by 10 to ensure selected experts have highest scores
591
+ one_hot = torch.nn.functional.one_hot(top_k_indices.long(),
592
+ num_classes=num_experts).float()
593
+ one_hot = one_hot.sum(dim=1) * 10
594
+ score = (score + one_hot).to(dtype)
595
+
596
+ engine_args = EngineArgs(
597
+ model="Qwen/Qwen2-1.5B-Instruct",
598
+ max_model_len=64,
599
+ max_num_batched_tokens=64,
600
+ max_num_seqs=4,
601
+ )
602
+ vllm_config = engine_args.create_engine_config()
603
+ vllm_config.model_config.dtype = dtype
604
+ vllm_config.parallel_config = ParallelConfig(
605
+ tensor_parallel_size=mesh.devices.size, enable_expert_parallel=True)
606
+
607
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
608
+ with set_current_vllm_config(vllm_config):
609
+ vllm_fused_moe = FusedMoE(
610
+ num_experts=num_experts,
611
+ top_k=topk,
612
+ hidden_size=hidden_size,
613
+ intermediate_size=intermediate_size,
614
+ reduce_results=True,
615
+ renormalize=False,
616
+ tp_size=mesh.devices.size,
617
+ dp_size=1,
618
+ quant_config=quant_config,
619
+ has_bias=has_bias,
620
+ )
621
+ vllm_fused_moe.moe_parallel_config.use_ep = True
622
+
623
+ vllm_fused_moe.w13_weight.data = w1
624
+ vllm_fused_moe.w2_weight.data = w2
625
+ if has_bias:
626
+ vllm_fused_moe.w13_bias.data = w1_bias
627
+ vllm_fused_moe.w2_bias.data = w2_bias
628
+
629
+ expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
630
+ vllm_fused_moe.top_k,
631
+ vllm_fused_moe.renormalize,
632
+ vllm_fused_moe.activation)
633
+
634
+ with torchax.default_env(), set_forward_context(None, vllm_config):
635
+ assert isinstance(vllm_fused_moe.quant_method,
636
+ VllmUnquantizedFusedMoEMethod)
637
+ assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.FUSED_MOE
638
+
639
+ jax_a = a.to('jax')
640
+ score = score.to('jax')
641
+
642
+ vllm_fused_moe.quant_method.process_weights_after_loading(
643
+ vllm_fused_moe)
644
+ vllm_fused_moe.quant_method.extra_backend_kwargs.update({
645
+ "bt": 32,
646
+ "bf": 512,
647
+ "bd1": 512,
648
+ "bd2": 512,
649
+ "btc": 32,
650
+ "bfc": 256,
651
+ "bd1c": 256,
652
+ "bd2c": 256,
653
+ })
654
+ actual = vllm_fused_moe(jax_a, score)
655
+
656
+ torch.testing.assert_close(
657
+ expected,
658
+ actual,
659
+ check_device=False,
660
+ atol=1e-2,
661
+ rtol=1e-2,
662
+ )