tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,18 @@
1
- import os
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
2
14
 
3
15
  import jax
4
- import jax.numpy as jnp
5
16
  import torch
6
17
  import torchax
7
18
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
@@ -9,6 +20,7 @@ from torch.nn import Parameter
9
20
  from torch.utils import _pytree as pytree
10
21
  from torchax.interop import jax_view, torch_view
11
22
  from torchax.ops.mappings import t2j
23
+ from vllm import envs as vllm_envs
12
24
  from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
13
25
  MergedColumnParallelLinearWithLoRA,
14
26
  MergedQKVParallelLinearWithLoRA,
@@ -20,18 +32,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
20
32
  ParallelLMHead, VocabParallelEmbedding)
21
33
 
22
34
  from tpu_inference import envs
35
+ from tpu_inference.layers.common.sharding import ShardingAxisName
23
36
  from tpu_inference.logger import init_logger
37
+ from tpu_inference.utils import to_jax_dtype
24
38
 
25
39
  P = PartitionSpec
26
40
 
27
41
  logger = init_logger(__name__)
28
42
 
29
- TORCH_TO_JAX_DTYPE_MAP = {
30
- torch.float32: jnp.float32,
31
- torch.float16: jnp.float16,
32
- torch.bfloat16: jnp.bfloat16,
33
- }
34
-
35
43
 
36
44
  def shard_model_to_tpu(model: torch.nn.Module,
37
45
  mesh: Mesh) -> dict[str, torchax.torch.Tensor]:
@@ -88,10 +96,9 @@ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
88
96
 
89
97
  def _convert_to_torchax_and_shard(tensor: torch.Tensor,
90
98
  sharding: NamedSharding) -> torch.Tensor:
91
- if os.getenv("VLLM_TPU_USING_PATHWAYS", False) and isinstance(
92
- tensor, torch.Tensor):
99
+ if vllm_envs.VLLM_TPU_USING_PATHWAYS and isinstance(tensor, torch.Tensor):
93
100
  np_tensor = tensor.detach().cpu().to(torch.float32).numpy()
94
- dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32)
101
+ dtype = to_jax_dtype(tensor.dtype)
95
102
  return torch_view(jax.device_put(np_tensor, sharding).astype(dtype))
96
103
  else:
97
104
  if isinstance(tensor, torchax.tensor.Tensor):
@@ -109,7 +116,8 @@ def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,
109
116
  def _shard_vocab_parallel_embedding(layer: VocabParallelEmbedding,
110
117
  mesh: Mesh) -> None:
111
118
  weight = _convert_to_torchax_and_shard(
112
- layer.weight, NamedSharding(mesh, P('model', None)))
119
+ layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
120
+ None)))
113
121
  layer.weight = Parameter(weight, requires_grad=False)
114
122
 
115
123
 
@@ -118,11 +126,12 @@ def _shard_lm_head(layer: ParallelLMHead, mesh: Mesh):
118
126
  # if that config is set, then we should not create new weights but reuse the
119
127
  # weight from VocabParallelEmbedding
120
128
  weight = _convert_to_torchax_and_shard(
121
- layer.weight, NamedSharding(mesh, P('model', None)))
129
+ layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
130
+ None)))
122
131
  layer.weight = Parameter(weight, requires_grad=False)
123
132
  if layer.bias is not None:
124
- bias = _convert_to_torchax_and_shard(layer.bias,
125
- NamedSharding(mesh, P('model')))
133
+ bias = _convert_to_torchax_and_shard(
134
+ layer.bias, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR)))
126
135
  layer.bias = Parameter(bias, requires_grad=False)
127
136
 
128
137
 
@@ -0,0 +1,369 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, fields
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from jax.experimental.layout import Format, Layout, with_layout_constraint
20
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
21
+ from torchax.tensor import Tensor
22
+
23
+ from tpu_inference.layers.common.quantization import quantize_tensor
24
+ from tpu_inference.layers.common.sharding import ShardingAxisName
25
+ from tpu_inference.layers.common.utils import \
26
+ reorder_concatenated_tensor_for_sharding
27
+ from tpu_inference.layers.vllm.fused_moe import FusedMoEBackend
28
+ from tpu_inference.utils import align_to
29
+
30
+ P = PartitionSpec
31
+
32
+
33
+ @jax.tree_util.register_dataclass
34
+ @dataclass
35
+ class FusedMoEWeights:
36
+ """Fused moe weights. weights can be either jax or torchax array."""
37
+ w13_weight: jax.Array | Tensor
38
+ w13_weight_scale: jax.Array | Tensor | None
39
+ w13_bias: jax.Array | Tensor | None
40
+ w2_weight: jax.Array | Tensor
41
+ w2_weight_scale: jax.Array | Tensor | None
42
+ w2_bias: jax.Array | Tensor | None
43
+
44
+
45
+ def quantize_moe_weights(
46
+ weights: FusedMoEWeights,
47
+ dtype: jnp.dtype,
48
+ block_size: int | None,
49
+ ) -> FusedMoEWeights:
50
+ """Quantize fused moe weights into a given dtype and block size.
51
+
52
+ Args:
53
+ weights: fused moe weights.
54
+ dtype: dtype to perform quantization.
55
+ block_size: Specify block quantization size. If non, use per-channel
56
+ quantization. If contracting dim is not divisible by block size,
57
+ the dim will be automatically padded and corresponding dim on bias
58
+ and the other weight (w13_weight <-> w2_weight) is also padded.
59
+
60
+ Returns:
61
+ Quantized fused moe weights that may have also been padded.
62
+ """
63
+
64
+ # If scale is present, it means the weights are already quantized.
65
+ # Ensure that weights are not quantized by checking if scales are None.
66
+ assert weights.w13_weight_scale is None
67
+ assert weights.w2_weight_scale is None
68
+
69
+ w13_weight = weights.w13_weight
70
+ w2_weight = weights.w2_weight
71
+
72
+ if block_size is None:
73
+ # Use per-channel quantizaiton.
74
+ w13_block_size = w13_weight.shape[-1]
75
+ w2_block_size = w2_weight.shape[-1]
76
+ else:
77
+ w13_block_size = w2_block_size = block_size
78
+
79
+ _, orig_hidden_size, orig_intermediate_size = w2_weight.shape
80
+
81
+ w13_weight, w13_weight_scale = quantize_tensor(dtype, w13_weight, 2,
82
+ w13_block_size, True)
83
+ w2_weight, w2_weight_scale = quantize_tensor(dtype, w2_weight, 2,
84
+ w2_block_size, True)
85
+
86
+ intermediate_size = w2_weight.shape[-1]
87
+ hidden_size = w13_weight.shape[-1]
88
+
89
+ # Dims may have been padded to align with subchannel size during
90
+ # quantization. We pad the corresponding dim on other weight.
91
+ # NOTE: We perform padding after quantization as padding value can
92
+ # affect quantization numerics.
93
+ w13_pad_widths = [[0, 0] for _ in range(3)]
94
+ w13_pad_widths[1][1] = 2 * (intermediate_size - orig_intermediate_size)
95
+ w2_pad_widths = [[0, 0] for _ in range(3)]
96
+ w2_pad_widths[1][1] = hidden_size - orig_hidden_size
97
+
98
+ weights.w13_weight = jnp.pad(w13_weight, w13_pad_widths)
99
+ weights.w13_weight_scale = jnp.pad(w13_weight_scale, w13_pad_widths)
100
+ weights.w2_weight = jnp.pad(w2_weight, w2_pad_widths)
101
+ weights.w2_weight_scale = jnp.pad(w2_weight_scale, w2_pad_widths)
102
+
103
+ if (w13_bias := weights.w13_bias) is not None:
104
+ weights.w13_bias = jnp.pad(w13_bias, w13_pad_widths[:2])
105
+ if (w2_bias := weights.w2_bias) is not None:
106
+ weights.w2_bias = jnp.pad(w2_bias, w2_pad_widths[:2])
107
+
108
+ return weights
109
+
110
+
111
+ def process_moe_weights(
112
+ weights: FusedMoEWeights,
113
+ moe_backend: FusedMoEBackend,
114
+ w13_reorder_size: int | None = None,
115
+ w13_interleave: bool = False,
116
+ ) -> FusedMoEWeights:
117
+ """Process fused moe weights to a layout that moe backend expects.
118
+
119
+ Args:
120
+ weights: fused moe weights.
121
+ moe_backend: backend type the weights should be processed for.
122
+ w13_reorder_size: only used when backend type is GMM_TP. in order to
123
+ eliminate collective operations when using tensor parallelism,
124
+ group w13_weight into w13_reorder_size number of chuncks where each
125
+ chunk stores both w1 and w3 weights.
126
+ w13_interleave: used when loaded w13_weight is stored in interleaved
127
+ pattern where even index element is w1 and odd index element is w3.
128
+ we uninterleave so that first half is w1 and second half is w3.
129
+
130
+ Returns:
131
+ MoE weights that are processed for specified backend.
132
+ """
133
+
134
+ w13_weight = weights.w13_weight
135
+ w13_weight_scale = weights.w13_weight_scale
136
+ w13_bias = weights.w13_bias
137
+ w2_weight = weights.w2_weight
138
+ w2_weight_scale = weights.w2_weight_scale
139
+ w2_bias = weights.w2_bias
140
+
141
+ num_experts, hidden_size, intermediate_size = w2_weight.shape
142
+
143
+ if w13_interleave:
144
+ w1_weight = w13_weight[:, ::2, :]
145
+ w3_weight = w13_weight[:, 1::2, :]
146
+ w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
147
+
148
+ if w13_weight_scale is not None:
149
+ w1_weight_scale = w13_weight_scale[:, ::2, :]
150
+ w3_weight_scale = w13_weight_scale[:, 1::2, :]
151
+ w13_weight_scale = jnp.concat([w1_weight_scale, w3_weight_scale],
152
+ axis=1)
153
+
154
+ if w13_bias is not None:
155
+ w1_bias = w13_bias[:, ::2]
156
+ w3_bias = w13_bias[:, 1::2]
157
+ w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
158
+
159
+ if w13_weight_scale is not None:
160
+ w13_weight_scale = w13_weight_scale.astype(jnp.float32)
161
+ if w2_weight_scale is not None:
162
+ w2_weight_scale = w2_weight_scale.astype(jnp.float32)
163
+ if w13_bias is not None:
164
+ w13_bias = w13_bias.astype(jnp.float32)
165
+ if w2_bias is not None:
166
+ w2_bias = w2_bias.astype(jnp.float32)
167
+
168
+ match moe_backend:
169
+ case FusedMoEBackend.FUSED_MOE:
170
+ # Kernel expects:
171
+ # w13: (num_experts, 2, hidden_size, intermediate_size)
172
+ # w2: (num_experts, intermediate_size, hidden_size)
173
+ # Current format:
174
+ # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
175
+ # w2_weight: (num_experts, hidden_size, intermediate_size)
176
+
177
+ # Fused moe kernel expects dims to be multiple of 256.
178
+ pad_width_intermediate_size = align_to(intermediate_size,
179
+ 256) - intermediate_size
180
+ pad_width_hidden_size = align_to(hidden_size, 256) - hidden_size
181
+
182
+ w13_weight = w13_weight.reshape(
183
+ num_experts,
184
+ 2,
185
+ intermediate_size,
186
+ hidden_size,
187
+ )
188
+
189
+ # Transpose non-constracting dim to right most dim
190
+ w13_weight = jnp.swapaxes(w13_weight, 2, 3)
191
+ w2_weight = jnp.swapaxes(w2_weight, 1, 2)
192
+
193
+ # Workaround for JAX error "must have valid byte strides"
194
+ w13_weight = with_layout_constraint(w13_weight, Layout(
195
+ (0, 1, 2, 3)))
196
+ w2_weight = with_layout_constraint(w2_weight, Layout((0, 1, 2)))
197
+
198
+ w13_weight = jnp.pad(
199
+ w13_weight,
200
+ ((0, 0), (0, 0), (0, pad_width_hidden_size),
201
+ (0, pad_width_intermediate_size)),
202
+ )
203
+
204
+ w2_weight = jnp.pad(
205
+ w2_weight,
206
+ ((0, 0), (0, pad_width_intermediate_size),
207
+ (0, pad_width_hidden_size)),
208
+ )
209
+
210
+ if w13_weight_scale is not None:
211
+ w13_weight_scale = w13_weight_scale.reshape(
212
+ num_experts, 2, intermediate_size, 1, -1)
213
+ w13_weight_scale = jnp.swapaxes(w13_weight_scale, 2, 4)
214
+ w13_weight_scale = jnp.pad(
215
+ w13_weight_scale,
216
+ ((0, 0), (0, 0), (0, pad_width_hidden_size), (0, 0),
217
+ (0, pad_width_intermediate_size)),
218
+ )
219
+ if w2_weight_scale is not None:
220
+ w2_weight_scale = w2_weight_scale.reshape(
221
+ num_experts, hidden_size, 1, -1)
222
+ w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 3)
223
+ w2_weight_scale = jnp.pad(
224
+ w2_weight_scale,
225
+ ((0, 0), (0, pad_width_intermediate_size), (0, 0),
226
+ (0, pad_width_hidden_size)),
227
+ )
228
+
229
+ if w13_bias is not None:
230
+ w13_bias = w13_bias.reshape(num_experts, 2, 1,
231
+ intermediate_size)
232
+ w13_bias = jnp.pad(
233
+ w13_bias,
234
+ ((0, 0), (0, 0), (0, 0), (0, pad_width_intermediate_size)),
235
+ )
236
+ if w2_bias is not None:
237
+ w2_bias = w2_bias.reshape(num_experts, 1, hidden_size)
238
+ w2_bias = jnp.pad(
239
+ w2_bias,
240
+ ((0, 0), (0, 0), (0, pad_width_hidden_size)),
241
+ )
242
+
243
+ case FusedMoEBackend.GMM_EP | FusedMoEBackend.GMM_TP:
244
+ if w13_weight_scale is not None:
245
+ w13_weight_scale = jnp.swapaxes(w13_weight_scale, 1, 2)
246
+ w13_weight_scale = jnp.expand_dims(w13_weight_scale, 2)
247
+ if w2_weight_scale is not None:
248
+ w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 2)
249
+ w2_weight_scale = jnp.expand_dims(w2_weight_scale, 2)
250
+ if w13_bias is not None:
251
+ w13_bias = jnp.expand_dims(w13_bias, 1)
252
+ if w2_bias is not None:
253
+ w2_bias = jnp.expand_dims(w2_bias, 1)
254
+
255
+ if moe_backend == FusedMoEBackend.GMM_TP:
256
+ assert w13_reorder_size is not None
257
+ assert intermediate_size % w13_reorder_size == 0
258
+ output_sizes = [intermediate_size, intermediate_size]
259
+ w13_weight = reorder_concatenated_tensor_for_sharding(
260
+ w13_weight,
261
+ output_sizes,
262
+ w13_reorder_size,
263
+ dim=1,
264
+ )
265
+ if w13_weight_scale is not None:
266
+ w13_weight_scale = reorder_concatenated_tensor_for_sharding(
267
+ w13_weight_scale,
268
+ output_sizes,
269
+ w13_reorder_size,
270
+ dim=3,
271
+ )
272
+ if w13_bias is not None:
273
+ w13_bias = reorder_concatenated_tensor_for_sharding(
274
+ w13_bias,
275
+ output_sizes,
276
+ w13_reorder_size,
277
+ dim=2,
278
+ )
279
+
280
+ return FusedMoEWeights(
281
+ w13_weight=w13_weight,
282
+ w13_weight_scale=w13_weight_scale,
283
+ w13_bias=w13_bias,
284
+ w2_weight=w2_weight,
285
+ w2_weight_scale=w2_weight_scale,
286
+ w2_bias=w2_bias,
287
+ )
288
+
289
+
290
+ def shard_moe_weights(
291
+ weights: FusedMoEWeights,
292
+ moe_backend: FusedMoEBackend,
293
+ mesh: Mesh,
294
+ ) -> FusedMoEWeights:
295
+
296
+ match moe_backend:
297
+ case FusedMoEBackend.FUSED_MOE | FusedMoEBackend.GMM_EP:
298
+ ep_sharding = NamedSharding(mesh, P(ShardingAxisName.EXPERT))
299
+ weight_shardings = FusedMoEWeights(
300
+ w13_weight=ep_sharding,
301
+ w13_weight_scale=ep_sharding,
302
+ w13_bias=ep_sharding,
303
+ w2_weight=ep_sharding,
304
+ w2_weight_scale=ep_sharding,
305
+ w2_bias=ep_sharding,
306
+ )
307
+ case FusedMoEBackend.GMM_TP:
308
+ # When using per-channel, in_dim // block_size == 1. This means we
309
+ # are unable to shard w2_weight_scale along 1st dim. Therefore, we
310
+ # fully replicate it instead.
311
+ if (weights.w2_weight_scale is not None
312
+ and weights.w2_weight_scale.shape[1] == 1):
313
+ w2_weight_scale_p_spec = P()
314
+ else:
315
+ w2_weight_scale_p_spec = P(None, ShardingAxisName.MLP_TENSOR)
316
+ weight_shardings = FusedMoEWeights(
317
+ w13_weight=NamedSharding(
318
+ mesh,
319
+ P(None, ShardingAxisName.MLP_TENSOR, None),
320
+ ), # (num_experts, out_dim, in_dim)
321
+ w13_weight_scale=NamedSharding(
322
+ mesh,
323
+ P(None, None, None, ShardingAxisName.MLP_TENSOR),
324
+ ), # (num_experts, in_dim // block_size, 1, out_dim)
325
+ w13_bias=NamedSharding(
326
+ mesh,
327
+ P(None, None, ShardingAxisName.MLP_TENSOR),
328
+ ), # (num_experts, 1, out_dim)
329
+ w2_weight=NamedSharding(
330
+ mesh,
331
+ P(None, None, ShardingAxisName.MLP_TENSOR),
332
+ ), # (num_experts, out_dim, in_dim)
333
+ w2_weight_scale=NamedSharding(
334
+ mesh, w2_weight_scale_p_spec
335
+ ), # (num_experts, in_dim // block_size, 1, out_dim)
336
+ w2_bias=NamedSharding(
337
+ mesh,
338
+ P(None, None, None),
339
+ ), # (num_experts, 1, out_dim)
340
+ )
341
+
342
+ match moe_backend:
343
+ case FusedMoEBackend.FUSED_MOE:
344
+ weight_layouts = FusedMoEWeights(
345
+ w13_weight=Layout((0, 1, 2, 3)),
346
+ w13_weight_scale=Layout((0, 1, 2, 3, 4)),
347
+ w13_bias=Layout((0, 1, 2, 3)),
348
+ w2_weight=Layout((0, 1, 2)),
349
+ w2_weight_scale=Layout((0, 1, 2, 3)),
350
+ w2_bias=Layout((0, 1, 2)),
351
+ )
352
+ case FusedMoEBackend.GMM_TP | FusedMoEBackend.GMM_EP:
353
+ weight_layouts = FusedMoEWeights(
354
+ w13_weight=Layout((0, 1, 2)),
355
+ w13_weight_scale=Layout((0, 1, 2, 3)),
356
+ w13_bias=Layout((0, 1, 2)),
357
+ w2_weight=Layout((0, 1, 2)),
358
+ w2_weight_scale=Layout((0, 1, 2, 3)),
359
+ w2_bias=Layout((0, 1, 2)),
360
+ )
361
+
362
+ for field in fields(FusedMoEWeights):
363
+ key = field.name
364
+ if (weight := getattr(weights, key, None)) is not None:
365
+ layout = getattr(weight_layouts, key)
366
+ sharding = getattr(weight_shardings, key)
367
+ weight = jax.device_put(weight, Format(layout, sharding))
368
+ setattr(weights, key, weight)
369
+ return weights
@@ -0,0 +1,174 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, fields
16
+
17
+ import jax
18
+ import torch
19
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
20
+ from torch.nn import ParameterList
21
+ from torch.nn.parameter import Parameter
22
+ from torchax.tensor import Tensor
23
+
24
+ from tpu_inference.layers.common.utils import \
25
+ reorder_concatenated_tensor_for_sharding
26
+ from tpu_inference.logger import init_logger
27
+
28
+ P = PartitionSpec
29
+
30
+ logger = init_logger(__name__)
31
+
32
+
33
+ @jax.tree_util.register_dataclass
34
+ @dataclass
35
+ class LinearWeights:
36
+ weight: jax.Array | Tensor | list[jax.Array | Tensor]
37
+ weight_scale: jax.Array | Tensor | list[jax.Array | Tensor] | None
38
+ zero_point: jax.Array | Tensor | list[jax.Array | Tensor] | None
39
+ bias: jax.Array | Tensor | list[jax.Array | Tensor] | None
40
+
41
+
42
+ MODEL_MATMUL_FUSION_TRUTH_TABLE = {
43
+ ("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "QKVParallelLinear"):
44
+ True,
45
+ ("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
46
+ False,
47
+ ("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "QKVParallelLinear"):
48
+ False,
49
+ ("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
50
+ False,
51
+ ("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "QKVParallelLinear"):
52
+ False,
53
+ ("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
54
+ False,
55
+ ("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "QKVParallelLinear"):
56
+ False,
57
+ ("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
58
+ False,
59
+ ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "QKVParallelLinear"):
60
+ False,
61
+ ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "MergedColumnParallelLinear"):
62
+ False,
63
+ ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "QKVParallelLinear"):
64
+ False,
65
+ ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "MergedColumnParallelLinear"):
66
+ False,
67
+ }
68
+
69
+
70
+ def to_parameter_list(tensor: list[torch.Tensor]):
71
+ tensor = [Parameter(t, requires_grad=False) for t in tensor]
72
+ return ParameterList(tensor)
73
+
74
+
75
+ def get_model_matmul_fusion_assignment(model_name: str, batch_size: int,
76
+ tp_size: int, layer_name: str):
77
+ key = (model_name, batch_size, tp_size, layer_name)
78
+ return MODEL_MATMUL_FUSION_TRUTH_TABLE.get(key, True)
79
+
80
+
81
+ def process_lienar_weights(
82
+ weights: LinearWeights,
83
+ fused: bool = False,
84
+ output_sizes: list[int] | None = None,
85
+ reorder_size: int | None = None,
86
+ transposed: bool = True,
87
+ per_tensor: bool = False,
88
+ ) -> LinearWeights:
89
+ weight = weights.weight
90
+ weight_scale = weights.weight_scale
91
+ zero_point = weights.zero_point
92
+ bias = weights.bias
93
+
94
+ dim = 0 if transposed else -1
95
+ if output_sizes is None:
96
+ output_sizes = [weight.shape[dim]]
97
+
98
+ if fused:
99
+ assert reorder_size is not None
100
+ weight = reorder_concatenated_tensor_for_sharding(
101
+ weight, output_sizes, reorder_size, dim)
102
+
103
+ if weight_scale is not None and not per_tensor:
104
+ weight_scale = reorder_concatenated_tensor_for_sharding(
105
+ weight_scale, output_sizes, reorder_size, dim)
106
+ if zero_point is not None:
107
+ zero_point = reorder_concatenated_tensor_for_sharding(
108
+ zero_point, output_sizes, reorder_size, dim)
109
+ if bias is not None:
110
+ bias = reorder_concatenated_tensor_for_sharding(
111
+ bias, output_sizes, reorder_size, dim)
112
+ else:
113
+
114
+ def slice_tensor(tensor):
115
+ tensors = []
116
+ start = 0
117
+ for size in output_sizes:
118
+ end = start + size
119
+ tensor_split = jax.lax.slice_in_dim(tensor,
120
+ start,
121
+ end,
122
+ axis=dim)
123
+ tensors.append(tensor_split)
124
+ start = end
125
+ return tensors
126
+
127
+ weight = slice_tensor(weight)
128
+ if weight_scale is not None and not per_tensor:
129
+ weight_scale = slice_tensor(weight_scale)
130
+ if zero_point is not None:
131
+ zero_point = slice_tensor(zero_point)
132
+ if bias is not None:
133
+ bias = slice_tensor(bias)
134
+
135
+ return LinearWeights(
136
+ weight=weight,
137
+ weight_scale=weight_scale,
138
+ zero_point=zero_point,
139
+ bias=bias,
140
+ )
141
+
142
+
143
+ def shard_linear_weights(
144
+ weights: LinearWeights,
145
+ mesh: Mesh,
146
+ weight_p_spec: PartitionSpec,
147
+ bias_p_spec: PartitionSpec,
148
+ transposed: bool = True,
149
+ per_tensor: bool = False,
150
+ ) -> LinearWeights:
151
+
152
+ if not transposed:
153
+ # By defualt, we use transposed weights. If it is not transposed,
154
+ # we need to transpose the sharding as well.
155
+ weight_p_spec = PartitionSpec(*weight_p_spec[::-1])
156
+ bias_p_spec = PartitionSpec(weight_p_spec[0])
157
+
158
+ weight_sharding = NamedSharding(mesh, weight_p_spec)
159
+ bias_sharding = NamedSharding(mesh, bias_p_spec)
160
+
161
+ weight_shardings = LinearWeights(
162
+ weight=weight_sharding,
163
+ weight_scale=NamedSharding(mesh, P()) if per_tensor else bias_sharding,
164
+ zero_point=bias_sharding,
165
+ bias=bias_sharding,
166
+ )
167
+
168
+ for field in fields(LinearWeights):
169
+ key = field.name
170
+ if (weight := getattr(weights, key, None)) is not None:
171
+ sharding = getattr(weight_shardings, key)
172
+ weight = jax.device_put(weight, sharding)
173
+ setattr(weights, key, weight)
174
+ return weights
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import copy
2
16
 
3
17
  from jax.sharding import Mesh
@@ -7,9 +21,10 @@ from vllm.model_executor.layers.quantization.base_config import \
7
21
 
8
22
  from tpu_inference.layers.common import quant_methods
9
23
  from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
10
- from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
11
24
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
12
- VllmCompressedTensorsConfig # noqa: E501
25
+ VllmCompressedTensorsConfig
26
+ from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
27
+ from tpu_inference.layers.vllm.quantization.fp8 import VllmFp8Config
13
28
  from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
14
29
  from tpu_inference.layers.vllm.quantization.unquantized import \
15
30
  VllmUnquantizedConfig
@@ -23,6 +38,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
23
38
  None: VllmUnquantizedConfig,
24
39
  quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
25
40
  quant_methods.AWQ: VllmAWQConfig,
41
+ quant_methods.FP8: VllmFp8Config,
26
42
  quant_methods.MXFP4: VllmMxfp4Config,
27
43
  }
28
44
  if model_config.quantization not in method_to_config:
@@ -30,7 +46,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
30
46
  f"{model_config.quantization} quantization method not supported."
31
47
  f" Supported methods are {method_to_config.keys()}")
32
48
  quant_config = method_to_config[model_config.quantization]
33
- assert issubclass(quant_config, JaxCommonConfig)
49
+ assert issubclass(quant_config, VllmQuantConfig)
34
50
  quant_config.set_configs(vllm_config, mesh)
35
51
 
36
52
  model_config.quantization = quant_methods.get_tpu_quant_method(