tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +46 -17
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +44 -17
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
  240. tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,405 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+ from typing import Optional
17
+
18
+ import jax
19
+ import pytest
20
+ import torch
21
+ import torchax
22
+ from jax.sharding import PartitionSpec
23
+ from torchax.interop import torch_view
24
+ from torchax.ops.mappings import j2t, t2j
25
+ from vllm.config import set_current_vllm_config
26
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
27
+ init_distributed_environment)
28
+ from vllm.engine.arg_utils import EngineArgs
29
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
30
+ LinearBase,
31
+ MergedColumnParallelLinear,
32
+ QKVParallelLinear,
33
+ RowParallelLinear)
34
+ from vllm.model_executor.layers.quantization.utils.quant_utils import \
35
+ pack_quantized_values_into_int32
36
+ from vllm.model_executor.model_loader import get_model as vllm_get_model
37
+ from vllm.scalar_type import scalar_types
38
+
39
+ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
40
+ from tpu_inference.layers.vllm.quantization.awq import (VllmAWQConfig,
41
+ VllmAWQLinearMethod)
42
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
43
+
44
+ from . import utils as test_utils
45
+
46
+ P = PartitionSpec
47
+ MODELS = ["Qwen/Qwen2.5-1.5B-Instruct-AWQ"]
48
+
49
+
50
+ def ref_quantize_uint4(x: torch.Tensor, group_size: int):
51
+ uint4_max = 15
52
+
53
+ # For group quantization, we reshape so that x[0], x[1], ... x[i] are
54
+ # quantized with different scale values.
55
+ x = torch.reshape(x, (-1, group_size) + (x.shape[1:]))
56
+
57
+ # Equation for asymmetric quantization is x_q = (x + x_z) / scale where
58
+ # x_z is calculated to ensure x + x_z does not contain any negative values.
59
+ offset = torch.clamp(-torch.amin(x, dim=1, keepdim=True), min=0)
60
+ x += offset
61
+ # After adding offset, x will not contain any negative values.
62
+ assert x.min() >= 0
63
+
64
+ x_abs_max = torch.amax(x, dim=1, keepdim=True)
65
+ x_s = x_abs_max / uint4_max
66
+ # torch does not support uint4, therefore, we cast to int32 instead.
67
+ x_q = torch.clip(x / x_s, 0, uint4_max).to(torch.int32)
68
+ x_z = torch.clip(offset / x_s, 0, uint4_max).to(torch.int32)
69
+ return x_q, x_z, x_s.to(torch.float32)
70
+
71
+
72
+ def ref_w4a16(x: torch.Tensor, w_q: torch.Tensor, w_z: torch.Tensor,
73
+ w_s: torch.Tensor, b: Optional[torch.Tensor]):
74
+ # Dequantize asymetric quantized weight.
75
+ w = (w_q.to(torch.float32) - w_z.to(torch.float32)) * w_s
76
+ w = w.reshape((-1, w.shape[-1]))
77
+ out = torch.einsum('bd,df->bf', x.to(torch.float32), w)
78
+ if b is not None:
79
+ out += b
80
+ return out.to(x.dtype)
81
+
82
+
83
+ def pack_awq_weight_into_int32(weight: torch.Tensor):
84
+ # AWQ packs 8 uint4 into 32-bits in this order.
85
+ awq_order = (0, 2, 4, 6, 1, 3, 5, 7)
86
+
87
+ orig_shape = weight.shape
88
+ weight = weight.reshape(orig_shape[:-1] + (-1, 8))
89
+ weight = weight[..., awq_order].reshape(orig_shape)
90
+
91
+ return pack_quantized_values_into_int32(weight, scalar_types.uint4, 1)
92
+
93
+
94
+ def return_ref_and_layer_output(
95
+ layer: torch.nn.Module,
96
+ qweight: torch.Tensor,
97
+ qzeros: torch.Tensor,
98
+ scales: torch.Tensor,
99
+ batch_size: int = 16,
100
+ ):
101
+ assert isinstance(layer, LinearBase)
102
+ quant_method = layer.quant_method
103
+ assert isinstance(quant_method, VllmAWQLinearMethod)
104
+ quant_config = quant_method.quant_config
105
+ assert isinstance(quant_config, VllmAWQConfig)
106
+ jax_config = quant_method.jax_config
107
+ assert isinstance(jax_config, JaxCommonLinearConfig)
108
+
109
+ input_tensor = torch.rand(
110
+ batch_size, layer.input_size, dtype=torch.bfloat16) / 10
111
+ input_tensor = input_tensor.to('cpu')
112
+
113
+ ref_output = ref_w4a16(
114
+ input_tensor,
115
+ qweight,
116
+ qzeros,
117
+ scales,
118
+ layer.bias,
119
+ )
120
+
121
+ # Run torchax/jax function
122
+ quant_method.process_weights_after_loading(layer)
123
+ with torchax.default_env():
124
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
125
+ layer_output = layer(jax_input_tensor)
126
+ layer_output = j2t(layer_output.to(torch.float32)).to(torch.bfloat16)
127
+
128
+ return ref_output, layer_output
129
+
130
+
131
+ def initialize_and_return_layer_weights(layer: torch.nn.Module):
132
+ assert isinstance(layer, LinearBase)
133
+ quant_method = layer.quant_method
134
+ assert isinstance(quant_method, VllmAWQLinearMethod)
135
+ quant_config = quant_method.quant_config
136
+ assert isinstance(quant_config, VllmAWQConfig)
137
+ jax_config = quant_method.jax_config
138
+ assert isinstance(jax_config, JaxCommonLinearConfig)
139
+
140
+ # torch.rand returns value in the range of [0, 1). We subtract by 0.2 to
141
+ # simulate asymmetry
142
+ weight = torch.rand((layer.input_size, layer.output_size)) - 0.2
143
+ qweight, qzeros, scales = ref_quantize_uint4(weight,
144
+ quant_config.group_size)
145
+
146
+ # We modify uint4 quantized weights into AWQ format.
147
+ layer_qweight = qweight.reshape((-1, layer.output_size))
148
+ layer_qzeros = qzeros.reshape((-1, layer.output_size))
149
+ layer_scales = scales.reshape((-1, layer.output_size))
150
+
151
+ layer_qweight = pack_awq_weight_into_int32(layer_qweight)
152
+ layer_qzeros = pack_awq_weight_into_int32(layer_qzeros)
153
+
154
+ assert layer.qweight.data.shape == layer_qweight.shape
155
+ assert layer.qzeros.data.shape == layer_qzeros.shape
156
+ assert layer.scales.data.shape == layer_scales.shape
157
+
158
+ layer.qweight.data = layer_qweight
159
+ layer.qzeros.data = layer_qzeros
160
+ layer.scales.data = layer_scales
161
+
162
+ bias = None
163
+ if layer.bias is not None:
164
+ bias = torch.rand_like(layer.bias.data)
165
+ layer.bias.data = bias
166
+
167
+ return qweight, qzeros, scales, bias
168
+
169
+
170
+ @pytest.fixture(autouse=True)
171
+ def setup_environment():
172
+ # This is a fake config used for init dist env.
173
+ # RowParallelLinear needs dist env to be initialized.
174
+ engine_args = EngineArgs(
175
+ model=MODELS[0],
176
+ max_model_len=64,
177
+ max_num_batched_tokens=64,
178
+ max_num_seqs=4,
179
+ )
180
+
181
+ vllm_config = engine_args.create_engine_config()
182
+
183
+ with set_current_vllm_config(vllm_config):
184
+ temp_file = tempfile.mkstemp()[1]
185
+ init_distributed_environment(
186
+ 1,
187
+ 0,
188
+ local_rank=0,
189
+ distributed_init_method=f"file://{temp_file}",
190
+ backend="gloo")
191
+ ensure_model_parallel_initialized(1, 1)
192
+
193
+
194
+ @pytest.mark.parametrize("model", MODELS)
195
+ @pytest.mark.parametrize("mesh", [
196
+ test_utils.get_spmd_mesh(1),
197
+ test_utils.get_spmd_mesh(jax.local_device_count())
198
+ ])
199
+ def test_quant_override(model, mesh):
200
+
201
+ engine_args = EngineArgs(
202
+ model=model,
203
+ max_model_len=64,
204
+ max_num_batched_tokens=64,
205
+ max_num_seqs=4,
206
+ )
207
+ vllm_config = engine_args.create_engine_config()
208
+ vllm_config.model_config.dtype = torch.bfloat16
209
+
210
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
211
+ assert isinstance(quant_config, VllmAWQConfig)
212
+ assert quant_config.vllm_config == vllm_config
213
+ assert quant_config.mesh == mesh
214
+
215
+
216
+ @pytest.mark.parametrize("model", MODELS)
217
+ @pytest.mark.parametrize(
218
+ "mesh",
219
+ [
220
+ test_utils.get_spmd_mesh(1),
221
+ # We limit device count by 2 instead of using all devices (like 8) since
222
+ # AWQ requires n_groups to be divisible by number of shards. Qwen uses
223
+ # group size of 128 and one of the layer has input size of 1536, meaning
224
+ # n_groups = 1536//128 = 12 - which is not divisible by 8.
225
+ test_utils.get_spmd_mesh(min(jax.local_device_count(), 2))
226
+ ])
227
+ def test_loading_model(model, mesh):
228
+ engine_args = EngineArgs(
229
+ model=model,
230
+ max_model_len=64,
231
+ max_num_batched_tokens=64,
232
+ max_num_seqs=4,
233
+ )
234
+ vllm_config = engine_args.create_engine_config()
235
+ vllm_config.model_config.dtype = torch.bfloat16
236
+ vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
237
+ vllm_config.device_config.device = "cpu"
238
+
239
+ vllm_model = vllm_get_model(vllm_config=vllm_config)
240
+ layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
241
+ for layer in layers:
242
+ assert isinstance(layer.quant_config, VllmAWQConfig)
243
+ assert isinstance(layer.quant_method, VllmAWQLinearMethod)
244
+
245
+
246
+ @pytest.mark.parametrize("model", MODELS)
247
+ @pytest.mark.parametrize("bias", [False, True])
248
+ @pytest.mark.parametrize("mesh", [
249
+ test_utils.get_spmd_mesh(1),
250
+ test_utils.get_spmd_mesh(jax.local_device_count())
251
+ ])
252
+ @pytest.mark.parametrize("enable_sp", [False, True])
253
+ def test_row_parallel_linear(model, bias, mesh, enable_sp):
254
+ dtype = torch.bfloat16
255
+
256
+ engine_args = EngineArgs(
257
+ model=model,
258
+ max_model_len=64,
259
+ max_num_batched_tokens=64,
260
+ max_num_seqs=4,
261
+ )
262
+ vllm_config = engine_args.create_engine_config()
263
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
264
+
265
+ vllm_config.model_config.dtype = dtype
266
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
267
+ with set_current_vllm_config(vllm_config):
268
+ linear_layer = RowParallelLinear(
269
+ input_size=4096,
270
+ output_size=8192,
271
+ bias=bias,
272
+ params_dtype=dtype,
273
+ return_bias=False,
274
+ quant_config=quant_config,
275
+ )
276
+
277
+ qweight, qzeros, scales, _ = initialize_and_return_layer_weights(
278
+ linear_layer)
279
+ ref_output, layer_output = return_ref_and_layer_output(
280
+ linear_layer, qweight, qzeros, scales)
281
+ torch.testing.assert_close(ref_output, layer_output)
282
+
283
+
284
+ @pytest.mark.parametrize("model", MODELS)
285
+ @pytest.mark.parametrize("bias", [False, True])
286
+ @pytest.mark.parametrize("mesh", [
287
+ test_utils.get_spmd_mesh(1),
288
+ test_utils.get_spmd_mesh(jax.local_device_count())
289
+ ])
290
+ @pytest.mark.parametrize("enable_sp", [False, True])
291
+ def test_column_parallel_linear(model, bias, mesh, enable_sp):
292
+ dtype = torch.bfloat16
293
+
294
+ engine_args = EngineArgs(
295
+ model=model,
296
+ max_model_len=64,
297
+ max_num_batched_tokens=64,
298
+ max_num_seqs=4,
299
+ )
300
+ vllm_config = engine_args.create_engine_config()
301
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
302
+
303
+ # Call tpu_inference code
304
+ vllm_config.model_config.dtype = torch.bfloat16
305
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
306
+ with set_current_vllm_config(vllm_config):
307
+ linear_layer = ColumnParallelLinear(
308
+ input_size=4096,
309
+ output_size=8192,
310
+ bias=bias,
311
+ params_dtype=dtype,
312
+ return_bias=False,
313
+ quant_config=quant_config,
314
+ )
315
+
316
+ qweight, qzeros, scales, _ = initialize_and_return_layer_weights(
317
+ linear_layer)
318
+ ref_output, layer_output = return_ref_and_layer_output(
319
+ linear_layer, qweight, qzeros, scales)
320
+ torch.testing.assert_close(ref_output, layer_output)
321
+
322
+
323
+ @pytest.mark.parametrize("model", MODELS)
324
+ @pytest.mark.parametrize("bias", [False, True])
325
+ @pytest.mark.parametrize("mesh", [
326
+ test_utils.get_spmd_mesh(1),
327
+ test_utils.get_spmd_mesh(jax.local_device_count())
328
+ ])
329
+ @pytest.mark.parametrize("enable_sp", [False, True])
330
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
331
+ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls):
332
+ dtype = torch.bfloat16
333
+
334
+ engine_args = EngineArgs(
335
+ model=model,
336
+ max_model_len=64,
337
+ max_num_batched_tokens=64,
338
+ max_num_seqs=4,
339
+ )
340
+ vllm_config = engine_args.create_engine_config()
341
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
342
+
343
+ # Call tpu_inference code
344
+ vllm_config.model_config.dtype = torch.bfloat16
345
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
346
+ with set_current_vllm_config(vllm_config):
347
+ linear_layer = QKVParallelLinear(
348
+ hidden_size=4096,
349
+ head_size=128,
350
+ total_num_heads=32,
351
+ total_num_kv_heads=8,
352
+ bias=bias,
353
+ params_dtype=dtype,
354
+ return_bias=False,
355
+ quant_config=quant_config,
356
+ )
357
+ linear_layer.quant_method.fuse_matmuls = fuse_matmuls
358
+
359
+ qweight, qzeros, scales, _ = initialize_and_return_layer_weights(
360
+ linear_layer)
361
+ ref_output, layer_output = return_ref_and_layer_output(
362
+ linear_layer, qweight, qzeros, scales)
363
+ torch.testing.assert_close(ref_output, layer_output)
364
+
365
+
366
+ @pytest.mark.parametrize("model", MODELS)
367
+ @pytest.mark.parametrize("bias", [False, True])
368
+ @pytest.mark.parametrize("mesh", [
369
+ test_utils.get_spmd_mesh(1),
370
+ test_utils.get_spmd_mesh(jax.local_device_count())
371
+ ])
372
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
373
+ @pytest.mark.parametrize("enable_sp", [False, True])
374
+ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
375
+ enable_sp):
376
+ dtype = torch.bfloat16
377
+
378
+ engine_args = EngineArgs(
379
+ model=model,
380
+ max_model_len=64,
381
+ max_num_batched_tokens=64,
382
+ max_num_seqs=4,
383
+ )
384
+ vllm_config = engine_args.create_engine_config()
385
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
386
+
387
+ # Call tpu_inference code
388
+ vllm_config.model_config.dtype = torch.bfloat16
389
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
390
+ with set_current_vllm_config(vllm_config):
391
+ linear_layer = MergedColumnParallelLinear(
392
+ input_size=4096,
393
+ output_sizes=[14336] * 2,
394
+ bias=bias,
395
+ params_dtype=dtype,
396
+ return_bias=False,
397
+ quant_config=quant_config,
398
+ )
399
+ linear_layer.quant_method.fuse_matmuls = fuse_matmuls
400
+
401
+ qweight, qzeros, scales, _ = initialize_and_return_layer_weights(
402
+ linear_layer)
403
+ ref_output, layer_output = return_ref_and_layer_output(
404
+ linear_layer, qweight, qzeros, scales)
405
+ torch.testing.assert_close(ref_output, layer_output)
@@ -0,0 +1,202 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import tempfile
17
+
18
+ import jax.numpy as jnp
19
+ import pytest
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import torchax
23
+ from compressed_tensors.quantization import QuantizationArgs
24
+ from jax.sharding import PartitionSpec
25
+ from vllm.config import set_current_vllm_config
26
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
27
+ init_distributed_environment)
28
+ from vllm.engine.arg_utils import EngineArgs
29
+ from vllm.model_executor.layers.fused_moe import FusedMoE
30
+ # yapf: disable
31
+ from vllm.model_executor.layers.fused_moe.config import (
32
+ FusedMoEConfig, FusedMoEParallelConfig)
33
+
34
+ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
35
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
36
+ VllmCompressedTensorsConfig
37
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
38
+ VllmCompressedTensorsW8A8Fp8MoEMethod
39
+
40
+ from . import utils as test_utils
41
+
42
+ # yapf: enable
43
+
44
+ P = PartitionSpec
45
+
46
+ os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
47
+
48
+ MODEL = 'BCCard/Qwen3-30B-A3B-FP8-Dynamic'
49
+
50
+
51
+ @pytest.fixture(autouse=True)
52
+ def setup_environment():
53
+ # This is a fake config used for init dist env.
54
+ # RowParallelLinear needs dist env to be initialized.
55
+ engine_args = EngineArgs(
56
+ model=MODEL,
57
+ max_model_len=64,
58
+ max_num_batched_tokens=64,
59
+ max_num_seqs=4,
60
+ )
61
+
62
+ vllm_config = engine_args.create_engine_config()
63
+
64
+ with set_current_vllm_config(vllm_config):
65
+ temp_file = tempfile.mkstemp()[1]
66
+ init_distributed_environment(
67
+ 1,
68
+ 0,
69
+ local_rank=0,
70
+ distributed_init_method=f"file://{temp_file}",
71
+ backend="gloo")
72
+ ensure_model_parallel_initialized(1, 1)
73
+
74
+
75
+ def _ref_math_in_bf16(w1, w2, w3, x, router_logits, top_k):
76
+ seqlen = x.shape[0]
77
+ expert_weights = F.softmax(router_logits, dim=-1)
78
+ expert_weights, expert_indices = torch.topk(expert_weights, top_k, dim=-1)
79
+ expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
80
+
81
+ # cond ffn
82
+ # e = total num of exp = 160
83
+ # t = seqlen
84
+ # o = config.imtermediate size
85
+ # i = config.dim
86
+ x1 = torch.einsum("ti, eoi -> teo", x, w1)
87
+ x1 = F.silu(x1)
88
+ x3 = torch.einsum("ti, eoi -> teo", x, w3)
89
+ expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), w2)
90
+
91
+ seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
92
+ expert_outs = expert_outs[seq_indexes, expert_indices]
93
+ out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
94
+ return out
95
+
96
+
97
+ @pytest.mark.parametrize(
98
+ "mesh", [test_utils.get_spmd_mesh(1),
99
+ test_utils.get_spmd_mesh(2)])
100
+ @pytest.mark.parametrize("num_tokens", [8])
101
+ @pytest.mark.parametrize("intermediate_size", [1024])
102
+ @pytest.mark.parametrize("hidden_size", [128])
103
+ @pytest.mark.parametrize("num_experts", [8])
104
+ @pytest.mark.parametrize("topk", [2])
105
+ @pytest.mark.parametrize("use_ep", [True, False])
106
+ def test_fused_moe_method(mesh, num_tokens, intermediate_size, hidden_size,
107
+ num_experts, topk, use_ep):
108
+ engine_args = EngineArgs(
109
+ model=MODEL,
110
+ max_model_len=64,
111
+ max_num_batched_tokens=64,
112
+ max_num_seqs=4,
113
+ )
114
+ vllm_config = engine_args.create_engine_config()
115
+ vllm_config.compilation_config.pass_config.enable_sp = False
116
+
117
+ # Call tpu_inference code
118
+ vllm_config.model_config.dtype = torch.bfloat16
119
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
120
+
121
+ with set_current_vllm_config(vllm_config):
122
+ layer = FusedMoE(num_experts=num_experts,
123
+ top_k=topk,
124
+ hidden_size=hidden_size,
125
+ intermediate_size=intermediate_size)
126
+ quant_config = VllmCompressedTensorsConfig(
127
+ target_scheme_map={
128
+ 'Linear': {
129
+ 'weights':
130
+ QuantizationArgs(num_bits=8,
131
+ type='float',
132
+ symmetric=True,
133
+ group_size=None,
134
+ strategy='channel',
135
+ block_structure=None,
136
+ dynamic=False,
137
+ actorder=None,
138
+ observer='minmax',
139
+ observer_kwargs={}),
140
+ 'input_activations':
141
+ QuantizationArgs(num_bits=8,
142
+ type='float',
143
+ symmetric=True,
144
+ group_size=None,
145
+ strategy='token',
146
+ block_structure=None,
147
+ dynamic=True,
148
+ actorder=None,
149
+ observer=None,
150
+ observer_kwargs={}),
151
+ 'format':
152
+ None
153
+ }
154
+ },
155
+ ignore=[],
156
+ quant_format='compressed-tensors',
157
+ sparsity_scheme_map={},
158
+ sparsity_ignore_list=[],
159
+ )
160
+ moe = FusedMoEConfig(
161
+ num_experts=num_experts,
162
+ experts_per_token=topk,
163
+ hidden_dim=hidden_size,
164
+ num_local_experts=num_experts,
165
+ moe_parallel_config=FusedMoEParallelConfig(
166
+ tp_size=1,
167
+ dp_size=1,
168
+ ep_size=1,
169
+ tp_rank=0,
170
+ dp_rank=0,
171
+ ep_rank=0,
172
+ use_ep=use_ep,
173
+ all2all_backend='',
174
+ ),
175
+ in_dtype=torch.bfloat16,
176
+ )
177
+ method = VllmCompressedTensorsW8A8Fp8MoEMethod(quant_config, moe, mesh)
178
+ method.create_weights(layer,
179
+ num_experts,
180
+ hidden_size,
181
+ intermediate_size,
182
+ params_dtype=torch.float8_e4m3fn)
183
+ method.process_weights_after_loading(layer)
184
+
185
+ seqlen = num_tokens
186
+ with torchax.default_env():
187
+ x = torch.ones((seqlen, hidden_size), dtype=torch.bfloat16).to('jax')
188
+ router_logits = torch.randn((seqlen, num_experts),
189
+ dtype=torch.bfloat16).to('jax')
190
+ result = method.apply(layer,
191
+ x,
192
+ router_logits,
193
+ top_k=topk,
194
+ renormalize=True)
195
+
196
+ result_reference = _ref_math_in_bf16(
197
+ layer.w13_weight.to(torch.bfloat16) * layer.w13_weight_scale,
198
+ layer.w2_weight.to(torch.bfloat16) * layer.w2_weight_scale,
199
+ layer.w3_weight.to(torch.bfloat16) * layer.w3_weight_scale, x,
200
+ router_logits, topk)
201
+
202
+ assert jnp.allclose(result.jax(), result_reference.jax())