tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,406 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+ from typing import Optional
17
+
18
+ import jax
19
+ import pytest
20
+ import torch
21
+ import torchax
22
+ from jax.sharding import PartitionSpec
23
+ from torchax.interop import torch_view
24
+ from torchax.ops.mappings import j2t, t2j
25
+ from vllm.config import set_current_vllm_config
26
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
27
+ init_distributed_environment)
28
+ from vllm.engine.arg_utils import EngineArgs
29
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
30
+ LinearBase,
31
+ MergedColumnParallelLinear,
32
+ QKVParallelLinear,
33
+ RowParallelLinear)
34
+ from vllm.model_executor.layers.quantization.utils.quant_utils import \
35
+ pack_quantized_values_into_int32
36
+ from vllm.model_executor.model_loader import get_model as vllm_get_model
37
+ from vllm.scalar_type import scalar_types
38
+
39
+ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
40
+ from tpu_inference.layers.vllm.quantization.awq import (VllmAWQConfig,
41
+ VllmAWQLinearMethod)
42
+ from tpu_inference.layers.vllm.quantization.configs import \
43
+ VllmQuantLinearConfig
44
+
45
+ from . import utils as test_utils
46
+
47
+ P = PartitionSpec
48
+ MODELS = ["Qwen/Qwen2.5-1.5B-Instruct-AWQ"]
49
+
50
+
51
+ def ref_quantize_uint4(x: torch.Tensor, group_size: int):
52
+ uint4_max = 15
53
+
54
+ # For group quantization, we reshape so that x[0], x[1], ... x[i] are
55
+ # quantized with different scale values.
56
+ x = torch.reshape(x, (-1, group_size) + (x.shape[1:]))
57
+
58
+ # Equation for asymmetric quantization is x_q = (x + x_z) / scale where
59
+ # x_z is calculated to ensure x + x_z does not contain any negative values.
60
+ offset = torch.clamp(-torch.amin(x, dim=1, keepdim=True), min=0)
61
+ x += offset
62
+ # After adding offset, x will not contain any negative values.
63
+ assert x.min() >= 0
64
+
65
+ x_abs_max = torch.amax(x, dim=1, keepdim=True)
66
+ x_s = x_abs_max / uint4_max
67
+ # torch does not support uint4, therefore, we cast to int32 instead.
68
+ x_q = torch.clip(x / x_s, 0, uint4_max).to(torch.int32)
69
+ x_z = torch.clip(offset / x_s, 0, uint4_max).to(torch.int32)
70
+ return x_q, x_z, x_s.to(torch.float32)
71
+
72
+
73
+ def ref_w4a16(x: torch.Tensor, w_q: torch.Tensor, w_z: torch.Tensor,
74
+ w_s: torch.Tensor, b: Optional[torch.Tensor]):
75
+ # Dequantize asymetric quantized weight.
76
+ w = (w_q.to(torch.float32) - w_z.to(torch.float32)) * w_s
77
+ w = w.reshape((-1, w.shape[-1]))
78
+ out = torch.einsum('bd,df->bf', x.to(torch.float32), w)
79
+ if b is not None:
80
+ out += b
81
+ return out.to(x.dtype)
82
+
83
+
84
+ def pack_awq_weight_into_int32(weight: torch.Tensor):
85
+ # AWQ packs 8 uint4 into 32-bits in this order.
86
+ awq_order = (0, 2, 4, 6, 1, 3, 5, 7)
87
+
88
+ orig_shape = weight.shape
89
+ weight = weight.reshape(orig_shape[:-1] + (-1, 8))
90
+ weight = weight[..., awq_order].reshape(orig_shape)
91
+
92
+ return pack_quantized_values_into_int32(weight, scalar_types.uint4, 1)
93
+
94
+
95
+ def return_ref_and_layer_output(
96
+ layer: torch.nn.Module,
97
+ qweight: torch.Tensor,
98
+ qzeros: torch.Tensor,
99
+ scales: torch.Tensor,
100
+ batch_size: int = 16,
101
+ ):
102
+ assert isinstance(layer, LinearBase)
103
+ quant_method = layer.quant_method
104
+ assert isinstance(quant_method, VllmAWQLinearMethod)
105
+ quant_config = quant_method.quant_config
106
+ assert isinstance(quant_config, VllmAWQConfig)
107
+ jax_config = quant_method.linear_config
108
+ assert isinstance(jax_config, VllmQuantLinearConfig)
109
+
110
+ input_tensor = torch.rand(
111
+ batch_size, layer.input_size, dtype=torch.bfloat16) / 10
112
+ input_tensor = input_tensor.to('cpu')
113
+
114
+ ref_output = ref_w4a16(
115
+ input_tensor,
116
+ qweight,
117
+ qzeros,
118
+ scales,
119
+ layer.bias,
120
+ )
121
+
122
+ # Run torchax/jax function
123
+ quant_method.process_weights_after_loading(layer)
124
+ with torchax.default_env():
125
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
126
+ layer_output = layer(jax_input_tensor)
127
+ layer_output = j2t(layer_output.to(torch.float32)).to(torch.bfloat16)
128
+
129
+ return ref_output, layer_output
130
+
131
+
132
+ def initialize_and_return_layer_weights(layer: torch.nn.Module):
133
+ assert isinstance(layer, LinearBase)
134
+ quant_method = layer.quant_method
135
+ assert isinstance(quant_method, VllmAWQLinearMethod)
136
+ quant_config = quant_method.quant_config
137
+ assert isinstance(quant_config, VllmAWQConfig)
138
+ jax_config = quant_method.linear_config
139
+ assert isinstance(jax_config, VllmQuantLinearConfig)
140
+
141
+ # torch.rand returns value in the range of [0, 1). We subtract by 0.2 to
142
+ # simulate asymmetry
143
+ weight = torch.rand((layer.input_size, layer.output_size)) - 0.2
144
+ qweight, qzeros, scales = ref_quantize_uint4(weight,
145
+ quant_config.group_size)
146
+
147
+ # We modify uint4 quantized weights into AWQ format.
148
+ layer_qweight = qweight.reshape((-1, layer.output_size))
149
+ layer_qzeros = qzeros.reshape((-1, layer.output_size))
150
+ layer_scales = scales.reshape((-1, layer.output_size))
151
+
152
+ layer_qweight = pack_awq_weight_into_int32(layer_qweight)
153
+ layer_qzeros = pack_awq_weight_into_int32(layer_qzeros)
154
+
155
+ assert layer.qweight.data.shape == layer_qweight.shape
156
+ assert layer.qzeros.data.shape == layer_qzeros.shape
157
+ assert layer.scales.data.shape == layer_scales.shape
158
+
159
+ layer.qweight.data = layer_qweight
160
+ layer.qzeros.data = layer_qzeros
161
+ layer.scales.data = layer_scales
162
+
163
+ bias = None
164
+ if layer.bias is not None:
165
+ bias = torch.rand_like(layer.bias.data)
166
+ layer.bias.data = bias
167
+
168
+ return qweight, qzeros, scales, bias
169
+
170
+
171
+ @pytest.fixture(autouse=True)
172
+ def setup_environment():
173
+ # This is a fake config used for init dist env.
174
+ # RowParallelLinear needs dist env to be initialized.
175
+ engine_args = EngineArgs(
176
+ model=MODELS[0],
177
+ max_model_len=64,
178
+ max_num_batched_tokens=64,
179
+ max_num_seqs=4,
180
+ )
181
+
182
+ vllm_config = engine_args.create_engine_config()
183
+
184
+ with set_current_vllm_config(vllm_config):
185
+ temp_file = tempfile.mkstemp()[1]
186
+ init_distributed_environment(
187
+ 1,
188
+ 0,
189
+ local_rank=0,
190
+ distributed_init_method=f"file://{temp_file}",
191
+ backend="gloo")
192
+ ensure_model_parallel_initialized(1, 1)
193
+
194
+
195
+ @pytest.mark.parametrize("model", MODELS)
196
+ @pytest.mark.parametrize("mesh", [
197
+ test_utils.get_spmd_mesh(1),
198
+ test_utils.get_spmd_mesh(jax.local_device_count())
199
+ ])
200
+ def test_quant_override(model, mesh):
201
+
202
+ engine_args = EngineArgs(
203
+ model=model,
204
+ max_model_len=64,
205
+ max_num_batched_tokens=64,
206
+ max_num_seqs=4,
207
+ )
208
+ vllm_config = engine_args.create_engine_config()
209
+ vllm_config.model_config.dtype = torch.bfloat16
210
+
211
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
212
+ assert isinstance(quant_config, VllmAWQConfig)
213
+ assert quant_config.vllm_config == vllm_config
214
+ assert quant_config.mesh == mesh
215
+
216
+
217
+ @pytest.mark.parametrize("model", MODELS)
218
+ @pytest.mark.parametrize(
219
+ "mesh",
220
+ [
221
+ test_utils.get_spmd_mesh(1),
222
+ # We limit device count by 2 instead of using all devices (like 8) since
223
+ # AWQ requires n_groups to be divisible by number of shards. Qwen uses
224
+ # group size of 128 and one of the layer has input size of 1536, meaning
225
+ # n_groups = 1536//128 = 12 - which is not divisible by 8.
226
+ test_utils.get_spmd_mesh(min(jax.local_device_count(), 2))
227
+ ])
228
+ def test_loading_model(model, mesh):
229
+ engine_args = EngineArgs(
230
+ model=model,
231
+ max_model_len=64,
232
+ max_num_batched_tokens=64,
233
+ max_num_seqs=4,
234
+ )
235
+ vllm_config = engine_args.create_engine_config()
236
+ vllm_config.model_config.dtype = torch.bfloat16
237
+ vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
238
+ vllm_config.device_config.device = "cpu"
239
+
240
+ vllm_model = vllm_get_model(vllm_config=vllm_config)
241
+ layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
242
+ for layer in layers:
243
+ assert isinstance(layer.quant_config, VllmAWQConfig)
244
+ assert isinstance(layer.quant_method, VllmAWQLinearMethod)
245
+
246
+
247
+ @pytest.mark.parametrize("model", MODELS)
248
+ @pytest.mark.parametrize("bias", [False, True])
249
+ @pytest.mark.parametrize("mesh", [
250
+ test_utils.get_spmd_mesh(1),
251
+ test_utils.get_spmd_mesh(jax.local_device_count())
252
+ ])
253
+ @pytest.mark.parametrize("enable_sp", [False, True])
254
+ def test_row_parallel_linear(model, bias, mesh, enable_sp):
255
+ dtype = torch.bfloat16
256
+
257
+ engine_args = EngineArgs(
258
+ model=model,
259
+ max_model_len=64,
260
+ max_num_batched_tokens=64,
261
+ max_num_seqs=4,
262
+ )
263
+ vllm_config = engine_args.create_engine_config()
264
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
265
+
266
+ vllm_config.model_config.dtype = dtype
267
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
268
+ with set_current_vllm_config(vllm_config):
269
+ linear_layer = RowParallelLinear(
270
+ input_size=4096,
271
+ output_size=8192,
272
+ bias=bias,
273
+ params_dtype=dtype,
274
+ return_bias=False,
275
+ quant_config=quant_config,
276
+ )
277
+
278
+ qweight, qzeros, scales, _ = initialize_and_return_layer_weights(
279
+ linear_layer)
280
+ ref_output, layer_output = return_ref_and_layer_output(
281
+ linear_layer, qweight, qzeros, scales)
282
+ torch.testing.assert_close(ref_output, layer_output)
283
+
284
+
285
+ @pytest.mark.parametrize("model", MODELS)
286
+ @pytest.mark.parametrize("bias", [False, True])
287
+ @pytest.mark.parametrize("mesh", [
288
+ test_utils.get_spmd_mesh(1),
289
+ test_utils.get_spmd_mesh(jax.local_device_count())
290
+ ])
291
+ @pytest.mark.parametrize("enable_sp", [False, True])
292
+ def test_column_parallel_linear(model, bias, mesh, enable_sp):
293
+ dtype = torch.bfloat16
294
+
295
+ engine_args = EngineArgs(
296
+ model=model,
297
+ max_model_len=64,
298
+ max_num_batched_tokens=64,
299
+ max_num_seqs=4,
300
+ )
301
+ vllm_config = engine_args.create_engine_config()
302
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
303
+
304
+ # Call tpu_inference code
305
+ vllm_config.model_config.dtype = torch.bfloat16
306
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
307
+ with set_current_vllm_config(vllm_config):
308
+ linear_layer = ColumnParallelLinear(
309
+ input_size=4096,
310
+ output_size=8192,
311
+ bias=bias,
312
+ params_dtype=dtype,
313
+ return_bias=False,
314
+ quant_config=quant_config,
315
+ )
316
+
317
+ qweight, qzeros, scales, _ = initialize_and_return_layer_weights(
318
+ linear_layer)
319
+ ref_output, layer_output = return_ref_and_layer_output(
320
+ linear_layer, qweight, qzeros, scales)
321
+ torch.testing.assert_close(ref_output, layer_output)
322
+
323
+
324
+ @pytest.mark.parametrize("model", MODELS)
325
+ @pytest.mark.parametrize("bias", [False, True])
326
+ @pytest.mark.parametrize("mesh", [
327
+ test_utils.get_spmd_mesh(1),
328
+ test_utils.get_spmd_mesh(jax.local_device_count())
329
+ ])
330
+ @pytest.mark.parametrize("enable_sp", [False, True])
331
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
332
+ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls):
333
+ dtype = torch.bfloat16
334
+
335
+ engine_args = EngineArgs(
336
+ model=model,
337
+ max_model_len=64,
338
+ max_num_batched_tokens=64,
339
+ max_num_seqs=4,
340
+ )
341
+ vllm_config = engine_args.create_engine_config()
342
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
343
+
344
+ # Call tpu_inference code
345
+ vllm_config.model_config.dtype = torch.bfloat16
346
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
347
+ with set_current_vllm_config(vllm_config):
348
+ linear_layer = QKVParallelLinear(
349
+ hidden_size=4096,
350
+ head_size=128,
351
+ total_num_heads=32,
352
+ total_num_kv_heads=8,
353
+ bias=bias,
354
+ params_dtype=dtype,
355
+ return_bias=False,
356
+ quant_config=quant_config,
357
+ )
358
+ linear_layer.quant_method.fuse_matmuls = fuse_matmuls
359
+
360
+ qweight, qzeros, scales, _ = initialize_and_return_layer_weights(
361
+ linear_layer)
362
+ ref_output, layer_output = return_ref_and_layer_output(
363
+ linear_layer, qweight, qzeros, scales)
364
+ torch.testing.assert_close(ref_output, layer_output)
365
+
366
+
367
+ @pytest.mark.parametrize("model", MODELS)
368
+ @pytest.mark.parametrize("bias", [False, True])
369
+ @pytest.mark.parametrize("mesh", [
370
+ test_utils.get_spmd_mesh(1),
371
+ test_utils.get_spmd_mesh(jax.local_device_count())
372
+ ])
373
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
374
+ @pytest.mark.parametrize("enable_sp", [False, True])
375
+ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
376
+ enable_sp):
377
+ dtype = torch.bfloat16
378
+
379
+ engine_args = EngineArgs(
380
+ model=model,
381
+ max_model_len=64,
382
+ max_num_batched_tokens=64,
383
+ max_num_seqs=4,
384
+ )
385
+ vllm_config = engine_args.create_engine_config()
386
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
387
+
388
+ # Call tpu_inference code
389
+ vllm_config.model_config.dtype = torch.bfloat16
390
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
391
+ with set_current_vllm_config(vllm_config):
392
+ linear_layer = MergedColumnParallelLinear(
393
+ input_size=4096,
394
+ output_sizes=[14336] * 2,
395
+ bias=bias,
396
+ params_dtype=dtype,
397
+ return_bias=False,
398
+ quant_config=quant_config,
399
+ )
400
+ linear_layer.quant_method.fuse_matmuls = fuse_matmuls
401
+
402
+ qweight, qzeros, scales, _ = initialize_and_return_layer_weights(
403
+ linear_layer)
404
+ ref_output, layer_output = return_ref_and_layer_output(
405
+ linear_layer, qweight, qzeros, scales)
406
+ torch.testing.assert_close(ref_output, layer_output)
@@ -0,0 +1,199 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+
17
+ import jax.numpy as jnp
18
+ import pytest
19
+ import torch
20
+ import torch.nn.functional as F
21
+ import torchax
22
+ from compressed_tensors.quantization import QuantizationArgs
23
+ from jax.sharding import PartitionSpec
24
+ from vllm.config import set_current_vllm_config
25
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
26
+ init_distributed_environment)
27
+ from vllm.engine.arg_utils import EngineArgs
28
+ from vllm.model_executor.layers.fused_moe import FusedMoE
29
+ # yapf: disable
30
+ from vllm.model_executor.layers.fused_moe.config import (
31
+ FusedMoEConfig, FusedMoEParallelConfig)
32
+
33
+ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
34
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
35
+ VllmCompressedTensorsConfig
36
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
37
+ VllmCompressedTensorsW8A8Fp8MoEMethod
38
+
39
+ from . import utils as test_utils
40
+
41
+ # yapf: enable
42
+
43
+ P = PartitionSpec
44
+
45
+ MODEL = 'BCCard/Qwen3-30B-A3B-FP8-Dynamic'
46
+
47
+
48
+ @pytest.fixture(autouse=True)
49
+ def setup_environment():
50
+ # This is a fake config used for init dist env.
51
+ # RowParallelLinear needs dist env to be initialized.
52
+ engine_args = EngineArgs(
53
+ model=MODEL,
54
+ max_model_len=64,
55
+ max_num_batched_tokens=64,
56
+ max_num_seqs=4,
57
+ )
58
+
59
+ vllm_config = engine_args.create_engine_config()
60
+
61
+ with set_current_vllm_config(vllm_config):
62
+ temp_file = tempfile.mkstemp()[1]
63
+ init_distributed_environment(
64
+ 1,
65
+ 0,
66
+ local_rank=0,
67
+ distributed_init_method=f"file://{temp_file}",
68
+ backend="gloo")
69
+ ensure_model_parallel_initialized(1, 1)
70
+
71
+
72
+ def _ref_math_in_bf16(w1, w2, w3, x, router_logits, top_k):
73
+ seqlen = x.shape[0]
74
+ expert_weights = F.softmax(router_logits, dim=-1)
75
+ expert_weights, expert_indices = torch.topk(expert_weights, top_k, dim=-1)
76
+ expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
77
+
78
+ # cond ffn
79
+ # e = total num of exp = 160
80
+ # t = seqlen
81
+ # o = config.imtermediate size
82
+ # i = config.dim
83
+ x1 = torch.einsum("ti, eoi -> teo", x, w1)
84
+ x1 = F.silu(x1)
85
+ x3 = torch.einsum("ti, eoi -> teo", x, w3)
86
+ expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), w2)
87
+
88
+ seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
89
+ expert_outs = expert_outs[seq_indexes, expert_indices]
90
+ out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
91
+ return out
92
+
93
+
94
+ @pytest.mark.parametrize(
95
+ "mesh", [test_utils.get_spmd_mesh(1),
96
+ test_utils.get_spmd_mesh(2)])
97
+ @pytest.mark.parametrize("num_tokens", [8])
98
+ @pytest.mark.parametrize("intermediate_size", [1024])
99
+ @pytest.mark.parametrize("hidden_size", [128])
100
+ @pytest.mark.parametrize("num_experts", [8])
101
+ @pytest.mark.parametrize("topk", [2])
102
+ @pytest.mark.parametrize("use_ep", [True, False])
103
+ def test_fused_moe_method(mesh, num_tokens, intermediate_size, hidden_size,
104
+ num_experts, topk, use_ep):
105
+ engine_args = EngineArgs(
106
+ model=MODEL,
107
+ max_model_len=64,
108
+ max_num_batched_tokens=64,
109
+ max_num_seqs=4,
110
+ )
111
+ vllm_config = engine_args.create_engine_config()
112
+ vllm_config.compilation_config.pass_config.enable_sp = False
113
+
114
+ # Call tpu_inference code
115
+ vllm_config.model_config.dtype = torch.bfloat16
116
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
117
+
118
+ with set_current_vllm_config(vllm_config):
119
+ layer = FusedMoE(num_experts=num_experts,
120
+ top_k=topk,
121
+ hidden_size=hidden_size,
122
+ intermediate_size=intermediate_size)
123
+ quant_config = VllmCompressedTensorsConfig(
124
+ target_scheme_map={
125
+ 'Linear': {
126
+ 'weights':
127
+ QuantizationArgs(num_bits=8,
128
+ type='float',
129
+ symmetric=True,
130
+ group_size=None,
131
+ strategy='channel',
132
+ block_structure=None,
133
+ dynamic=False,
134
+ actorder=None,
135
+ observer='minmax',
136
+ observer_kwargs={}),
137
+ 'input_activations':
138
+ QuantizationArgs(num_bits=8,
139
+ type='float',
140
+ symmetric=True,
141
+ group_size=None,
142
+ strategy='token',
143
+ block_structure=None,
144
+ dynamic=True,
145
+ actorder=None,
146
+ observer=None,
147
+ observer_kwargs={}),
148
+ 'format':
149
+ None
150
+ }
151
+ },
152
+ ignore=[],
153
+ quant_format='compressed-tensors',
154
+ sparsity_scheme_map={},
155
+ sparsity_ignore_list=[],
156
+ )
157
+ moe = FusedMoEConfig(
158
+ num_experts=num_experts,
159
+ experts_per_token=topk,
160
+ hidden_dim=hidden_size,
161
+ num_local_experts=num_experts,
162
+ moe_parallel_config=FusedMoEParallelConfig(
163
+ tp_size=1,
164
+ dp_size=1,
165
+ ep_size=1,
166
+ tp_rank=0,
167
+ dp_rank=0,
168
+ ep_rank=0,
169
+ use_ep=use_ep,
170
+ all2all_backend='',
171
+ ),
172
+ in_dtype=torch.bfloat16,
173
+ )
174
+ method = VllmCompressedTensorsW8A8Fp8MoEMethod(quant_config, moe, mesh)
175
+ method.create_weights(layer,
176
+ num_experts,
177
+ hidden_size,
178
+ intermediate_size,
179
+ params_dtype=torch.float8_e4m3fn)
180
+ method.process_weights_after_loading(layer)
181
+
182
+ seqlen = num_tokens
183
+ with torchax.default_env():
184
+ x = torch.ones((seqlen, hidden_size), dtype=torch.bfloat16).to('jax')
185
+ router_logits = torch.randn((seqlen, num_experts),
186
+ dtype=torch.bfloat16).to('jax')
187
+ result = method.apply(layer,
188
+ x,
189
+ router_logits,
190
+ top_k=topk,
191
+ renormalize=True)
192
+
193
+ result_reference = _ref_math_in_bf16(
194
+ layer.w13_weight.to(torch.bfloat16) * layer.w13_weight_scale,
195
+ layer.w2_weight.to(torch.bfloat16) * layer.w2_weight_scale,
196
+ layer.w3_weight.to(torch.bfloat16) * layer.w3_weight_scale, x,
197
+ router_logits, topk)
198
+
199
+ assert jnp.allclose(result.jax(), result_reference.jax())