tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,174 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, fields
16
+
17
+ import jax
18
+ import torch
19
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
20
+ from torch.nn import ParameterList
21
+ from torch.nn.parameter import Parameter
22
+ from torchax.tensor import Tensor
23
+
24
+ from tpu_inference.layers.common.utils import \
25
+ reorder_concatenated_tensor_for_sharding
26
+ from tpu_inference.logger import init_logger
27
+
28
+ P = PartitionSpec
29
+
30
+ logger = init_logger(__name__)
31
+
32
+
33
+ @jax.tree_util.register_dataclass
34
+ @dataclass
35
+ class LinearWeights:
36
+ weight: jax.Array | Tensor | list[jax.Array | Tensor]
37
+ weight_scale: jax.Array | Tensor | list[jax.Array | Tensor] | None
38
+ zero_point: jax.Array | Tensor | list[jax.Array | Tensor] | None
39
+ bias: jax.Array | Tensor | list[jax.Array | Tensor] | None
40
+
41
+
42
+ MODEL_MATMUL_FUSION_TRUTH_TABLE = {
43
+ ("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "QKVParallelLinear"):
44
+ True,
45
+ ("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
46
+ False,
47
+ ("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "QKVParallelLinear"):
48
+ False,
49
+ ("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
50
+ False,
51
+ ("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "QKVParallelLinear"):
52
+ False,
53
+ ("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
54
+ False,
55
+ ("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "QKVParallelLinear"):
56
+ False,
57
+ ("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
58
+ False,
59
+ ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "QKVParallelLinear"):
60
+ False,
61
+ ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "MergedColumnParallelLinear"):
62
+ False,
63
+ ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "QKVParallelLinear"):
64
+ False,
65
+ ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "MergedColumnParallelLinear"):
66
+ False,
67
+ }
68
+
69
+
70
+ def to_parameter_list(tensor: list[torch.Tensor]):
71
+ tensor = [Parameter(t, requires_grad=False) for t in tensor]
72
+ return ParameterList(tensor)
73
+
74
+
75
+ def get_model_matmul_fusion_assignment(model_name: str, batch_size: int,
76
+ tp_size: int, layer_name: str):
77
+ key = (model_name, batch_size, tp_size, layer_name)
78
+ return MODEL_MATMUL_FUSION_TRUTH_TABLE.get(key, True)
79
+
80
+
81
+ def process_lienar_weights(
82
+ weights: LinearWeights,
83
+ fused: bool = False,
84
+ output_sizes: list[int] | None = None,
85
+ reorder_size: int | None = None,
86
+ transposed: bool = True,
87
+ per_tensor: bool = False,
88
+ ) -> LinearWeights:
89
+ weight = weights.weight
90
+ weight_scale = weights.weight_scale
91
+ zero_point = weights.zero_point
92
+ bias = weights.bias
93
+
94
+ dim = 0 if transposed else -1
95
+ if output_sizes is None:
96
+ output_sizes = [weight.shape[dim]]
97
+
98
+ if fused:
99
+ assert reorder_size is not None
100
+ weight = reorder_concatenated_tensor_for_sharding(
101
+ weight, output_sizes, reorder_size, dim)
102
+
103
+ if weight_scale is not None and not per_tensor:
104
+ weight_scale = reorder_concatenated_tensor_for_sharding(
105
+ weight_scale, output_sizes, reorder_size, dim)
106
+ if zero_point is not None:
107
+ zero_point = reorder_concatenated_tensor_for_sharding(
108
+ zero_point, output_sizes, reorder_size, dim)
109
+ if bias is not None:
110
+ bias = reorder_concatenated_tensor_for_sharding(
111
+ bias, output_sizes, reorder_size, dim)
112
+ else:
113
+
114
+ def slice_tensor(tensor):
115
+ tensors = []
116
+ start = 0
117
+ for size in output_sizes:
118
+ end = start + size
119
+ tensor_split = jax.lax.slice_in_dim(tensor,
120
+ start,
121
+ end,
122
+ axis=dim)
123
+ tensors.append(tensor_split)
124
+ start = end
125
+ return tensors
126
+
127
+ weight = slice_tensor(weight)
128
+ if weight_scale is not None and not per_tensor:
129
+ weight_scale = slice_tensor(weight_scale)
130
+ if zero_point is not None:
131
+ zero_point = slice_tensor(zero_point)
132
+ if bias is not None:
133
+ bias = slice_tensor(bias)
134
+
135
+ return LinearWeights(
136
+ weight=weight,
137
+ weight_scale=weight_scale,
138
+ zero_point=zero_point,
139
+ bias=bias,
140
+ )
141
+
142
+
143
+ def shard_linear_weights(
144
+ weights: LinearWeights,
145
+ mesh: Mesh,
146
+ weight_p_spec: PartitionSpec,
147
+ bias_p_spec: PartitionSpec,
148
+ transposed: bool = True,
149
+ per_tensor: bool = False,
150
+ ) -> LinearWeights:
151
+
152
+ if not transposed:
153
+ # By defualt, we use transposed weights. If it is not transposed,
154
+ # we need to transpose the sharding as well.
155
+ weight_p_spec = PartitionSpec(*weight_p_spec[::-1])
156
+ bias_p_spec = PartitionSpec(weight_p_spec[0])
157
+
158
+ weight_sharding = NamedSharding(mesh, weight_p_spec)
159
+ bias_sharding = NamedSharding(mesh, bias_p_spec)
160
+
161
+ weight_shardings = LinearWeights(
162
+ weight=weight_sharding,
163
+ weight_scale=NamedSharding(mesh, P()) if per_tensor else bias_sharding,
164
+ zero_point=bias_sharding,
165
+ bias=bias_sharding,
166
+ )
167
+
168
+ for field in fields(LinearWeights):
169
+ key = field.name
170
+ if (weight := getattr(weights, key, None)) is not None:
171
+ sharding = getattr(weight_shardings, key)
172
+ weight = jax.device_put(weight, sharding)
173
+ setattr(weights, key, weight)
174
+ return weights
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import copy
2
16
 
3
17
  from jax.sharding import Mesh
@@ -7,9 +21,10 @@ from vllm.model_executor.layers.quantization.base_config import \
7
21
 
8
22
  from tpu_inference.layers.common import quant_methods
9
23
  from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
10
- from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
11
24
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
12
- VllmCompressedTensorsConfig # noqa: E501
25
+ VllmCompressedTensorsConfig
26
+ from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
27
+ from tpu_inference.layers.vllm.quantization.fp8 import VllmFp8Config
13
28
  from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
14
29
  from tpu_inference.layers.vllm.quantization.unquantized import \
15
30
  VllmUnquantizedConfig
@@ -23,6 +38,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
23
38
  None: VllmUnquantizedConfig,
24
39
  quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
25
40
  quant_methods.AWQ: VllmAWQConfig,
41
+ quant_methods.FP8: VllmFp8Config,
26
42
  quant_methods.MXFP4: VllmMxfp4Config,
27
43
  }
28
44
  if model_config.quantization not in method_to_config:
@@ -30,7 +46,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
30
46
  f"{model_config.quantization} quantization method not supported."
31
47
  f" Supported methods are {method_to_config.keys()}")
32
48
  quant_config = method_to_config[model_config.quantization]
33
- assert issubclass(quant_config, JaxCommonConfig)
49
+ assert issubclass(quant_config, VllmQuantConfig)
34
50
  quant_config.set_configs(vllm_config, mesh)
35
51
 
36
52
  model_config.quantization = quant_methods.get_tpu_quant_method(
@@ -1,11 +1,26 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Optional, Union
2
16
 
3
17
  import jax
4
18
  import jax.numpy as jnp
5
19
  import torch
6
- from jax.sharding import NamedSharding, PartitionSpec
20
+ from jax.sharding import PartitionSpec
21
+ from torch.nn.parameter import Parameter
7
22
  from torchax.interop import jax_view, torch_view
8
- from vllm.logger import init_logger
23
+ from torchax.ops.mappings import t2j
9
24
  from vllm.model_executor.layers.fused_moe.layer import FusedMoE
10
25
  from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
11
26
  from vllm.model_executor.layers.quantization import \
@@ -14,24 +29,29 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
14
29
  AWQLinearMethod)
15
30
  from vllm.model_executor.layers.quantization.base_config import \
16
31
  QuantizeMethodBase
17
- from vllm.model_executor.layers.quantization.utils.quant_utils import (
18
- is_layer_skipped, unpack_quantized_values_into_int32)
19
- from vllm.scalar_type import scalar_types
32
+ from vllm.model_executor.layers.quantization.utils.quant_utils import \
33
+ is_layer_skipped
20
34
 
21
35
  from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
22
- from tpu_inference.layers.vllm.linear_common import (
23
- slice_sharded_tensor_for_concatenation, torch_to_jax_param)
24
- from tpu_inference.layers.vllm.quantization.common import (
25
- JaxCommonConfig, JaxCommonLinearConfig)
36
+ from tpu_inference.layers.common.quantization import awq_u32_unpack_u4
37
+ from tpu_inference.layers.common.utils import \
38
+ slice_sharded_tensor_for_concatenation
39
+ from tpu_inference.layers.vllm.process_weights.linear_weights import (
40
+ LinearWeights, process_lienar_weights, shard_linear_weights,
41
+ to_parameter_list)
42
+ from tpu_inference.layers.vllm.quantization.configs import (
43
+ VllmQuantConfig, VllmQuantLinearConfig)
26
44
  from tpu_inference.layers.vllm.quantization.unquantized import \
27
45
  VllmUnquantizedLinearMethod
46
+ from tpu_inference.logger import init_logger
28
47
 
29
48
  P = PartitionSpec
49
+
30
50
  logger = init_logger(__name__)
31
51
 
32
52
 
33
53
  @register_quantization_config(get_tpu_quant_method(AWQ))
34
- class VllmAWQConfig(AWQConfig, JaxCommonConfig):
54
+ class VllmAWQConfig(AWQConfig, VllmQuantConfig):
35
55
 
36
56
  @classmethod
37
57
  def get_name(cls):
@@ -39,7 +59,7 @@ class VllmAWQConfig(AWQConfig, JaxCommonConfig):
39
59
 
40
60
  def get_supported_act_dtypes(self) -> list[torch.dtype]:
41
61
  # NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
42
- # bfloat16 is signifcantly preferred over foat16. This might lead to
62
+ # bfloat16 is significantly preferred over float16. This might lead to
43
63
  # some numeric output change.
44
64
  return [torch.bfloat16]
45
65
 
@@ -60,72 +80,79 @@ class VllmAWQConfig(AWQConfig, JaxCommonConfig):
60
80
  class VllmAWQLinearMethod(AWQLinearMethod):
61
81
 
62
82
  def __init__(self, quant_config: VllmAWQConfig,
63
- jax_config: JaxCommonLinearConfig):
83
+ linear_config: VllmQuantLinearConfig):
64
84
  super().__init__(quant_config)
65
- self.jax_config = jax_config
66
-
67
- out_sharding, in_sharding = self.jax_config.weight_sharding[:]
68
- self.jax_config.weight_sharding = P(in_sharding, None, out_sharding)
69
- self.jax_config.scale_sharding = P(in_sharding, out_sharding)
85
+ self.linear_config = linear_config
70
86
 
71
87
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
72
- qweight = layer.qweight
73
- qweight = unpack_awq_weight(qweight, qweight.packed_dim)
74
-
75
- group_size = self.quant_config.group_size
76
- # Reshape so that each qweight[i] were quantized with same scales[i].
77
- qweight = qweight.reshape((-1, group_size, layer.output_size))
78
- qweight = torch_to_jax_param(qweight,
79
- NamedSharding(
80
- self.jax_config.mesh,
81
- self.jax_config.weight_sharding),
82
- self.jax_config.output_sizes,
83
- self.jax_config.n_shards,
84
- self.jax_config.fuse_matmuls,
85
- dim=2,
86
- jax_dtype=jnp.uint4)
88
+ assert layer.qweight.packed_dim == layer.qweight.ndim - 1
89
+ weight = t2j(layer.qweight, use_dlpack=False)
87
90
  delattr(layer, "qweight")
88
- layer.qweight = qweight
89
-
90
- qzeros = layer.qzeros
91
- qzeros = unpack_awq_weight(qzeros, qzeros.packed_dim)
92
- qzeros = torch_to_jax_param(qzeros,
93
- NamedSharding(
94
- self.jax_config.mesh,
95
- self.jax_config.scale_sharding),
96
- self.jax_config.output_sizes,
97
- self.jax_config.n_shards,
98
- self.jax_config.fuse_matmuls,
99
- dim=1,
100
- jax_dtype=jnp.uint4)
101
- delattr(layer, "qzeros")
102
- layer.qzeros = qzeros
103
-
104
- scales = torch_to_jax_param(layer.scales,
105
- NamedSharding(
106
- self.jax_config.mesh,
107
- self.jax_config.scale_sharding),
108
- self.jax_config.output_sizes,
109
- self.jax_config.n_shards,
110
- self.jax_config.fuse_matmuls,
111
- dim=1)
91
+
92
+ weight_scale = t2j(layer.scales, use_dlpack=False)
112
93
  delattr(layer, "scales")
113
- layer.scales = scales
94
+
95
+ assert layer.qzeros.packed_dim == layer.qzeros.ndim - 1
96
+ zero_point = t2j(layer.qzeros, use_dlpack=False)
97
+ delattr(layer, "qzeros")
114
98
 
115
99
  if layer.bias is not None and not layer.skip_bias_add:
116
100
  if layer.return_bias:
117
101
  logger.warning_once("Bias might return incorrect value.")
118
-
119
- bias = torch_to_jax_param(
120
- layer.bias,
121
- NamedSharding(self.jax_config.mesh,
122
- self.jax_config.bias_sharding),
123
- self.jax_config.output_sizes,
124
- self.jax_config.n_shards,
125
- self.jax_config.fuse_matmuls,
126
- )
102
+ bias = t2j(layer.bias, use_dlpack=False)
127
103
  delattr(layer, "bias")
128
- layer.bias = bias
104
+ else:
105
+ bias = None
106
+
107
+ @jax.jit
108
+ def process_awq_linear_weights(
109
+ weight: jax.Array,
110
+ weight_scale: jax.Array,
111
+ zero_point: jax.Array,
112
+ bias: jax.Array | None,
113
+ ) -> LinearWeights:
114
+ weight = awq_u32_unpack_u4(weight)
115
+ group_size = self.quant_config.group_size
116
+ weight = weight.reshape((-1, group_size, weight.shape[-1]))
117
+
118
+ zero_point = awq_u32_unpack_u4(zero_point)
119
+
120
+ return process_lienar_weights(
121
+ LinearWeights(
122
+ weight=weight,
123
+ weight_scale=weight_scale,
124
+ zero_point=zero_point,
125
+ bias=bias,
126
+ ),
127
+ fused=self.linear_config.fuse_matmuls,
128
+ output_sizes=self.linear_config.output_sizes,
129
+ reorder_size=self.linear_config.n_shards,
130
+ transposed=False,
131
+ )
132
+
133
+ weights = process_awq_linear_weights(weight, weight_scale, zero_point,
134
+ bias)
135
+ weights = torch_view(
136
+ shard_linear_weights(
137
+ weights,
138
+ mesh=self.linear_config.mesh,
139
+ weight_p_spec=self.linear_config.weight_sharding,
140
+ bias_p_spec=self.linear_config.bias_sharding,
141
+ transposed=False,
142
+ ))
143
+
144
+ if self.linear_config.fuse_matmuls:
145
+ layer.qweight = Parameter(weights.weight, requires_grad=False)
146
+ layer.scales = Parameter(weights.weight_scale, requires_grad=False)
147
+ layer.qzeros = Parameter(weights.zero_point, requires_grad=False)
148
+ if bias is not None:
149
+ layer.bias = Parameter(weights.bias, requires_grad=False)
150
+ else:
151
+ layer.qweight = to_parameter_list(weights.weight)
152
+ layer.scales = to_parameter_list(weights.weight_scale)
153
+ layer.qzeros = to_parameter_list(weights.zero_point)
154
+ if bias is not None:
155
+ layer.bias = to_parameter_list(weights.bias)
129
156
 
130
157
  def apply(self,
131
158
  layer: torch.nn.Module,
@@ -133,7 +160,7 @@ class VllmAWQLinearMethod(AWQLinearMethod):
133
160
  bias: Optional[torch.Tensor] = None) -> torch.Tensor:
134
161
 
135
162
  with jax.named_scope(layer._get_name()):
136
- if self.jax_config.fuse_matmuls:
163
+ if self.linear_config.fuse_matmuls:
137
164
  out = self._apply_fused(layer, x, bias)
138
165
  else:
139
166
  out = self._apply_split(layer, x, bias)
@@ -161,7 +188,7 @@ class VllmAWQLinearMethod(AWQLinearMethod):
161
188
  outs += bias.jax()
162
189
 
163
190
  outs = slice_sharded_tensor_for_concatenation(
164
- outs, self.jax_config.output_sizes, self.jax_config.n_shards)
191
+ outs, self.linear_config.output_sizes, self.linear_config.n_shards)
165
192
  out = jnp.concatenate(outs, axis=-1)
166
193
  return torch_view(out)
167
194
 
@@ -192,16 +219,3 @@ class VllmAWQLinearMethod(AWQLinearMethod):
192
219
  outs.append(out)
193
220
  out = jnp.concatenate(outs, axis=-1)
194
221
  return torch_view(out)
195
-
196
-
197
- def unpack_awq_weight(weight: torch.Tensor, packed_dim: int):
198
- weight = unpack_quantized_values_into_int32(weight, scalar_types.uint4,
199
- packed_dim)
200
-
201
- # AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
202
- # Following list maps the order used by AWQ into an ascending order.
203
- reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
204
-
205
- orig_shape = weight.shape
206
- weight = weight.reshape(orig_shape[:-1] + (-1, 8))
207
- return weight[..., reverse_awq_order].reshape(orig_shape)
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,9 +1,22 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Optional
2
16
 
3
17
  import torch
4
18
  from jax.sharding import PartitionSpec
5
19
  from vllm.attention.layer import Attention
6
- from vllm.logger import init_logger
7
20
  from vllm.model_executor.layers.fused_moe.layer import FusedMoE
8
21
  from vllm.model_executor.layers.linear import LinearBase
9
22
  from vllm.model_executor.layers.quantization import \
@@ -18,22 +31,23 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
18
31
 
19
32
  from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
20
33
  get_tpu_quant_method)
21
- from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
22
34
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
23
35
  VllmCompressedTensorsMoEMethod
24
36
  from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
25
37
  VllmCompressedTensorsW8A8Fp8
26
38
  from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
27
39
  VllmCompressedTensorsW8A8Int8
40
+ from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
28
41
  from tpu_inference.layers.vllm.quantization.unquantized import \
29
42
  VllmUnquantizedConfig
43
+ from tpu_inference.logger import init_logger
30
44
 
31
45
  P = PartitionSpec
32
46
  logger = init_logger(__name__)
33
47
 
34
48
 
35
49
  @register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
36
- class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
50
+ class VllmCompressedTensorsConfig(CompressedTensorsConfig, VllmQuantConfig):
37
51
 
38
52
  @classmethod
39
53
  def get_name(cls) -> str:
@@ -84,14 +98,14 @@ class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
84
98
  return VllmCompressedTensorsW8A8Fp8(
85
99
  weight_quant=weight_quant,
86
100
  is_static_input_scheme=is_static_input_scheme,
87
- jax_config=linear_config,
101
+ linear_config=linear_config,
88
102
  )
89
103
  if self._is_dynamic_token_w8a8(weight_quant, input_quant):
90
104
  return VllmCompressedTensorsW8A8Int8(
91
105
  strategy=weight_quant.strategy,
92
106
  is_static_input_scheme=False,
93
107
  input_symmetric=input_quant.symmetric,
94
- jax_config=linear_config,
108
+ linear_config=linear_config,
95
109
  )
96
110
  raise NotImplementedError(
97
111
  "No compressed-tensors compatible scheme was found.")