tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +46 -17
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +44 -17
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
  240. tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
@@ -1,38 +1,73 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Optional, Union
2
16
 
3
17
  import jax
4
18
  import jax.numpy as jnp
5
19
  import torch
6
- from jax.experimental.shard_map import shard_map
7
20
  from jax.sharding import Mesh, NamedSharding
8
21
  from jax.sharding import PartitionSpec as P
9
22
  from torchax.interop import torch_view
10
23
  from torchax.ops.mappings import t2j
11
24
 
12
- from tpu_inference.kernels.quantized_matmul.kernel import \
13
- quantized_matmul_kernel
25
+ from tpu_inference import envs
26
+ from tpu_inference.kernels.quantized_matmul.kernel import (
27
+ quantized_matmul_kernel, xla_quantized_matmul)
14
28
 
15
29
 
16
30
  def sharded_quantized_matmul(x: jax.Array, w_q: jax.Array, w_s: jax.Array,
17
- mesh: Mesh, weight_sharding: P):
18
- out_axis, in_axis = weight_sharding
19
- x_sharding = P(None, in_axis)
20
- scale_sharding = P(out_axis, )
21
- out_sharding = P(None, out_axis)
22
-
23
- x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, x_sharding))
24
-
25
- def wrapper(x, w_q, w_s):
26
- output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
27
- if in_axis:
28
- output = jax.lax.psum(output, axis_name=in_axis)
29
- return output
30
-
31
- return shard_map(wrapper,
32
- mesh=mesh,
33
- in_specs=(x_sharding, weight_sharding, scale_sharding),
34
- out_specs=(out_sharding),
35
- check_rep=False)(x, w_q, w_s)
31
+ mesh: Mesh, weight_sharding: P) -> jax.Array:
32
+ """
33
+ Wrapper around the quantized matmul kernel.
34
+
35
+ Args:
36
+ x: Activation.
37
+ w_q: Weight quantized array. [n_output_features, n_input_features]
38
+ w_s: Weight quantization scale. [n_output_features]
39
+ mesh: Mesh to shard on.
40
+ weight_sharding: PartitionSpec for the weight tensor.
41
+
42
+ Returns:
43
+ Output of the quantized matmul.
44
+ """
45
+
46
+ # NOTE (jacobplatin/kyuyeunk) there have been numeric issues (concerning) NaNs
47
+ # with the kernel and thus we disable it for now.
48
+ if envs.ENABLE_QUANTIZED_MATMUL_KERNEL:
49
+ out_axis, in_axis = weight_sharding
50
+ x_sharding = P(None, in_axis)
51
+ scale_sharding = P(out_axis, )
52
+ out_sharding = P(None, out_axis)
53
+
54
+ x = jax.lax.with_sharding_constraint(x,
55
+ NamedSharding(mesh, x_sharding))
56
+
57
+ def wrapper(x, w_q, w_s):
58
+ output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
59
+ if in_axis:
60
+ output = jax.lax.psum(output, axis_name=in_axis)
61
+ return output
62
+
63
+ return jax.shard_map(wrapper,
64
+ mesh=mesh,
65
+ in_specs=(x_sharding, weight_sharding,
66
+ scale_sharding),
67
+ out_specs=(out_sharding),
68
+ check_vma=False)(x, w_q, w_s)
69
+ else:
70
+ return xla_quantized_matmul(x, w_q, w_s)
36
71
 
37
72
 
38
73
  def reorder_concatenated_tensor_for_sharding(concatenated_tensor: jax.Array,
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import copy
2
16
 
3
17
  from jax.sharding import Mesh
@@ -10,6 +24,7 @@ from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
10
24
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
11
25
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
12
26
  VllmCompressedTensorsConfig # noqa: E501
27
+ from tpu_inference.layers.vllm.quantization.fp8 import VllmFp8Config
13
28
  from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
14
29
  from tpu_inference.layers.vllm.quantization.unquantized import \
15
30
  VllmUnquantizedConfig
@@ -23,6 +38,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
23
38
  None: VllmUnquantizedConfig,
24
39
  quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
25
40
  quant_methods.AWQ: VllmAWQConfig,
41
+ quant_methods.FP8: VllmFp8Config,
26
42
  quant_methods.MXFP4: VllmMxfp4Config,
27
43
  }
28
44
  if model_config.quantization not in method_to_config:
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Optional, Union
2
16
 
3
17
  import jax
@@ -39,7 +53,7 @@ class VllmAWQConfig(AWQConfig, JaxCommonConfig):
39
53
 
40
54
  def get_supported_act_dtypes(self) -> list[torch.dtype]:
41
55
  # NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
42
- # bfloat16 is signifcantly preferred over foat16. This might lead to
56
+ # bfloat16 is significantly preferred over float16. This might lead to
43
57
  # some numeric output change.
44
58
  return [torch.bfloat16]
45
59
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import torchax
2
16
  from jax.sharding import Mesh, PartitionSpec
3
17
  from vllm.config import VllmConfig
@@ -11,9 +25,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
11
25
  ReplicatedLinear,
12
26
  RowParallelLinear)
13
27
 
28
+ from tpu_inference.layers.common.sharding import ShardingAxisName
14
29
  from tpu_inference.layers.vllm.linear_common import \
15
30
  get_model_matmul_fusion_assignment
16
- from tpu_inference.utils import TPU_SECOND_LAST_MINOR
31
+ from tpu_inference.utils import TPU_SECOND_LAST_MINOR, get_mesh_shape_product
17
32
 
18
33
  # yapf: enable
19
34
 
@@ -31,18 +46,22 @@ class JaxCommonLinearConfig:
31
46
  self.output_sizes = [layer.output_size]
32
47
  self.weight_sharding = P(None, None)
33
48
  self.fuse_matmuls = True
34
- self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
49
+ self.enable_sp = vllm_config.compilation_config.pass_config.enable_sp
35
50
  self.input_sharding = None
36
51
  self.output_sharding = None
37
52
 
53
+ self.tp_size = get_mesh_shape_product(self.mesh,
54
+ ShardingAxisName.MLP_TENSOR)
55
+
38
56
  if isinstance(layer, RowParallelLinear):
39
- self.weight_sharding = P(None, "model")
40
- if self.enable_sequence_parallelism:
41
- self.output_sharding = P("model", None)
57
+ self.weight_sharding = P(None, ShardingAxisName.ATTN_HEAD)
58
+ if self.enable_sp:
59
+ self.output_sharding = P(ShardingAxisName.MLP_TENSOR, None)
42
60
  elif isinstance(layer, ColumnParallelLinear):
43
- self.weight_sharding = P("model", None)
44
- if self.enable_sequence_parallelism:
45
- self.input_sharding = P("model", None)
61
+ self.weight_sharding = P(ShardingAxisName.ATTN_HEAD, None)
62
+
63
+ if self.enable_sp:
64
+ self.input_sharding = P(ShardingAxisName.MLP_TENSOR, None)
46
65
 
47
66
  if isinstance(layer, MergedColumnParallelLinear) or isinstance(
48
67
  layer, QKVParallelLinear):
@@ -61,28 +80,24 @@ class JaxCommonLinearConfig:
61
80
  " bad performance.", type(layer))
62
81
 
63
82
  self.bias_sharding = P(self.weight_sharding[0])
64
- if isinstance(self.weight_sharding[0], tuple):
65
- self.n_shards = 1
66
- for axis in self.weight_sharding[0]:
67
- self.n_shards *= self.mesh.shape.get(axis, 1)
68
- else:
69
- self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
83
+ self.n_shards = get_mesh_shape_product(self.mesh,
84
+ self.weight_sharding[0])
70
85
 
71
86
  def get_input_sharding(self, x: torchax.tensor.Tensor):
72
- if self.enable_sequence_parallelism:
87
+ if self.enable_sp:
73
88
  token_num = x.shape[0]
74
89
  # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
75
- if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
90
+ if token_num // self.tp_size >= TPU_SECOND_LAST_MINOR:
76
91
  return self.input_sharding
77
92
  else:
78
93
  return None
79
94
  return self.input_sharding
80
95
 
81
96
  def get_output_sharding(self, x: torchax.tensor.Tensor):
82
- if self.enable_sequence_parallelism:
97
+ if self.enable_sp:
83
98
  token_num = x.shape[0]
84
99
  # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
85
- if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
100
+ if token_num // self.tp_size >= TPU_SECOND_LAST_MINOR:
86
101
  return self.output_sharding
87
102
  else:
88
103
  return None
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Optional
2
16
 
3
17
  import torch
@@ -20,7 +34,7 @@ from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
20
34
  get_tpu_quant_method)
21
35
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
22
36
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
23
- VllmCompressedTensorsW8A8Fp8MoEMethod
37
+ VllmCompressedTensorsMoEMethod
24
38
  from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
25
39
  VllmCompressedTensorsW8A8Fp8
26
40
  from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
@@ -113,8 +127,9 @@ class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
113
127
  layer.scheme = scheme
114
128
  return CompressedTensorsLinearMethod(self)
115
129
  if isinstance(layer, FusedMoE):
116
- return VllmCompressedTensorsW8A8Fp8MoEMethod(
117
- self, layer.quant_config, self.mesh)
130
+ layer.moe_config = self.get_moe_config(layer)
131
+ return VllmCompressedTensorsMoEMethod.get_moe_method(
132
+ self, layer, layer_name=prefix)
118
133
  if isinstance(layer, Attention):
119
134
  return CompressedTensorsKVCacheMethod(self)
120
135
  return None