tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +46 -17
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +44 -17
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
  240. tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,418 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+ from typing import Optional
17
+
18
+ import jax
19
+ import pytest
20
+ import torch
21
+ import torchax
22
+ from compressed_tensors.quantization import QuantizationStrategy
23
+ from jax.sharding import PartitionSpec
24
+ from torchax.interop import torch_view
25
+ from torchax.ops.mappings import j2t, t2j
26
+ from vllm.config import set_current_vllm_config
27
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
28
+ init_distributed_environment)
29
+ from vllm.engine.arg_utils import EngineArgs
30
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
31
+ LinearBase,
32
+ MergedColumnParallelLinear,
33
+ QKVParallelLinear,
34
+ RowParallelLinear)
35
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \
36
+ CompressedTensorsLinearMethod
37
+ from vllm.model_executor.model_loader import get_model as vllm_get_model
38
+
39
+ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
40
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
41
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
42
+ VllmCompressedTensorsConfig
43
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import (
44
+ VllmCompressedTensorsW8A8Fp8, requantize_with_max_scale)
45
+
46
+ from . import utils as test_utils
47
+
48
+ P = PartitionSpec
49
+ MODELS = [
50
+ "RedHatAI/Llama-3.2-1B-Instruct-FP8-dynamic",
51
+ "RedHatAI/Llama-3.2-1B-Instruct-FP8"
52
+ ]
53
+
54
+
55
+ def ref_quantize_fp8(x: torch.Tensor,
56
+ dtype: torch.dtype,
57
+ per_tensor: bool = False):
58
+ dtype_info = torch.finfo(dtype)
59
+ dtype_max = float(dtype_info.max)
60
+ dtype_min = float(dtype_info.min)
61
+
62
+ dim = () if per_tensor else 1
63
+ x_abs_max = torch.amax(torch.abs(x), dim=dim, keepdim=True)
64
+ if per_tensor:
65
+ x_abs_max = torch.squeeze(x_abs_max, dim=-1)
66
+ x_s = x_abs_max / dtype_max
67
+ x_q = torch.clip(x / x_s, dtype_min, dtype_max).to(dtype)
68
+ return x_q, x_s.to(torch.float32)
69
+
70
+
71
+ def ref_w8a8_fp8_dynamic(x: torch.Tensor, w_q: torch.Tensor, w_s: torch.Tensor,
72
+ b: Optional[torch.Tensor]):
73
+ x_q, x_s = ref_quantize_fp8(x, w_q.dtype)
74
+ out = torch.einsum('bd,fd->bf', x_q.to(torch.float32),
75
+ w_q.to(torch.float32))
76
+ out = (out * x_s) * w_s.T
77
+ if b is not None:
78
+ out += b
79
+ return out.to(x.dtype)
80
+
81
+
82
+ def ref_w8a8_fp8_static(x: torch.Tensor, x_s: torch.Tensor, w_q: torch.Tensor,
83
+ w_s: torch.Tensor, b: Optional[torch.Tensor]):
84
+ dtype_info = torch.finfo(w_q.dtype)
85
+ dtype_max = float(dtype_info.max)
86
+ dtype_min = float(dtype_info.min)
87
+
88
+ x_q = torch.clamp(x / x_s, dtype_min, dtype_max).to(w_q.dtype)
89
+ out = torch.einsum('bd,fd->bf', x_q.to(torch.float32),
90
+ w_q.to(torch.float32))
91
+ out = (out * x_s) * w_s.T
92
+ if b is not None:
93
+ out += b
94
+ return out.to(x.dtype)
95
+
96
+
97
+ def return_ref_and_layer_output(layer: torch.nn.Module, batch_size: int = 16):
98
+ assert isinstance(layer, LinearBase)
99
+ scheme = layer.scheme
100
+ assert isinstance(scheme, VllmCompressedTensorsW8A8Fp8)
101
+ quant_config = scheme.jax_config
102
+ assert isinstance(quant_config, JaxCommonLinearConfig)
103
+ quant_method = layer.quant_method
104
+ assert isinstance(quant_method, CompressedTensorsLinearMethod)
105
+ per_tensor = scheme.strategy == QuantizationStrategy.TENSOR
106
+ is_static_input_scheme = scheme.is_static_input_scheme
107
+
108
+ input_tensor = torch.rand(
109
+ batch_size, layer.input_size, dtype=torch.bfloat16) / 10
110
+ input_tensor = input_tensor.to('cpu')
111
+
112
+ weight_scale, weight = layer.weight_scale, layer.weight
113
+ input_scale = getattr(layer, 'input_scale', None)
114
+ # For per_tensor with merged layers, vLLM requenzites them so all merged
115
+ # layers shared the same scale values.
116
+ if per_tensor:
117
+ weight_scale, weight = requantize_with_max_scale(
118
+ layer.weight, layer.weight_scale, quant_config.output_sizes)
119
+ if input_scale is not None:
120
+ input_scale = input_scale.max()
121
+
122
+ # Run reference implementation
123
+ if is_static_input_scheme:
124
+ ref_output = ref_w8a8_fp8_static(
125
+ input_tensor,
126
+ input_scale,
127
+ weight,
128
+ weight_scale,
129
+ layer.bias,
130
+ )
131
+ else:
132
+ ref_output = ref_w8a8_fp8_dynamic(
133
+ input_tensor,
134
+ weight,
135
+ weight_scale,
136
+ layer.bias,
137
+ )
138
+
139
+ # Run torchax/jax function
140
+ with torchax.default_env():
141
+ quant_method.process_weights_after_loading(layer)
142
+
143
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
144
+ layer_output = layer(jax_input_tensor)
145
+ layer_output = j2t(layer_output.to(torch.float32)).to(torch.bfloat16)
146
+
147
+ return ref_output, layer_output
148
+
149
+
150
+ def initialize_layer_weights(layer: torch.nn.Module):
151
+ assert isinstance(layer, LinearBase)
152
+ scheme = layer.scheme
153
+ assert isinstance(scheme, VllmCompressedTensorsW8A8Fp8)
154
+ quant_config = scheme.jax_config
155
+ assert isinstance(quant_config, JaxCommonLinearConfig)
156
+ per_tensor = scheme.strategy == QuantizationStrategy.TENSOR
157
+
158
+ weight_list = []
159
+ weight_scale_list = []
160
+ for output_size in quant_config.output_sizes:
161
+ weight = torch.rand(
162
+ (output_size, layer.input_size), dtype=torch.bfloat16) / 10
163
+ weight_, weight_scale_ = ref_quantize_fp8(weight, torch.float8_e4m3fn,
164
+ per_tensor)
165
+ weight_list.append(weight_)
166
+ weight_scale_list.append(weight_scale_)
167
+
168
+ weight = torch.concatenate(weight_list)
169
+ weight_scale = torch.concatenate(weight_scale_list)
170
+
171
+ assert layer.weight.data.shape == weight.shape
172
+ assert layer.weight_scale.data.shape == weight_scale.shape
173
+
174
+ layer.weight.data = weight
175
+ layer.weight_scale.data = weight_scale
176
+
177
+ if layer.bias is not None:
178
+ layer.bias.data = torch.rand_like(layer.bias.data)
179
+
180
+
181
+ @pytest.fixture(autouse=True)
182
+ def setup_environment():
183
+ # This is a fake config used for init dist env.
184
+ # RowParallelLinear needs dist env to be initialized.
185
+ engine_args = EngineArgs(
186
+ model=MODELS[0],
187
+ max_model_len=64,
188
+ max_num_batched_tokens=64,
189
+ max_num_seqs=4,
190
+ )
191
+
192
+ vllm_config = engine_args.create_engine_config()
193
+
194
+ with set_current_vllm_config(vllm_config):
195
+ temp_file = tempfile.mkstemp()[1]
196
+ init_distributed_environment(
197
+ 1,
198
+ 0,
199
+ local_rank=0,
200
+ distributed_init_method=f"file://{temp_file}",
201
+ backend="gloo")
202
+ ensure_model_parallel_initialized(1, 1)
203
+
204
+
205
+ @pytest.mark.parametrize("model", MODELS)
206
+ @pytest.mark.parametrize("mesh", [
207
+ test_utils.get_spmd_mesh(1),
208
+ test_utils.get_spmd_mesh(jax.local_device_count())
209
+ ])
210
+ def test_quant_override(model, mesh):
211
+
212
+ engine_args = EngineArgs(
213
+ model=model,
214
+ max_model_len=64,
215
+ max_num_batched_tokens=64,
216
+ max_num_seqs=4,
217
+ )
218
+ vllm_config = engine_args.create_engine_config()
219
+ vllm_config.model_config.dtype = torch.bfloat16
220
+
221
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
222
+ assert isinstance(quant_config, VllmCompressedTensorsConfig)
223
+ assert quant_config.vllm_config == vllm_config
224
+ assert quant_config.mesh == mesh
225
+
226
+
227
+ @pytest.mark.parametrize("model", MODELS)
228
+ @pytest.mark.parametrize("mesh", [
229
+ test_utils.get_spmd_mesh(1),
230
+ test_utils.get_spmd_mesh(jax.local_device_count())
231
+ ])
232
+ def test_loading_model(model, mesh):
233
+ engine_args = EngineArgs(
234
+ model=model,
235
+ max_model_len=64,
236
+ max_num_batched_tokens=64,
237
+ max_num_seqs=4,
238
+ )
239
+ vllm_config = engine_args.create_engine_config()
240
+ vllm_config.model_config.dtype = torch.bfloat16
241
+ vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
242
+ vllm_config.device_config.device = "cpu"
243
+
244
+ vllm_model = vllm_get_model(vllm_config=vllm_config)
245
+ layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
246
+ for layer in layers:
247
+ assert isinstance(layer.quant_config, VllmCompressedTensorsConfig)
248
+ assert isinstance(layer.quant_method, CompressedTensorsLinearMethod)
249
+ assert isinstance(layer.scheme, VllmCompressedTensorsW8A8Fp8)
250
+
251
+
252
+ @pytest.mark.parametrize("model", MODELS)
253
+ @pytest.mark.parametrize("bias", [False, True])
254
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
255
+ @pytest.mark.parametrize("enable_sp", [False, True])
256
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
257
+ def test_row_parallel_linear(model, bias, num_devices, enable_sp,
258
+ enable_attn_dp):
259
+ # Skip if enable_attn_dp is True but we don't have enough devices
260
+ if enable_attn_dp and num_devices < 2:
261
+ pytest.skip("enable_attn_dp requires at least 2 devices")
262
+
263
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
264
+ dtype = torch.bfloat16
265
+
266
+ engine_args = EngineArgs(
267
+ model=model,
268
+ max_model_len=64,
269
+ max_num_batched_tokens=64,
270
+ max_num_seqs=4,
271
+ )
272
+ vllm_config = engine_args.create_engine_config()
273
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
274
+
275
+ vllm_config.model_config.dtype = dtype
276
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
277
+ with set_current_vllm_config(vllm_config):
278
+ linear_layer = RowParallelLinear(
279
+ input_size=4096,
280
+ output_size=8192,
281
+ bias=bias,
282
+ params_dtype=dtype,
283
+ return_bias=False,
284
+ quant_config=quant_config,
285
+ )
286
+
287
+ initialize_layer_weights(linear_layer)
288
+ ref_output, layer_output = return_ref_and_layer_output(linear_layer)
289
+ torch.testing.assert_close(ref_output, layer_output)
290
+
291
+
292
+ @pytest.mark.parametrize("model", MODELS)
293
+ @pytest.mark.parametrize("bias", [False, True])
294
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
295
+ @pytest.mark.parametrize("enable_sp", [False, True])
296
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
297
+ def test_column_parallel_linear(model, bias, num_devices, enable_sp,
298
+ enable_attn_dp):
299
+ # Skip if enable_attn_dp is True but we don't have enough devices
300
+ if enable_attn_dp and num_devices < 2:
301
+ pytest.skip("enable_attn_dp requires at least 2 devices")
302
+
303
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
304
+ dtype = torch.bfloat16
305
+
306
+ engine_args = EngineArgs(
307
+ model=model,
308
+ max_model_len=64,
309
+ max_num_batched_tokens=64,
310
+ max_num_seqs=4,
311
+ )
312
+ vllm_config = engine_args.create_engine_config()
313
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
314
+
315
+ # Call tpu_inference code
316
+ vllm_config.model_config.dtype = torch.bfloat16
317
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
318
+ with set_current_vllm_config(vllm_config):
319
+ linear_layer = ColumnParallelLinear(
320
+ input_size=4096,
321
+ output_size=8192,
322
+ bias=bias,
323
+ params_dtype=dtype,
324
+ return_bias=False,
325
+ quant_config=quant_config,
326
+ )
327
+
328
+ initialize_layer_weights(linear_layer)
329
+ ref_output, layer_output = return_ref_and_layer_output(linear_layer)
330
+ torch.testing.assert_close(ref_output, layer_output)
331
+
332
+
333
+ @pytest.mark.parametrize("model", MODELS)
334
+ @pytest.mark.parametrize("bias", [False, True])
335
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
336
+ @pytest.mark.parametrize("enable_sp", [False, True])
337
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
338
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
339
+ def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
340
+ enable_attn_dp):
341
+ # Skip if enable_attn_dp is True but we don't have enough devices
342
+ if enable_attn_dp and num_devices < 2:
343
+ pytest.skip("enable_attn_dp requires at least 2 devices")
344
+
345
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
346
+ dtype = torch.bfloat16
347
+
348
+ engine_args = EngineArgs(
349
+ model=model,
350
+ max_model_len=64,
351
+ max_num_batched_tokens=64,
352
+ max_num_seqs=4,
353
+ )
354
+ vllm_config = engine_args.create_engine_config()
355
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
356
+
357
+ # Call tpu_inference code
358
+ vllm_config.model_config.dtype = torch.bfloat16
359
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
360
+ with set_current_vllm_config(vllm_config):
361
+ linear_layer = QKVParallelLinear(
362
+ hidden_size=4096,
363
+ head_size=128,
364
+ total_num_heads=32,
365
+ total_num_kv_heads=8,
366
+ bias=bias,
367
+ params_dtype=dtype,
368
+ return_bias=False,
369
+ quant_config=quant_config,
370
+ )
371
+ linear_layer.quant_method.fuse_matmuls = fuse_matmuls
372
+
373
+ initialize_layer_weights(linear_layer)
374
+ ref_output, layer_output = return_ref_and_layer_output(linear_layer)
375
+ torch.testing.assert_close(ref_output, layer_output)
376
+
377
+
378
+ @pytest.mark.parametrize("model", MODELS)
379
+ @pytest.mark.parametrize("bias", [False, True])
380
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
381
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
382
+ @pytest.mark.parametrize("enable_sp", [False, True])
383
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
384
+ def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
385
+ enable_sp, enable_attn_dp):
386
+ # Skip if enable_attn_dp is True but we don't have enough devices
387
+ if enable_attn_dp and num_devices < 2:
388
+ pytest.skip("enable_attn_dp requires at least 2 devices")
389
+
390
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
391
+ dtype = torch.bfloat16
392
+
393
+ engine_args = EngineArgs(
394
+ model=model,
395
+ max_model_len=64,
396
+ max_num_batched_tokens=64,
397
+ max_num_seqs=4,
398
+ )
399
+ vllm_config = engine_args.create_engine_config()
400
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
401
+
402
+ # Call tpu_inference code
403
+ vllm_config.model_config.dtype = torch.bfloat16
404
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
405
+ with set_current_vllm_config(vllm_config):
406
+ linear_layer = MergedColumnParallelLinear(
407
+ input_size=4096,
408
+ output_sizes=[14336] * 2,
409
+ bias=bias,
410
+ params_dtype=dtype,
411
+ return_bias=False,
412
+ quant_config=quant_config,
413
+ )
414
+ linear_layer.quant_method.fuse_matmuls = fuse_matmuls
415
+
416
+ initialize_layer_weights(linear_layer)
417
+ ref_output, layer_output = return_ref_and_layer_output(linear_layer)
418
+ torch.testing.assert_close(ref_output, layer_output)