tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,441 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+ from typing import Optional
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import pytest
21
+ import torch
22
+ import torchax
23
+ from compressed_tensors.quantization import QuantizationStrategy
24
+ from jax.sharding import PartitionSpec
25
+ from torchax.interop import torch_view
26
+ from torchax.ops.mappings import j2t, t2j
27
+ from vllm.config import set_current_vllm_config
28
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
29
+ init_distributed_environment)
30
+ from vllm.engine.arg_utils import EngineArgs
31
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
32
+ LinearBase,
33
+ MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear)
36
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \
37
+ CompressedTensorsLinearMethod
38
+ from vllm.model_executor.model_loader import get_model as vllm_get_model
39
+
40
+ from tpu_inference.layers.common.quantization import (dequantize_tensor,
41
+ quantize_tensor)
42
+ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
43
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
44
+ VllmCompressedTensorsConfig
45
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
46
+ VllmCompressedTensorsW8A8Fp8
47
+ from tpu_inference.layers.vllm.quantization.configs import \
48
+ VllmQuantLinearConfig
49
+
50
+ from . import utils as test_utils
51
+
52
+ P = PartitionSpec
53
+ MODELS = [
54
+ "RedHatAI/Llama-3.2-1B-Instruct-FP8-dynamic",
55
+ "RedHatAI/Llama-3.2-1B-Instruct-FP8"
56
+ ]
57
+
58
+
59
+ def ref_quantize_fp8(x: torch.Tensor,
60
+ dtype: torch.dtype,
61
+ per_tensor: bool = False):
62
+ dtype_info = torch.finfo(dtype)
63
+ dtype_max = float(dtype_info.max)
64
+ dtype_min = float(dtype_info.min)
65
+
66
+ dim = () if per_tensor else 1
67
+ x_abs_max = torch.amax(torch.abs(x), dim=dim, keepdim=True)
68
+ if per_tensor:
69
+ x_abs_max = torch.squeeze(x_abs_max, dim=-1)
70
+ x_s = x_abs_max / dtype_max
71
+ x_q = torch.clip(x / x_s, dtype_min, dtype_max).to(dtype)
72
+ return x_q, x_s.to(torch.float32)
73
+
74
+
75
+ def ref_w8a8_fp8_dynamic(x: torch.Tensor, w_q: torch.Tensor, w_s: torch.Tensor,
76
+ b: Optional[torch.Tensor]):
77
+ x_q, x_s = ref_quantize_fp8(x, w_q.dtype)
78
+ out = torch.einsum('bd,fd->bf', x_q.to(torch.float32),
79
+ w_q.to(torch.float32))
80
+ out = (out * x_s) * w_s.T
81
+ if b is not None:
82
+ out += b
83
+ return out.to(x.dtype)
84
+
85
+
86
+ def ref_w8a8_fp8_static(x: torch.Tensor, x_s: torch.Tensor, w_q: torch.Tensor,
87
+ w_s: torch.Tensor, b: Optional[torch.Tensor]):
88
+ dtype_info = torch.finfo(w_q.dtype)
89
+ dtype_max = float(dtype_info.max)
90
+ dtype_min = float(dtype_info.min)
91
+
92
+ x_q = torch.clamp(x / x_s, dtype_min, dtype_max).to(w_q.dtype)
93
+ out = torch.einsum('bd,fd->bf', x_q.to(torch.float32),
94
+ w_q.to(torch.float32))
95
+ out = (out * x_s) * w_s.T
96
+ if b is not None:
97
+ out += b
98
+ return out.to(x.dtype)
99
+
100
+
101
+ def return_ref_and_layer_output(layer: torch.nn.Module, batch_size: int = 16):
102
+ assert isinstance(layer, LinearBase)
103
+ scheme = layer.scheme
104
+ assert isinstance(scheme, VllmCompressedTensorsW8A8Fp8)
105
+ quant_config = scheme.linear_config
106
+ assert isinstance(quant_config, VllmQuantLinearConfig)
107
+ quant_method = layer.quant_method
108
+ assert isinstance(quant_method, CompressedTensorsLinearMethod)
109
+ per_tensor = scheme.strategy == QuantizationStrategy.TENSOR
110
+ is_static_input_scheme = scheme.is_static_input_scheme
111
+
112
+ input_tensor = torch.rand(
113
+ batch_size, layer.input_size, dtype=torch.bfloat16) / 10
114
+ input_tensor = input_tensor.to('cpu')
115
+
116
+ weight_scale, weight = layer.weight_scale, layer.weight
117
+ input_scale = getattr(layer, 'input_scale', None)
118
+ # For per_tensor with merged layers, vLLM requenzites them so all merged
119
+ # layers shared the same scale values.
120
+ if per_tensor:
121
+ dtype = weight.dtype
122
+
123
+ weight = t2j(weight)
124
+ weight_scale = t2j(weight_scale)
125
+ weights = []
126
+ start = 0
127
+ # Multiple weights may have been concatenated. Loop through
128
+ # each weight and perform dequantization.
129
+ for i, output_size in enumerate(quant_config.output_sizes):
130
+ end = start + output_size
131
+ weights.append(
132
+ dequantize_tensor(weight[start:end], weight_scale[i]))
133
+ start = end
134
+ weight = jnp.concat(weights, axis=0)
135
+ weight, weight_scale = quantize_tensor(
136
+ jnp.float8_e4m3fn,
137
+ weight,
138
+ None,
139
+ )
140
+ weight = j2t(weight.astype(jnp.float32)).to(dtype)
141
+ weight_scale = j2t(weight_scale)
142
+ if input_scale is not None:
143
+ input_scale = input_scale.max()
144
+
145
+ # Run reference implementation
146
+ if is_static_input_scheme:
147
+ ref_output = ref_w8a8_fp8_static(
148
+ input_tensor,
149
+ input_scale,
150
+ weight,
151
+ weight_scale,
152
+ layer.bias,
153
+ )
154
+ else:
155
+ ref_output = ref_w8a8_fp8_dynamic(
156
+ input_tensor,
157
+ weight,
158
+ weight_scale,
159
+ layer.bias,
160
+ )
161
+
162
+ # Run torchax/jax function
163
+ with torchax.default_env():
164
+ quant_method.process_weights_after_loading(layer)
165
+
166
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
167
+ layer_output = layer(jax_input_tensor)
168
+ layer_output = j2t(layer_output.to(torch.float32)).to(torch.bfloat16)
169
+
170
+ return ref_output, layer_output
171
+
172
+
173
+ def initialize_layer_weights(layer: torch.nn.Module):
174
+ assert isinstance(layer, LinearBase)
175
+ scheme = layer.scheme
176
+ assert isinstance(scheme, VllmCompressedTensorsW8A8Fp8)
177
+ quant_config = scheme.linear_config
178
+ assert isinstance(quant_config, VllmQuantLinearConfig)
179
+ per_tensor = scheme.strategy == QuantizationStrategy.TENSOR
180
+
181
+ weight_list = []
182
+ weight_scale_list = []
183
+ for output_size in quant_config.output_sizes:
184
+ weight = torch.rand(
185
+ (output_size, layer.input_size), dtype=torch.bfloat16) / 10
186
+ weight_, weight_scale_ = ref_quantize_fp8(weight, torch.float8_e4m3fn,
187
+ per_tensor)
188
+ weight_list.append(weight_)
189
+ weight_scale_list.append(weight_scale_)
190
+
191
+ weight = torch.concatenate(weight_list)
192
+ weight_scale = torch.concatenate(weight_scale_list)
193
+
194
+ assert layer.weight.data.shape == weight.shape
195
+ assert layer.weight_scale.data.shape == weight_scale.shape
196
+
197
+ layer.weight.data = weight
198
+ layer.weight_scale.data = weight_scale
199
+
200
+ if layer.bias is not None:
201
+ layer.bias.data = torch.rand_like(layer.bias.data)
202
+
203
+
204
+ @pytest.fixture(autouse=True)
205
+ def setup_environment():
206
+ # This is a fake config used for init dist env.
207
+ # RowParallelLinear needs dist env to be initialized.
208
+ engine_args = EngineArgs(
209
+ model=MODELS[0],
210
+ max_model_len=64,
211
+ max_num_batched_tokens=64,
212
+ max_num_seqs=4,
213
+ )
214
+
215
+ vllm_config = engine_args.create_engine_config()
216
+
217
+ with set_current_vllm_config(vllm_config):
218
+ temp_file = tempfile.mkstemp()[1]
219
+ init_distributed_environment(
220
+ 1,
221
+ 0,
222
+ local_rank=0,
223
+ distributed_init_method=f"file://{temp_file}",
224
+ backend="gloo")
225
+ ensure_model_parallel_initialized(1, 1)
226
+
227
+
228
+ @pytest.mark.parametrize("model", MODELS)
229
+ @pytest.mark.parametrize("mesh", [
230
+ test_utils.get_spmd_mesh(1),
231
+ test_utils.get_spmd_mesh(jax.local_device_count())
232
+ ])
233
+ def test_quant_override(model, mesh):
234
+
235
+ engine_args = EngineArgs(
236
+ model=model,
237
+ max_model_len=64,
238
+ max_num_batched_tokens=64,
239
+ max_num_seqs=4,
240
+ )
241
+ vllm_config = engine_args.create_engine_config()
242
+ vllm_config.model_config.dtype = torch.bfloat16
243
+
244
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
245
+ assert isinstance(quant_config, VllmCompressedTensorsConfig)
246
+ assert quant_config.vllm_config == vllm_config
247
+ assert quant_config.mesh == mesh
248
+
249
+
250
+ @pytest.mark.parametrize("model", MODELS)
251
+ @pytest.mark.parametrize("mesh", [
252
+ test_utils.get_spmd_mesh(1),
253
+ test_utils.get_spmd_mesh(jax.local_device_count())
254
+ ])
255
+ def test_loading_model(model, mesh):
256
+ engine_args = EngineArgs(
257
+ model=model,
258
+ max_model_len=64,
259
+ max_num_batched_tokens=64,
260
+ max_num_seqs=4,
261
+ )
262
+ vllm_config = engine_args.create_engine_config()
263
+ vllm_config.model_config.dtype = torch.bfloat16
264
+ vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
265
+ vllm_config.device_config.device = "cpu"
266
+
267
+ vllm_model = vllm_get_model(vllm_config=vllm_config)
268
+ layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
269
+ for layer in layers:
270
+ assert isinstance(layer.quant_config, VllmCompressedTensorsConfig)
271
+ assert isinstance(layer.quant_method, CompressedTensorsLinearMethod)
272
+ assert isinstance(layer.scheme, VllmCompressedTensorsW8A8Fp8)
273
+
274
+
275
+ @pytest.mark.parametrize("model", MODELS)
276
+ @pytest.mark.parametrize("bias", [False, True])
277
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
278
+ @pytest.mark.parametrize("enable_sp", [False, True])
279
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
280
+ def test_row_parallel_linear(model, bias, num_devices, enable_sp,
281
+ enable_attn_dp):
282
+ # Skip if enable_attn_dp is True but we don't have enough devices
283
+ if enable_attn_dp and num_devices < 2:
284
+ pytest.skip("enable_attn_dp requires at least 2 devices")
285
+
286
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
287
+ dtype = torch.bfloat16
288
+
289
+ engine_args = EngineArgs(
290
+ model=model,
291
+ max_model_len=64,
292
+ max_num_batched_tokens=64,
293
+ max_num_seqs=4,
294
+ )
295
+ vllm_config = engine_args.create_engine_config()
296
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
297
+
298
+ vllm_config.model_config.dtype = dtype
299
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
300
+ with set_current_vllm_config(vllm_config):
301
+ linear_layer = RowParallelLinear(
302
+ input_size=4096,
303
+ output_size=8192,
304
+ bias=bias,
305
+ params_dtype=dtype,
306
+ return_bias=False,
307
+ quant_config=quant_config,
308
+ )
309
+
310
+ initialize_layer_weights(linear_layer)
311
+ ref_output, layer_output = return_ref_and_layer_output(linear_layer)
312
+ torch.testing.assert_close(ref_output, layer_output)
313
+
314
+
315
+ @pytest.mark.parametrize("model", MODELS)
316
+ @pytest.mark.parametrize("bias", [False, True])
317
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
318
+ @pytest.mark.parametrize("enable_sp", [False, True])
319
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
320
+ def test_column_parallel_linear(model, bias, num_devices, enable_sp,
321
+ enable_attn_dp):
322
+ # Skip if enable_attn_dp is True but we don't have enough devices
323
+ if enable_attn_dp and num_devices < 2:
324
+ pytest.skip("enable_attn_dp requires at least 2 devices")
325
+
326
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
327
+ dtype = torch.bfloat16
328
+
329
+ engine_args = EngineArgs(
330
+ model=model,
331
+ max_model_len=64,
332
+ max_num_batched_tokens=64,
333
+ max_num_seqs=4,
334
+ )
335
+ vllm_config = engine_args.create_engine_config()
336
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
337
+
338
+ # Call tpu_inference code
339
+ vllm_config.model_config.dtype = torch.bfloat16
340
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
341
+ with set_current_vllm_config(vllm_config):
342
+ linear_layer = ColumnParallelLinear(
343
+ input_size=4096,
344
+ output_size=8192,
345
+ bias=bias,
346
+ params_dtype=dtype,
347
+ return_bias=False,
348
+ quant_config=quant_config,
349
+ )
350
+
351
+ initialize_layer_weights(linear_layer)
352
+ ref_output, layer_output = return_ref_and_layer_output(linear_layer)
353
+ torch.testing.assert_close(ref_output, layer_output)
354
+
355
+
356
+ @pytest.mark.parametrize("model", MODELS)
357
+ @pytest.mark.parametrize("bias", [False, True])
358
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
359
+ @pytest.mark.parametrize("enable_sp", [False, True])
360
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
361
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
362
+ def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
363
+ enable_attn_dp):
364
+ # Skip if enable_attn_dp is True but we don't have enough devices
365
+ if enable_attn_dp and num_devices < 2:
366
+ pytest.skip("enable_attn_dp requires at least 2 devices")
367
+
368
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
369
+ dtype = torch.bfloat16
370
+
371
+ engine_args = EngineArgs(
372
+ model=model,
373
+ max_model_len=64,
374
+ max_num_batched_tokens=64,
375
+ max_num_seqs=4,
376
+ )
377
+ vllm_config = engine_args.create_engine_config()
378
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
379
+
380
+ # Call tpu_inference code
381
+ vllm_config.model_config.dtype = torch.bfloat16
382
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
383
+ with set_current_vllm_config(vllm_config):
384
+ linear_layer = QKVParallelLinear(
385
+ hidden_size=4096,
386
+ head_size=128,
387
+ total_num_heads=32,
388
+ total_num_kv_heads=8,
389
+ bias=bias,
390
+ params_dtype=dtype,
391
+ return_bias=False,
392
+ quant_config=quant_config,
393
+ )
394
+ linear_layer.quant_method.fuse_matmuls = fuse_matmuls
395
+
396
+ initialize_layer_weights(linear_layer)
397
+ ref_output, layer_output = return_ref_and_layer_output(linear_layer)
398
+ torch.testing.assert_close(ref_output, layer_output)
399
+
400
+
401
+ @pytest.mark.parametrize("model", MODELS)
402
+ @pytest.mark.parametrize("bias", [False, True])
403
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
404
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
405
+ @pytest.mark.parametrize("enable_sp", [False, True])
406
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
407
+ def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
408
+ enable_sp, enable_attn_dp):
409
+ # Skip if enable_attn_dp is True but we don't have enough devices
410
+ if enable_attn_dp and num_devices < 2:
411
+ pytest.skip("enable_attn_dp requires at least 2 devices")
412
+
413
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
414
+ dtype = torch.bfloat16
415
+
416
+ engine_args = EngineArgs(
417
+ model=model,
418
+ max_model_len=64,
419
+ max_num_batched_tokens=64,
420
+ max_num_seqs=4,
421
+ )
422
+ vllm_config = engine_args.create_engine_config()
423
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
424
+
425
+ # Call tpu_inference code
426
+ vllm_config.model_config.dtype = torch.bfloat16
427
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
428
+ with set_current_vllm_config(vllm_config):
429
+ linear_layer = MergedColumnParallelLinear(
430
+ input_size=4096,
431
+ output_sizes=[14336] * 2,
432
+ bias=bias,
433
+ params_dtype=dtype,
434
+ return_bias=False,
435
+ quant_config=quant_config,
436
+ )
437
+ linear_layer.quant_method.fuse_matmuls = fuse_matmuls
438
+
439
+ initialize_layer_weights(linear_layer)
440
+ ref_output, layer_output = return_ref_and_layer_output(linear_layer)
441
+ torch.testing.assert_close(ref_output, layer_output)