tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,41 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Optional
2
16
 
3
17
  import jax
4
18
  import jax.numpy as jnp
5
19
  import torch
6
20
  from compressed_tensors.quantization import QuantizationStrategy
7
- from jax.sharding import NamedSharding, PartitionSpec
21
+ from jax.sharding import PartitionSpec
22
+ from torch.nn.parameter import Parameter
8
23
  from torchax.interop import jax_view, torch_view
9
- from vllm.logger import init_logger
24
+ from torchax.ops.mappings import t2j
10
25
  from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
11
26
  CompressedTensorsW8A8Int8
12
27
  from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
13
28
  convert_to_channelwise
14
29
 
15
- from tpu_inference.layers.vllm.linear_common import (
16
- sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
17
- torch_to_jax_param)
18
- from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
30
+ from tpu_inference.layers.common.utils import \
31
+ slice_sharded_tensor_for_concatenation
32
+ from tpu_inference.layers.vllm.linear import sharded_quantized_matmul
33
+ from tpu_inference.layers.vllm.process_weights.linear_weights import (
34
+ LinearWeights, process_lienar_weights, shard_linear_weights,
35
+ to_parameter_list)
36
+ from tpu_inference.layers.vllm.quantization.configs import \
37
+ VllmQuantLinearConfig
38
+ from tpu_inference.logger import init_logger
19
39
 
20
40
  P = PartitionSpec
21
41
  logger = init_logger(__name__)
@@ -24,23 +44,15 @@ logger = init_logger(__name__)
24
44
  class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
25
45
 
26
46
  def __init__(self, strategy: str, is_static_input_scheme: bool,
27
- input_symmetric: bool, jax_config: JaxCommonLinearConfig):
47
+ input_symmetric: bool, linear_config: VllmQuantLinearConfig):
28
48
  super().__init__(strategy, is_static_input_scheme, input_symmetric)
29
49
 
30
- self.jax_config = jax_config
31
- self.is_channelwise = (self.strategy == QuantizationStrategy.CHANNEL),
50
+ self.linear_config = linear_config
51
+ self.is_channelwise = (self.strategy == QuantizationStrategy.CHANNEL)
32
52
 
33
53
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
34
- weight = torch_to_jax_param(
35
- layer.weight,
36
- NamedSharding(self.jax_config.mesh,
37
- self.jax_config.weight_sharding),
38
- self.jax_config.output_sizes,
39
- self.jax_config.n_shards,
40
- self.jax_config.fuse_matmuls,
41
- )
54
+ weight = t2j(layer.weight, use_dlpack=False)
42
55
  delattr(layer, "weight")
43
- layer.weight = weight
44
56
 
45
57
  weight_scale = layer.weight_scale
46
58
  is_fused_module = len(layer.logical_widths) > 1
@@ -48,31 +60,55 @@ class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
48
60
  weight_scale = convert_to_channelwise(weight_scale,
49
61
  layer.logical_widths)
50
62
  weight_scale = weight_scale.squeeze(-1)
51
-
52
- weight_scale = torch_to_jax_param(
53
- weight_scale,
54
- NamedSharding(self.jax_config.mesh, self.jax_config.bias_sharding),
55
- self.jax_config.output_sizes,
56
- self.jax_config.n_shards,
57
- self.jax_config.fuse_matmuls,
58
- )
63
+ weight_scale = t2j(weight_scale, use_dlpack=False)
59
64
  delattr(layer, "weight_scale")
60
- layer.weight_scale = weight_scale
61
65
 
62
66
  if layer.bias is not None and not layer.skip_bias_add:
63
67
  if layer.return_bias:
64
68
  logger.warning_once("Bias might return incorrect value.")
65
-
66
- bias = torch_to_jax_param(
67
- layer.bias,
68
- NamedSharding(self.jax_config.mesh,
69
- self.jax_config.bias_sharding),
70
- self.jax_config.output_sizes,
71
- self.jax_config.n_shards,
72
- self.jax_config.fuse_matmuls,
73
- )
69
+ bias = t2j(layer.bias, use_dlpack=False)
74
70
  delattr(layer, "bias")
75
- layer.bias = bias
71
+ else:
72
+ bias = None
73
+
74
+ @jax.jit
75
+ def process_int8_linear_weights(
76
+ weight: jax.Array,
77
+ weight_scale: jax.Array,
78
+ bias: jax.Array | None,
79
+ ) -> LinearWeights:
80
+ return process_lienar_weights(
81
+ LinearWeights(
82
+ weight=weight,
83
+ weight_scale=weight_scale,
84
+ zero_point=None,
85
+ bias=bias,
86
+ ),
87
+ fused=self.linear_config.fuse_matmuls,
88
+ output_sizes=self.linear_config.output_sizes,
89
+ reorder_size=self.linear_config.n_shards,
90
+ )
91
+
92
+ weights = process_int8_linear_weights(weight, weight_scale, bias)
93
+ weights = torch_view(
94
+ shard_linear_weights(
95
+ weights,
96
+ mesh=self.linear_config.mesh,
97
+ weight_p_spec=self.linear_config.weight_sharding,
98
+ bias_p_spec=self.linear_config.bias_sharding,
99
+ ))
100
+
101
+ if self.linear_config.fuse_matmuls:
102
+ layer.weight = Parameter(weights.weight, requires_grad=False)
103
+ layer.weight_scale = Parameter(weights.weight_scale,
104
+ requires_grad=False)
105
+ if bias is not None:
106
+ layer.bias = Parameter(weights.bias, requires_grad=False)
107
+ else:
108
+ layer.weight = to_parameter_list(weights.weight)
109
+ layer.weight_scale = to_parameter_list(weights.weight_scale)
110
+ if bias is not None:
111
+ layer.bias = to_parameter_list(weights.bias)
76
112
 
77
113
  # TODO(kyuyeunk): Support static range input quantization.
78
114
  assert getattr(layer, "input_scale", None) is None
@@ -82,7 +118,7 @@ class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
82
118
  def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
83
119
  bias: Optional[torch.Tensor]) -> torch.Tensor:
84
120
  with jax.named_scope(layer._get_name()):
85
- if self.jax_config.fuse_matmuls:
121
+ if self.linear_config.fuse_matmuls:
86
122
  out = self._apply_fused(layer, x, bias)
87
123
  else:
88
124
  out = self._apply_split(layer, x, bias)
@@ -99,14 +135,14 @@ class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
99
135
  x_jax,
100
136
  weight_jax,
101
137
  weight_scale_jax,
102
- self.jax_config.mesh,
103
- self.jax_config.weight_sharding,
138
+ self.linear_config.mesh,
139
+ self.linear_config.weight_sharding,
104
140
  )
105
141
  if bias is not None and not layer.skip_bias_add:
106
142
  outs += jax_view(bias)
107
143
 
108
144
  outs = slice_sharded_tensor_for_concatenation(
109
- outs, self.jax_config.output_sizes, self.jax_config.n_shards)
145
+ outs, self.linear_config.output_sizes, self.linear_config.n_shards)
110
146
  out = jnp.concatenate(outs, axis=-1)
111
147
  return torch_view(out)
112
148
 
@@ -125,8 +161,8 @@ class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
125
161
  x_jax,
126
162
  weight_jax,
127
163
  weight_scale_jax,
128
- self.jax_config.mesh,
129
- self.jax_config.weight_sharding,
164
+ self.linear_config.mesh,
165
+ self.linear_config.weight_sharding,
130
166
  )
131
167
  if bias is not None and not layer.skip_bias_add:
132
168
  out += jax_view(bias[i])
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import torchax
2
16
  from jax.sharding import Mesh, PartitionSpec
3
17
  from vllm.config import VllmConfig
@@ -11,9 +25,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
11
25
  ReplicatedLinear,
12
26
  RowParallelLinear)
13
27
 
14
- from tpu_inference.layers.vllm.linear_common import \
28
+ from tpu_inference.layers.common.sharding import ShardingAxisName
29
+ from tpu_inference.layers.vllm.process_weights.linear_weights import \
15
30
  get_model_matmul_fusion_assignment
16
- from tpu_inference.utils import TPU_SECOND_LAST_MINOR
31
+ from tpu_inference.utils import TPU_SECOND_LAST_MINOR, get_mesh_shape_product
17
32
 
18
33
  # yapf: enable
19
34
 
@@ -22,7 +37,7 @@ P = PartitionSpec
22
37
  logger = init_logger(__name__)
23
38
 
24
39
 
25
- class JaxCommonLinearConfig:
40
+ class VllmQuantLinearConfig:
26
41
 
27
42
  def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
28
43
  assert isinstance(layer, LinearBase)
@@ -31,18 +46,22 @@ class JaxCommonLinearConfig:
31
46
  self.output_sizes = [layer.output_size]
32
47
  self.weight_sharding = P(None, None)
33
48
  self.fuse_matmuls = True
34
- self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
49
+ self.enable_sp = vllm_config.compilation_config.pass_config.enable_sp
35
50
  self.input_sharding = None
36
51
  self.output_sharding = None
37
52
 
53
+ self.tp_size = get_mesh_shape_product(self.mesh,
54
+ ShardingAxisName.MLP_TENSOR)
55
+
38
56
  if isinstance(layer, RowParallelLinear):
39
- self.weight_sharding = P(None, "model")
40
- if self.enable_sequence_parallelism:
41
- self.output_sharding = P("model", None)
57
+ self.weight_sharding = P(None, ShardingAxisName.ATTN_HEAD)
58
+ if self.enable_sp:
59
+ self.output_sharding = P(ShardingAxisName.MLP_TENSOR, None)
42
60
  elif isinstance(layer, ColumnParallelLinear):
43
- self.weight_sharding = P("model", None)
44
- if self.enable_sequence_parallelism:
45
- self.input_sharding = P("model", None)
61
+ self.weight_sharding = P(ShardingAxisName.ATTN_HEAD, None)
62
+
63
+ if self.enable_sp:
64
+ self.input_sharding = P(ShardingAxisName.MLP_TENSOR, None)
46
65
 
47
66
  if isinstance(layer, MergedColumnParallelLinear) or isinstance(
48
67
  layer, QKVParallelLinear):
@@ -61,30 +80,28 @@ class JaxCommonLinearConfig:
61
80
  " bad performance.", type(layer))
62
81
 
63
82
  self.bias_sharding = P(self.weight_sharding[0])
64
- self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
83
+ self.n_shards = get_mesh_shape_product(self.mesh,
84
+ self.weight_sharding[0])
65
85
 
66
86
  def get_input_sharding(self, x: torchax.tensor.Tensor):
67
- if self.enable_sequence_parallelism:
68
- token_num = x.shape[0]
69
- # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
70
- if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
71
- return self.input_sharding
72
- else:
73
- return None
87
+ if not self.enable_sp:
88
+ return None
89
+ token_num = x.shape[0]
90
+ # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
91
+ if token_num // self.tp_size < TPU_SECOND_LAST_MINOR:
92
+ return None
74
93
  return self.input_sharding
75
94
 
76
95
  def get_output_sharding(self, x: torchax.tensor.Tensor):
77
- if self.enable_sequence_parallelism:
96
+ if self.enable_sp:
78
97
  token_num = x.shape[0]
79
98
  # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
80
- if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
81
- return self.output_sharding
82
- else:
99
+ if token_num // self.tp_size < TPU_SECOND_LAST_MINOR:
83
100
  return None
84
101
  return self.output_sharding
85
102
 
86
103
 
87
- class JaxCommonConfig:
104
+ class VllmQuantConfig:
88
105
  vllm_config: VllmConfig
89
106
  mesh: Mesh
90
107
 
@@ -93,9 +110,9 @@ class JaxCommonConfig:
93
110
  cls.vllm_config = vllm_config
94
111
  cls.mesh = mesh
95
112
 
96
- def get_linear_config(self, layer: LinearBase) -> JaxCommonLinearConfig:
113
+ def get_linear_config(self, layer: LinearBase) -> VllmQuantLinearConfig:
97
114
  assert isinstance(layer, LinearBase)
98
- return JaxCommonLinearConfig(self.vllm_config, self.mesh, layer)
115
+ return VllmQuantLinearConfig(self.vllm_config, self.mesh, layer)
99
116
 
100
117
  def get_moe_config(self, layer: FusedMoE) -> FusedMoEConfig:
101
118
  assert isinstance(layer, FusedMoE)
@@ -0,0 +1,119 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Union
16
+
17
+ import jax
18
+ import torch
19
+ from jax.sharding import PartitionSpec
20
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
21
+ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
22
+ from vllm.model_executor.layers.quantization import \
23
+ register_quantization_config
24
+ from vllm.model_executor.layers.quantization.base_config import \
25
+ QuantizeMethodBase
26
+ from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
27
+ Fp8LinearMethod)
28
+ from vllm.model_executor.layers.quantization.utils.quant_utils import \
29
+ is_layer_skipped
30
+
31
+ from tpu_inference.layers.common.quant_methods import FP8, get_tpu_quant_method
32
+ from tpu_inference.layers.vllm.quantization.configs import (
33
+ VllmQuantConfig, VllmQuantLinearConfig)
34
+ from tpu_inference.layers.vllm.quantization.unquantized import \
35
+ VllmUnquantizedLinearMethod
36
+ from tpu_inference.logger import init_logger
37
+
38
+ P = PartitionSpec
39
+
40
+ logger = init_logger(__name__)
41
+
42
+
43
+ @register_quantization_config(get_tpu_quant_method(FP8))
44
+ class VllmFp8Config(Fp8Config, VllmQuantConfig):
45
+
46
+ @classmethod
47
+ def get_name(cls):
48
+ return FP8
49
+
50
+ def get_supported_act_dtypes(self) -> list[torch.dtype]:
51
+ return [torch.bfloat16]
52
+
53
+ def get_quant_method(
54
+ self, layer: torch.nn.Module, prefix: str
55
+ ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
56
+ if isinstance(layer, LinearBase):
57
+ linear_config = self.get_linear_config(layer)
58
+ if is_layer_skipped(prefix, self.ignored_layers):
59
+ return VllmUnquantizedLinearMethod(linear_config)
60
+ return VllmFp8LinearMethod(self, linear_config)
61
+ elif isinstance(layer, FusedMoE):
62
+ raise NotImplementedError(
63
+ "FP8 FusedMoE is currently not supported in torchax-jax")
64
+ return None
65
+
66
+
67
+ class VllmFp8LinearMethod(Fp8LinearMethod):
68
+
69
+ def __init__(self, quant_config: VllmFp8Config,
70
+ jax_config: VllmQuantLinearConfig):
71
+ super().__init__(quant_config)
72
+ self.jax_config = jax_config
73
+ self._configure_sharding()
74
+
75
+ def _configure_sharding(self) -> None:
76
+
77
+ raise NotImplementedError(
78
+ "Configure PartitionSpec for weight_sharding and scale_sharding "
79
+ "based on layer type (RowParallel/ColumnParallel)")
80
+
81
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
82
+
83
+ raise NotImplementedError(
84
+ "Convert layer.weight, layer.weight_scale, and optionally "
85
+ "layer.input_scale and layer.bias from torch tensors to JAX arrays "
86
+ "using torch_to_jax_param() with appropriate sharding")
87
+
88
+ def apply(self,
89
+ layer: torch.nn.Module,
90
+ x: torch.Tensor,
91
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
92
+
93
+ with jax.named_scope(layer._get_name()):
94
+ if self.jax_config.fuse_matmuls:
95
+ out = self._apply_fused(layer, x, bias)
96
+ else:
97
+ out = self._apply_split(layer, x, bias)
98
+
99
+ return out
100
+
101
+ def _apply_fused(self,
102
+ layer: torch.nn.Module,
103
+ x: torch.Tensor,
104
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
105
+
106
+ raise NotImplementedError(
107
+ "Implement single matmul for fused outputs: "
108
+ "quantize input to fp8, perform fp8 matmul with weight and scales, "
109
+ "dequantize output, and add bias if present")
110
+
111
+ def _apply_split(self,
112
+ layer: torch.nn.Module,
113
+ x: torch.Tensor,
114
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
115
+
116
+ raise NotImplementedError(
117
+ "Implement separate matmuls per output partition: "
118
+ "split weight/scale by output_sizes, perform fp8 matmul for each, "
119
+ "concatenate results, and add bias if present")