tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,443 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+ from typing import Optional
17
+
18
+ import jax
19
+ import pytest
20
+ import torch
21
+ import torchax
22
+ from jax.sharding import NamedSharding, PartitionSpec
23
+ from torchax.interop import torch_view
24
+ from torchax.ops.mappings import j2t, t2j
25
+ from vllm.config import set_current_vllm_config
26
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
27
+ init_distributed_environment)
28
+ from vllm.engine.arg_utils import EngineArgs
29
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
30
+ LinearBase,
31
+ MergedColumnParallelLinear,
32
+ QKVParallelLinear,
33
+ RowParallelLinear)
34
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \
35
+ CompressedTensorsLinearMethod
36
+ from vllm.model_executor.model_loader import get_model as vllm_get_model
37
+
38
+ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
39
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
40
+ VllmCompressedTensorsConfig
41
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
42
+ VllmCompressedTensorsW8A8Int8
43
+
44
+ from . import utils as test_utils
45
+
46
+ P = PartitionSpec
47
+ MODELS = ["RedHatAI/Qwen2.5-1.5B-quantized.w8a8"]
48
+
49
+
50
+ def ref_quantize_int8(x: torch.Tensor):
51
+ x_abs_max = torch.amax(torch.abs(x), dim=1, keepdim=True)
52
+ x_s = x_abs_max / 127
53
+ x_q = torch.round(x / x_s).to(torch.int8)
54
+ return x_q, x_s.to(torch.float32)
55
+
56
+
57
+ def ref_w8a8_int8(x: torch.Tensor, w_q: torch.Tensor, w_s: torch.Tensor,
58
+ b: Optional[torch.Tensor]):
59
+ x_q, x_s = ref_quantize_int8(x)
60
+ out = torch.einsum('bd,fd->bf', x_q.to(torch.float32),
61
+ w_q.to(torch.float32))
62
+ out = (out * x_s) * w_s.T
63
+ if b is not None:
64
+ out += b
65
+ return out.to(x.dtype)
66
+
67
+
68
+ @pytest.fixture(autouse=True)
69
+ def setup_environment():
70
+ # This is a fake config used for init dist env.
71
+ # RowParallelLinear needs dist env to be initialized.
72
+ engine_args = EngineArgs(
73
+ model=MODELS[0],
74
+ max_model_len=64,
75
+ max_num_batched_tokens=64,
76
+ max_num_seqs=4,
77
+ )
78
+
79
+ vllm_config = engine_args.create_engine_config()
80
+
81
+ with set_current_vllm_config(vllm_config):
82
+ temp_file = tempfile.mkstemp()[1]
83
+ init_distributed_environment(
84
+ 1,
85
+ 0,
86
+ local_rank=0,
87
+ distributed_init_method=f"file://{temp_file}",
88
+ backend="gloo")
89
+ ensure_model_parallel_initialized(1, 1)
90
+
91
+
92
+ @pytest.mark.parametrize("model", MODELS)
93
+ @pytest.mark.parametrize("mesh", [
94
+ test_utils.get_spmd_mesh(1),
95
+ test_utils.get_spmd_mesh(jax.local_device_count())
96
+ ])
97
+ def test_quant_override(model, mesh):
98
+
99
+ engine_args = EngineArgs(
100
+ model=model,
101
+ max_model_len=64,
102
+ max_num_batched_tokens=64,
103
+ max_num_seqs=4,
104
+ )
105
+ vllm_config = engine_args.create_engine_config()
106
+ vllm_config.model_config.dtype = torch.bfloat16
107
+
108
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
109
+ assert isinstance(quant_config, VllmCompressedTensorsConfig)
110
+ assert quant_config.vllm_config == vllm_config
111
+ assert quant_config.mesh == mesh
112
+
113
+
114
+ @pytest.mark.parametrize("model", MODELS)
115
+ @pytest.mark.parametrize("mesh", [
116
+ test_utils.get_spmd_mesh(1),
117
+ test_utils.get_spmd_mesh(jax.local_device_count())
118
+ ])
119
+ def test_loading_model(model, mesh):
120
+ engine_args = EngineArgs(
121
+ model=model,
122
+ max_model_len=64,
123
+ max_num_batched_tokens=64,
124
+ max_num_seqs=4,
125
+ )
126
+ vllm_config = engine_args.create_engine_config()
127
+ vllm_config.model_config.dtype = torch.bfloat16
128
+ vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
129
+ vllm_config.device_config.device = "cpu"
130
+
131
+ vllm_model = vllm_get_model(vllm_config=vllm_config)
132
+ layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
133
+ for layer in layers:
134
+ assert isinstance(layer.quant_config, VllmCompressedTensorsConfig)
135
+ assert isinstance(layer.quant_method, CompressedTensorsLinearMethod)
136
+ assert isinstance(layer.scheme, VllmCompressedTensorsW8A8Int8)
137
+
138
+
139
+ @pytest.mark.parametrize("model", MODELS)
140
+ @pytest.mark.parametrize("bias", [False, True])
141
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
142
+ @pytest.mark.parametrize("enable_sp", [False, True])
143
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
144
+ def test_row_parallel_linear(model, bias, num_devices, enable_sp,
145
+ enable_attn_dp):
146
+ # Skip if enable_attn_dp is True but we don't have enough devices
147
+ if enable_attn_dp and num_devices < 2:
148
+ pytest.skip("enable_attn_dp requires at least 2 devices")
149
+
150
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
151
+
152
+ dtype = torch.bfloat16
153
+
154
+ engine_args = EngineArgs(
155
+ model=model,
156
+ max_model_len=64,
157
+ max_num_batched_tokens=64,
158
+ max_num_seqs=4,
159
+ )
160
+ vllm_config = engine_args.create_engine_config()
161
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
162
+
163
+ # Call tpu_inference code
164
+ vllm_config.model_config.dtype = dtype
165
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
166
+ with set_current_vllm_config(vllm_config):
167
+ jax_row_linear = RowParallelLinear(
168
+ input_size=4096,
169
+ output_size=8192,
170
+ bias=bias,
171
+ params_dtype=dtype,
172
+ return_bias=False,
173
+ quant_config=quant_config,
174
+ )
175
+
176
+ weight_data_float = torch.rand(
177
+ (jax_row_linear.output_size, jax_row_linear.input_size),
178
+ dtype=dtype) / 10
179
+ weight_data, weight_scale_data = ref_quantize_int8(weight_data_float)
180
+ if bias:
181
+ bias_data = torch.rand_like(jax_row_linear.bias.data)
182
+
183
+ jax_row_linear.weight.data = weight_data
184
+ jax_row_linear.weight_scale.data = weight_scale_data
185
+ if bias:
186
+ jax_row_linear.bias.data = bias_data
187
+
188
+ input_tensor = torch.rand(10, jax_row_linear.input_size, dtype=dtype) / 10
189
+ input_tensor = input_tensor.to('cpu')
190
+
191
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
192
+ jax_input_tensor.apply_jax_(jax.device_put,
193
+ NamedSharding(mesh, P(None, None)))
194
+ with torchax.default_env():
195
+ assert isinstance(jax_row_linear.quant_method,
196
+ CompressedTensorsLinearMethod)
197
+ assert isinstance(jax_row_linear.scheme, VllmCompressedTensorsW8A8Int8)
198
+ jax_row_linear.quant_method.process_weights_after_loading(
199
+ jax_row_linear)
200
+ jax_output = jax_row_linear(jax_input_tensor)
201
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
202
+
203
+ # Call reference w8a8 int8
204
+ output = ref_w8a8_int8(
205
+ input_tensor,
206
+ weight_data,
207
+ weight_scale_data,
208
+ bias_data if bias else None,
209
+ )
210
+
211
+ torch.testing.assert_close(output, jax_output)
212
+
213
+
214
+ @pytest.mark.parametrize("model", MODELS)
215
+ @pytest.mark.parametrize("bias", [False, True])
216
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
217
+ @pytest.mark.parametrize("enable_sp", [False, True])
218
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
219
+ def test_column_parallel_linear(model, bias, num_devices, enable_sp,
220
+ enable_attn_dp):
221
+ # Skip if enable_attn_dp is True but we don't have enough devices
222
+ if enable_attn_dp and num_devices < 2:
223
+ pytest.skip("enable_attn_dp requires at least 2 devices")
224
+
225
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
226
+ dtype = torch.bfloat16
227
+
228
+ engine_args = EngineArgs(
229
+ model=model,
230
+ max_model_len=64,
231
+ max_num_batched_tokens=64,
232
+ max_num_seqs=4,
233
+ )
234
+ vllm_config = engine_args.create_engine_config()
235
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
236
+
237
+ # Call tpu_inference code
238
+ vllm_config.model_config.dtype = torch.bfloat16
239
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
240
+ with set_current_vllm_config(vllm_config):
241
+ jax_column_linear = ColumnParallelLinear(
242
+ input_size=4096,
243
+ output_size=8192,
244
+ bias=bias,
245
+ params_dtype=dtype,
246
+ return_bias=False,
247
+ quant_config=quant_config,
248
+ )
249
+
250
+ weight_data_float = torch.rand(
251
+ (jax_column_linear.output_size, jax_column_linear.input_size),
252
+ dtype=dtype) / 10
253
+ weight_data, weight_scale_data = ref_quantize_int8(weight_data_float)
254
+ if bias:
255
+ bias_data = torch.rand_like(jax_column_linear.bias.data)
256
+
257
+ jax_column_linear.weight.data = weight_data
258
+ jax_column_linear.weight_scale.data = weight_scale_data
259
+ if bias:
260
+ jax_column_linear.bias.data = bias_data
261
+
262
+ input_tensor = torch.rand(10, jax_column_linear.input_size,
263
+ dtype=dtype) / 10
264
+ input_tensor = input_tensor.to('cpu')
265
+
266
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
267
+ jax_input_tensor.apply_jax_(jax.device_put,
268
+ NamedSharding(mesh, P(None, None)))
269
+ with torchax.default_env():
270
+ assert isinstance(jax_column_linear.quant_method,
271
+ CompressedTensorsLinearMethod)
272
+ assert isinstance(jax_column_linear.scheme,
273
+ VllmCompressedTensorsW8A8Int8)
274
+ jax_column_linear.quant_method.process_weights_after_loading(
275
+ jax_column_linear)
276
+ jax_output = jax_column_linear(jax_input_tensor)
277
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
278
+
279
+ # Call reference w8a8 int8
280
+ output = ref_w8a8_int8(
281
+ input_tensor,
282
+ weight_data,
283
+ weight_scale_data,
284
+ bias_data if bias else None,
285
+ )
286
+
287
+ torch.testing.assert_close(output, jax_output)
288
+
289
+
290
+ @pytest.mark.parametrize("model", MODELS)
291
+ @pytest.mark.parametrize("bias", [False, True])
292
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
293
+ @pytest.mark.parametrize("enable_sp", [False, True])
294
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
295
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
296
+ def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
297
+ enable_attn_dp):
298
+ # Skip if enable_attn_dp is True but we don't have enough devices
299
+ if enable_attn_dp and num_devices < 2:
300
+ pytest.skip("enable_attn_dp requires at least 2 devices")
301
+
302
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
303
+ dtype = torch.bfloat16
304
+
305
+ engine_args = EngineArgs(
306
+ model=model,
307
+ max_model_len=64,
308
+ max_num_batched_tokens=64,
309
+ max_num_seqs=4,
310
+ )
311
+ vllm_config = engine_args.create_engine_config()
312
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
313
+
314
+ # Call tpu_inference code
315
+ vllm_config.model_config.dtype = torch.bfloat16
316
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
317
+ with set_current_vllm_config(vllm_config):
318
+ jax_qkv_linear = QKVParallelLinear(
319
+ hidden_size=4096,
320
+ head_size=128,
321
+ total_num_heads=32,
322
+ total_num_kv_heads=8,
323
+ bias=bias,
324
+ params_dtype=dtype,
325
+ return_bias=False,
326
+ quant_config=quant_config,
327
+ )
328
+ jax_qkv_linear.quant_method.fuse_matmuls = fuse_matmuls
329
+
330
+ weight_data_float = torch.rand(
331
+ (jax_qkv_linear.output_size, jax_qkv_linear.input_size),
332
+ dtype=dtype) / 10
333
+ weight_data, weight_scale_data = ref_quantize_int8(weight_data_float)
334
+ if bias:
335
+ bias_data = torch.rand_like(jax_qkv_linear.bias.data)
336
+
337
+ jax_qkv_linear.weight.data = weight_data
338
+ jax_qkv_linear.weight_scale.data = weight_scale_data
339
+ if bias:
340
+ jax_qkv_linear.bias.data = bias_data
341
+
342
+ input_tensor = torch.rand(10, jax_qkv_linear.input_size, dtype=dtype) / 10
343
+ input_tensor = input_tensor.to('cpu')
344
+
345
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
346
+ jax_input_tensor.apply_jax_(jax.device_put,
347
+ NamedSharding(mesh, P(None, None)))
348
+ with torchax.default_env():
349
+ assert isinstance(jax_qkv_linear.quant_method,
350
+ CompressedTensorsLinearMethod)
351
+ assert isinstance(jax_qkv_linear.scheme, VllmCompressedTensorsW8A8Int8)
352
+ jax_qkv_linear.quant_method.process_weights_after_loading(
353
+ jax_qkv_linear)
354
+ jax_output = jax_qkv_linear(jax_input_tensor)
355
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
356
+
357
+ # Call reference w8a8 int8
358
+ output = ref_w8a8_int8(
359
+ input_tensor,
360
+ weight_data,
361
+ weight_scale_data,
362
+ bias_data if bias else None,
363
+ )
364
+
365
+ torch.testing.assert_close(output, jax_output)
366
+
367
+
368
+ @pytest.mark.parametrize("model", MODELS)
369
+ @pytest.mark.parametrize("bias", [False, True])
370
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
371
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
372
+ @pytest.mark.parametrize("enable_sp", [False, True])
373
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
374
+ def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
375
+ enable_sp, enable_attn_dp):
376
+ # Skip if enable_attn_dp is True but we don't have enough devices
377
+ if enable_attn_dp and num_devices < 2:
378
+ pytest.skip("enable_attn_dp requires at least 2 devices")
379
+
380
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
381
+ dtype = torch.bfloat16
382
+
383
+ engine_args = EngineArgs(
384
+ model=model,
385
+ max_model_len=64,
386
+ max_num_batched_tokens=64,
387
+ max_num_seqs=4,
388
+ )
389
+ vllm_config = engine_args.create_engine_config()
390
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
391
+
392
+ # Call tpu_inference code
393
+ vllm_config.model_config.dtype = torch.bfloat16
394
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
395
+ with set_current_vllm_config(vllm_config):
396
+ jax_merged_column_linear = MergedColumnParallelLinear(
397
+ input_size=4096,
398
+ output_sizes=[14336] * 2,
399
+ bias=bias,
400
+ params_dtype=dtype,
401
+ return_bias=False,
402
+ quant_config=quant_config,
403
+ )
404
+ jax_merged_column_linear.quant_method.fuse_matmuls = fuse_matmuls
405
+
406
+ weight_data_float = torch.rand((jax_merged_column_linear.output_size,
407
+ jax_merged_column_linear.input_size),
408
+ dtype=dtype) / 10
409
+ weight_data, weight_scale_data = ref_quantize_int8(weight_data_float)
410
+ if bias:
411
+ bias_data = torch.rand_like(jax_merged_column_linear.bias.data)
412
+
413
+ jax_merged_column_linear.weight.data = weight_data
414
+ jax_merged_column_linear.weight_scale.data = weight_scale_data
415
+ if bias:
416
+ jax_merged_column_linear.bias.data = bias_data
417
+
418
+ input_tensor = torch.rand(
419
+ 10, jax_merged_column_linear.input_size, dtype=dtype) / 10
420
+ input_tensor = input_tensor.to('cpu')
421
+
422
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
423
+ jax_input_tensor.apply_jax_(jax.device_put,
424
+ NamedSharding(mesh, P(None, None)))
425
+ with torchax.default_env():
426
+ assert isinstance(jax_merged_column_linear.quant_method,
427
+ CompressedTensorsLinearMethod)
428
+ assert isinstance(jax_merged_column_linear.scheme,
429
+ VllmCompressedTensorsW8A8Int8)
430
+ jax_merged_column_linear.quant_method.process_weights_after_loading(
431
+ jax_merged_column_linear)
432
+ jax_output = jax_merged_column_linear(jax_input_tensor)
433
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
434
+
435
+ # Call reference w8a8 int8
436
+ output = ref_w8a8_int8(
437
+ input_tensor,
438
+ weight_data,
439
+ weight_scale_data,
440
+ bias_data if bias else None,
441
+ )
442
+
443
+ torch.testing.assert_close(output, jax_output)
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pytest
16
+
17
+ pytest.skip("FP8 implementation not complete yet", allow_module_level=True)