tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,26 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Optional, Union
2
16
 
3
17
  import jax
4
18
  import jax.numpy as jnp
5
19
  import torch
6
- from jax.sharding import NamedSharding, PartitionSpec
20
+ from jax.sharding import PartitionSpec
21
+ from torch.nn.parameter import Parameter
7
22
  from torchax.interop import jax_view, torch_view
8
- from vllm.logger import init_logger
23
+ from torchax.ops.mappings import t2j
9
24
  from vllm.model_executor.layers.fused_moe.layer import FusedMoE
10
25
  from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
11
26
  from vllm.model_executor.layers.quantization import \
@@ -14,24 +29,29 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
14
29
  AWQLinearMethod)
15
30
  from vllm.model_executor.layers.quantization.base_config import \
16
31
  QuantizeMethodBase
17
- from vllm.model_executor.layers.quantization.utils.quant_utils import (
18
- is_layer_skipped, unpack_quantized_values_into_int32)
19
- from vllm.scalar_type import scalar_types
32
+ from vllm.model_executor.layers.quantization.utils.quant_utils import \
33
+ is_layer_skipped
20
34
 
21
35
  from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
22
- from tpu_inference.layers.vllm.linear_common import (
23
- slice_sharded_tensor_for_concatenation, torch_to_jax_param)
24
- from tpu_inference.layers.vllm.quantization.common import (
25
- JaxCommonConfig, JaxCommonLinearConfig)
36
+ from tpu_inference.layers.common.quantization import awq_u32_unpack_u4
37
+ from tpu_inference.layers.common.utils import \
38
+ slice_sharded_tensor_for_concatenation
39
+ from tpu_inference.layers.vllm.process_weights.linear_weights import (
40
+ LinearWeights, process_lienar_weights, shard_linear_weights,
41
+ to_parameter_list)
42
+ from tpu_inference.layers.vllm.quantization.configs import (
43
+ VllmQuantConfig, VllmQuantLinearConfig)
26
44
  from tpu_inference.layers.vllm.quantization.unquantized import \
27
45
  VllmUnquantizedLinearMethod
46
+ from tpu_inference.logger import init_logger
28
47
 
29
48
  P = PartitionSpec
49
+
30
50
  logger = init_logger(__name__)
31
51
 
32
52
 
33
53
  @register_quantization_config(get_tpu_quant_method(AWQ))
34
- class VllmAWQConfig(AWQConfig, JaxCommonConfig):
54
+ class VllmAWQConfig(AWQConfig, VllmQuantConfig):
35
55
 
36
56
  @classmethod
37
57
  def get_name(cls):
@@ -39,7 +59,7 @@ class VllmAWQConfig(AWQConfig, JaxCommonConfig):
39
59
 
40
60
  def get_supported_act_dtypes(self) -> list[torch.dtype]:
41
61
  # NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
42
- # bfloat16 is signifcantly preferred over foat16. This might lead to
62
+ # bfloat16 is significantly preferred over float16. This might lead to
43
63
  # some numeric output change.
44
64
  return [torch.bfloat16]
45
65
 
@@ -60,72 +80,79 @@ class VllmAWQConfig(AWQConfig, JaxCommonConfig):
60
80
  class VllmAWQLinearMethod(AWQLinearMethod):
61
81
 
62
82
  def __init__(self, quant_config: VllmAWQConfig,
63
- jax_config: JaxCommonLinearConfig):
83
+ linear_config: VllmQuantLinearConfig):
64
84
  super().__init__(quant_config)
65
- self.jax_config = jax_config
66
-
67
- out_sharding, in_sharding = self.jax_config.weight_sharding[:]
68
- self.jax_config.weight_sharding = P(in_sharding, None, out_sharding)
69
- self.jax_config.scale_sharding = P(in_sharding, out_sharding)
85
+ self.linear_config = linear_config
70
86
 
71
87
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
72
- qweight = layer.qweight
73
- qweight = unpack_awq_weight(qweight, qweight.packed_dim)
74
-
75
- group_size = self.quant_config.group_size
76
- # Reshape so that each qweight[i] were quantized with same scales[i].
77
- qweight = qweight.reshape((-1, group_size, layer.output_size))
78
- qweight = torch_to_jax_param(qweight,
79
- NamedSharding(
80
- self.jax_config.mesh,
81
- self.jax_config.weight_sharding),
82
- self.jax_config.output_sizes,
83
- self.jax_config.n_shards,
84
- self.jax_config.fuse_matmuls,
85
- dim=2,
86
- jax_dtype=jnp.uint4)
88
+ assert layer.qweight.packed_dim == layer.qweight.ndim - 1
89
+ weight = t2j(layer.qweight, use_dlpack=False)
87
90
  delattr(layer, "qweight")
88
- layer.qweight = qweight
89
-
90
- qzeros = layer.qzeros
91
- qzeros = unpack_awq_weight(qzeros, qzeros.packed_dim)
92
- qzeros = torch_to_jax_param(qzeros,
93
- NamedSharding(
94
- self.jax_config.mesh,
95
- self.jax_config.scale_sharding),
96
- self.jax_config.output_sizes,
97
- self.jax_config.n_shards,
98
- self.jax_config.fuse_matmuls,
99
- dim=1,
100
- jax_dtype=jnp.uint4)
101
- delattr(layer, "qzeros")
102
- layer.qzeros = qzeros
103
-
104
- scales = torch_to_jax_param(layer.scales,
105
- NamedSharding(
106
- self.jax_config.mesh,
107
- self.jax_config.scale_sharding),
108
- self.jax_config.output_sizes,
109
- self.jax_config.n_shards,
110
- self.jax_config.fuse_matmuls,
111
- dim=1)
91
+
92
+ weight_scale = t2j(layer.scales, use_dlpack=False)
112
93
  delattr(layer, "scales")
113
- layer.scales = scales
94
+
95
+ assert layer.qzeros.packed_dim == layer.qzeros.ndim - 1
96
+ zero_point = t2j(layer.qzeros, use_dlpack=False)
97
+ delattr(layer, "qzeros")
114
98
 
115
99
  if layer.bias is not None and not layer.skip_bias_add:
116
100
  if layer.return_bias:
117
101
  logger.warning_once("Bias might return incorrect value.")
118
-
119
- bias = torch_to_jax_param(
120
- layer.bias,
121
- NamedSharding(self.jax_config.mesh,
122
- self.jax_config.bias_sharding),
123
- self.jax_config.output_sizes,
124
- self.jax_config.n_shards,
125
- self.jax_config.fuse_matmuls,
126
- )
102
+ bias = t2j(layer.bias, use_dlpack=False)
127
103
  delattr(layer, "bias")
128
- layer.bias = bias
104
+ else:
105
+ bias = None
106
+
107
+ @jax.jit
108
+ def process_awq_linear_weights(
109
+ weight: jax.Array,
110
+ weight_scale: jax.Array,
111
+ zero_point: jax.Array,
112
+ bias: jax.Array | None,
113
+ ) -> LinearWeights:
114
+ weight = awq_u32_unpack_u4(weight)
115
+ group_size = self.quant_config.group_size
116
+ weight = weight.reshape((-1, group_size, weight.shape[-1]))
117
+
118
+ zero_point = awq_u32_unpack_u4(zero_point)
119
+
120
+ return process_lienar_weights(
121
+ LinearWeights(
122
+ weight=weight,
123
+ weight_scale=weight_scale,
124
+ zero_point=zero_point,
125
+ bias=bias,
126
+ ),
127
+ fused=self.linear_config.fuse_matmuls,
128
+ output_sizes=self.linear_config.output_sizes,
129
+ reorder_size=self.linear_config.n_shards,
130
+ transposed=False,
131
+ )
132
+
133
+ weights = process_awq_linear_weights(weight, weight_scale, zero_point,
134
+ bias)
135
+ weights = torch_view(
136
+ shard_linear_weights(
137
+ weights,
138
+ mesh=self.linear_config.mesh,
139
+ weight_p_spec=self.linear_config.weight_sharding,
140
+ bias_p_spec=self.linear_config.bias_sharding,
141
+ transposed=False,
142
+ ))
143
+
144
+ if self.linear_config.fuse_matmuls:
145
+ layer.qweight = Parameter(weights.weight, requires_grad=False)
146
+ layer.scales = Parameter(weights.weight_scale, requires_grad=False)
147
+ layer.qzeros = Parameter(weights.zero_point, requires_grad=False)
148
+ if bias is not None:
149
+ layer.bias = Parameter(weights.bias, requires_grad=False)
150
+ else:
151
+ layer.qweight = to_parameter_list(weights.weight)
152
+ layer.scales = to_parameter_list(weights.weight_scale)
153
+ layer.qzeros = to_parameter_list(weights.zero_point)
154
+ if bias is not None:
155
+ layer.bias = to_parameter_list(weights.bias)
129
156
 
130
157
  def apply(self,
131
158
  layer: torch.nn.Module,
@@ -133,7 +160,7 @@ class VllmAWQLinearMethod(AWQLinearMethod):
133
160
  bias: Optional[torch.Tensor] = None) -> torch.Tensor:
134
161
 
135
162
  with jax.named_scope(layer._get_name()):
136
- if self.jax_config.fuse_matmuls:
163
+ if self.linear_config.fuse_matmuls:
137
164
  out = self._apply_fused(layer, x, bias)
138
165
  else:
139
166
  out = self._apply_split(layer, x, bias)
@@ -161,7 +188,7 @@ class VllmAWQLinearMethod(AWQLinearMethod):
161
188
  outs += bias.jax()
162
189
 
163
190
  outs = slice_sharded_tensor_for_concatenation(
164
- outs, self.jax_config.output_sizes, self.jax_config.n_shards)
191
+ outs, self.linear_config.output_sizes, self.linear_config.n_shards)
165
192
  out = jnp.concatenate(outs, axis=-1)
166
193
  return torch_view(out)
167
194
 
@@ -192,16 +219,3 @@ class VllmAWQLinearMethod(AWQLinearMethod):
192
219
  outs.append(out)
193
220
  out = jnp.concatenate(outs, axis=-1)
194
221
  return torch_view(out)
195
-
196
-
197
- def unpack_awq_weight(weight: torch.Tensor, packed_dim: int):
198
- weight = unpack_quantized_values_into_int32(weight, scalar_types.uint4,
199
- packed_dim)
200
-
201
- # AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
202
- # Following list maps the order used by AWQ into an ascending order.
203
- reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
204
-
205
- orig_shape = weight.shape
206
- weight = weight.reshape(orig_shape[:-1] + (-1, 8))
207
- return weight[..., reverse_awq_order].reshape(orig_shape)
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,9 +1,22 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Optional
2
16
 
3
17
  import torch
4
18
  from jax.sharding import PartitionSpec
5
19
  from vllm.attention.layer import Attention
6
- from vllm.logger import init_logger
7
20
  from vllm.model_executor.layers.fused_moe.layer import FusedMoE
8
21
  from vllm.model_executor.layers.linear import LinearBase
9
22
  from vllm.model_executor.layers.quantization import \
@@ -18,22 +31,23 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
18
31
 
19
32
  from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
20
33
  get_tpu_quant_method)
21
- from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
22
34
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
23
- VllmCompressedTensorsW8A8Fp8MoEMethod
35
+ VllmCompressedTensorsMoEMethod
24
36
  from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
25
37
  VllmCompressedTensorsW8A8Fp8
26
38
  from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
27
39
  VllmCompressedTensorsW8A8Int8
40
+ from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
28
41
  from tpu_inference.layers.vllm.quantization.unquantized import \
29
42
  VllmUnquantizedConfig
43
+ from tpu_inference.logger import init_logger
30
44
 
31
45
  P = PartitionSpec
32
46
  logger = init_logger(__name__)
33
47
 
34
48
 
35
49
  @register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
36
- class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
50
+ class VllmCompressedTensorsConfig(CompressedTensorsConfig, VllmQuantConfig):
37
51
 
38
52
  @classmethod
39
53
  def get_name(cls) -> str:
@@ -84,14 +98,14 @@ class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
84
98
  return VllmCompressedTensorsW8A8Fp8(
85
99
  weight_quant=weight_quant,
86
100
  is_static_input_scheme=is_static_input_scheme,
87
- jax_config=linear_config,
101
+ linear_config=linear_config,
88
102
  )
89
103
  if self._is_dynamic_token_w8a8(weight_quant, input_quant):
90
104
  return VllmCompressedTensorsW8A8Int8(
91
105
  strategy=weight_quant.strategy,
92
106
  is_static_input_scheme=False,
93
107
  input_symmetric=input_quant.symmetric,
94
- jax_config=linear_config,
108
+ linear_config=linear_config,
95
109
  )
96
110
  raise NotImplementedError(
97
111
  "No compressed-tensors compatible scheme was found.")
@@ -113,8 +127,9 @@ class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
113
127
  layer.scheme = scheme
114
128
  return CompressedTensorsLinearMethod(self)
115
129
  if isinstance(layer, FusedMoE):
116
- return VllmCompressedTensorsW8A8Fp8MoEMethod(
117
- self, layer.quant_config, self.mesh)
130
+ layer.moe_config = self.get_moe_config(layer)
131
+ return VllmCompressedTensorsMoEMethod.get_moe_method(
132
+ self, layer, layer_name=prefix)
118
133
  if isinstance(layer, Attention):
119
134
  return CompressedTensorsKVCacheMethod(self)
120
135
  return None