tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,221 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Union
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import torch
20
+ from jax.sharding import Mesh, NamedSharding
21
+ from jax.sharding import PartitionSpec as P
22
+ from torchax.interop import torch_view
23
+ from torchax.ops.mappings import t2j
24
+
25
+ from tpu_inference import envs
26
+ from tpu_inference.kernels.quantized_matmul.kernel import (
27
+ quantized_matmul_kernel, xla_quantized_matmul)
28
+
29
+
30
+ def sharded_quantized_matmul(x: jax.Array, w_q: jax.Array, w_s: jax.Array,
31
+ mesh: Mesh, weight_sharding: P) -> jax.Array:
32
+ """
33
+ Wrapper around the quantized matmul kernel.
34
+
35
+ Args:
36
+ x: Activation.
37
+ w_q: Weight quantized array. [n_output_features, n_input_features]
38
+ w_s: Weight quantization scale. [n_output_features]
39
+ mesh: Mesh to shard on.
40
+ weight_sharding: PartitionSpec for the weight tensor.
41
+
42
+ Returns:
43
+ Output of the quantized matmul.
44
+ """
45
+
46
+ # NOTE (jacobplatin/kyuyeunk) there have been numeric issues (concerning) NaNs
47
+ # with the kernel and thus we disable it for now.
48
+ if envs.ENABLE_QUANTIZED_MATMUL_KERNEL:
49
+ out_axis, in_axis = weight_sharding
50
+ x_sharding = P(None, in_axis)
51
+ scale_sharding = P(out_axis, )
52
+ out_sharding = P(None, out_axis)
53
+
54
+ x = jax.lax.with_sharding_constraint(x,
55
+ NamedSharding(mesh, x_sharding))
56
+
57
+ def wrapper(x, w_q, w_s):
58
+ output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
59
+ if in_axis:
60
+ output = jax.lax.psum(output, axis_name=in_axis)
61
+ return output
62
+
63
+ return jax.shard_map(wrapper,
64
+ mesh=mesh,
65
+ in_specs=(x_sharding, weight_sharding,
66
+ scale_sharding),
67
+ out_specs=(out_sharding),
68
+ check_vma=False)(x, w_q, w_s)
69
+ else:
70
+ return xla_quantized_matmul(x, w_q, w_s)
71
+
72
+
73
+ def reorder_concatenated_tensor_for_sharding(concatenated_tensor: jax.Array,
74
+ split_sizes: list[int],
75
+ n_shards: int, dim: int):
76
+ """
77
+ Reorder a replicated concatenated tensor such that when sharded on multiple chips, each shard is a concatenation of the shards of the individual tensors.
78
+ For example, let the concatenated_tensor be:
79
+ AAAAAAAAAAAABBBBBBBBCCCC
80
+ 12 As 8 Bs 4 Cs
81
+ and let the split_sizes = [12, 8, 4] and n_shards = 4.
82
+ The output is:
83
+ AAABBCAAABBCAAABBCAAABBC
84
+ In other words, it reorders the input tensor into 4 segements, with each segment corresponding to a shard and being AAABBC.
85
+ Args:
86
+ concatenated_tensor: the tensor, concatenated on the dimension specified by `dim`.
87
+ split_sizes: each individual tensor's size on the dimension specified by `dim`.
88
+ n_shards: num of shards.
89
+ dim: the dimension on which the concatenated_tensor is concatenated.
90
+ """
91
+ # Split the concatenated tensor into individual tensors.
92
+ split_tensors = []
93
+ start_offset = 0
94
+ old_shape = concatenated_tensor.shape
95
+ # New shape ensures each split_tensor[i] maps to a tensor in ith shards
96
+ new_shape = old_shape[:dim] + (n_shards, -1) + old_shape[dim + 1:]
97
+ for split_size in split_sizes:
98
+ split_tensor = jax.lax.slice_in_dim(concatenated_tensor,
99
+ start_offset,
100
+ start_offset + split_size,
101
+ axis=dim)
102
+ split_tensors.append(split_tensor.reshape(new_shape))
103
+ start_offset += split_size
104
+ # While maintaining 0th dim as a shard dim, we concatenate along 1th dim to
105
+ # to create concatenated tnensor where 0th dim maps to shard dim.
106
+ reordered_tensor = jnp.concatenate(split_tensors, axis=dim + 1)
107
+ return reordered_tensor.reshape(old_shape)
108
+
109
+
110
+ def slice_sharded_tensor_for_concatenation(sharded_tensor: jax.Array,
111
+ split_sizes: list[int],
112
+ n_shards: int):
113
+ """
114
+ Slice the input tensor which is sharded on multiple chips (on the last dim) into individual tensors with the same sharding.
115
+ For example, let the sharded_tensor be:
116
+ AAABBC | AAABBC | AAABBC | AAABBC
117
+ Shard0 Shard1 Shard2 Shard3
118
+ and let the split_sizes = [12, 8, 4] and n_shards = 4.
119
+ The output is a list of 3 tensors:
120
+ AAA | AAA | AAA | AAA
121
+ BB | BB | BB | BB
122
+ C | C | C | C
123
+ Shard0 Shard1 Shard2 Shard3
124
+ In other words, each individual tensor is a slice of the input tensor with the same sharding.
125
+ Args:
126
+ sharded_tensor: the input tensor, sharded on the last dim.
127
+ split_sizes: each individual tensor's size on the last dim.
128
+ n_shards: num of shards.
129
+ """
130
+ new_shape = sharded_tensor.shape[:-1] + (n_shards, -1)
131
+ # New shape ensures each sharded_tensor[:, i] maps to a tensor in ith shards
132
+ sharded_tensor = sharded_tensor.reshape(new_shape)
133
+
134
+ split_tensors = []
135
+ start_offset = 0
136
+ for split_size in split_sizes:
137
+ assert split_size % n_shards == 0
138
+ sz = split_size // n_shards # size of this split tensor per shard
139
+ end_offset = start_offset + sz
140
+ # Because we are slicing over last dim, sharding dim remains intact.
141
+ # Therefore, splitting happens locally.
142
+ split_tensor = sharded_tensor[..., start_offset:end_offset]
143
+ split_tensors.append(split_tensor.reshape(new_shape[:-2] + (-1, )))
144
+ start_offset = end_offset
145
+
146
+ return split_tensors
147
+
148
+
149
+ def torch_to_jax_param(
150
+ tensor: torch.Tensor,
151
+ sharding: NamedSharding,
152
+ output_sizes: Optional[int],
153
+ n_shards: int,
154
+ fused: bool,
155
+ dim: int = 0,
156
+ jax_dtype: Optional[jnp.dtype] = None,
157
+ ) -> Union[torch.nn.Parameter, torch.nn.ParameterList]:
158
+ if output_sizes is None:
159
+ output_sizes = [tensor.shape[0]]
160
+
161
+ tensor = t2j(tensor, use_dlpack=False)
162
+ if jax_dtype:
163
+ tensor = tensor.astype(jax_dtype)
164
+
165
+ if fused:
166
+ tensor = reorder_concatenated_tensor_for_sharding(
167
+ tensor, output_sizes, n_shards, dim)
168
+ tensor = jax.device_put(tensor, sharding)
169
+ param = torch.nn.Parameter(torch_view(tensor), requires_grad=False)
170
+ else:
171
+ tensors = []
172
+ start_offset = 0
173
+ for size in output_sizes:
174
+ end_offset = start_offset + size
175
+
176
+ tensor_split = jax.lax.slice_in_dim(tensor,
177
+ start_offset,
178
+ end_offset,
179
+ axis=dim)
180
+ tensor_split = jax.device_put(tensor_split, sharding)
181
+ tensor_split = torch.nn.Parameter(torch_view(tensor_split),
182
+ requires_grad=False)
183
+ tensors.append(tensor_split)
184
+
185
+ start_offset = end_offset
186
+ param = torch.nn.ParameterList(tensors)
187
+ return param
188
+
189
+
190
+ MODEL_MATMUL_FUSION_TRUTH_TABLE = {
191
+ ("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "QKVParallelLinear"):
192
+ True,
193
+ ("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
194
+ False,
195
+ ("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "QKVParallelLinear"):
196
+ False,
197
+ ("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
198
+ False,
199
+ ("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "QKVParallelLinear"):
200
+ False,
201
+ ("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
202
+ False,
203
+ ("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "QKVParallelLinear"):
204
+ False,
205
+ ("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
206
+ False,
207
+ ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "QKVParallelLinear"):
208
+ False,
209
+ ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "MergedColumnParallelLinear"):
210
+ False,
211
+ ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "QKVParallelLinear"):
212
+ False,
213
+ ("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "MergedColumnParallelLinear"):
214
+ False,
215
+ }
216
+
217
+
218
+ def get_model_matmul_fusion_assignment(model_name: str, batch_size: int,
219
+ tp_size: int, layer_name: str):
220
+ key = (model_name, batch_size, tp_size, layer_name)
221
+ return MODEL_MATMUL_FUSION_TRUTH_TABLE.get(key, True)
@@ -0,0 +1,55 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+
17
+ from jax.sharding import Mesh
18
+ from vllm.config import VllmConfig
19
+ from vllm.model_executor.layers.quantization.base_config import \
20
+ QuantizationConfig
21
+
22
+ from tpu_inference.layers.common import quant_methods
23
+ from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
24
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
25
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
26
+ VllmCompressedTensorsConfig # noqa: E501
27
+ from tpu_inference.layers.vllm.quantization.fp8 import VllmFp8Config
28
+ from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
29
+ from tpu_inference.layers.vllm.quantization.unquantized import \
30
+ VllmUnquantizedConfig
31
+
32
+
33
+ def get_tpu_quantization_config(vllm_config: VllmConfig,
34
+ mesh: Mesh) -> QuantizationConfig:
35
+ model_config = copy.deepcopy(vllm_config.model_config)
36
+ # TODO(kyuyeunk): Add support for "tpu_int8".
37
+ method_to_config: dict[str, str] = {
38
+ None: VllmUnquantizedConfig,
39
+ quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
40
+ quant_methods.AWQ: VllmAWQConfig,
41
+ quant_methods.FP8: VllmFp8Config,
42
+ quant_methods.MXFP4: VllmMxfp4Config,
43
+ }
44
+ if model_config.quantization not in method_to_config:
45
+ raise NotImplementedError(
46
+ f"{model_config.quantization} quantization method not supported."
47
+ f" Supported methods are {method_to_config.keys()}")
48
+ quant_config = method_to_config[model_config.quantization]
49
+ assert issubclass(quant_config, JaxCommonConfig)
50
+ quant_config.set_configs(vllm_config, mesh)
51
+
52
+ model_config.quantization = quant_methods.get_tpu_quant_method(
53
+ quant_config.get_name())
54
+ return VllmConfig.get_quantization_config(model_config,
55
+ vllm_config.load_config)
@@ -0,0 +1,221 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Union
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import torch
20
+ from jax.sharding import NamedSharding, PartitionSpec
21
+ from torchax.interop import jax_view, torch_view
22
+ from vllm.logger import init_logger
23
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
24
+ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
25
+ from vllm.model_executor.layers.quantization import \
26
+ register_quantization_config
27
+ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
28
+ AWQLinearMethod)
29
+ from vllm.model_executor.layers.quantization.base_config import \
30
+ QuantizeMethodBase
31
+ from vllm.model_executor.layers.quantization.utils.quant_utils import (
32
+ is_layer_skipped, unpack_quantized_values_into_int32)
33
+ from vllm.scalar_type import scalar_types
34
+
35
+ from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
36
+ from tpu_inference.layers.vllm.linear_common import (
37
+ slice_sharded_tensor_for_concatenation, torch_to_jax_param)
38
+ from tpu_inference.layers.vllm.quantization.common import (
39
+ JaxCommonConfig, JaxCommonLinearConfig)
40
+ from tpu_inference.layers.vllm.quantization.unquantized import \
41
+ VllmUnquantizedLinearMethod
42
+
43
+ P = PartitionSpec
44
+ logger = init_logger(__name__)
45
+
46
+
47
+ @register_quantization_config(get_tpu_quant_method(AWQ))
48
+ class VllmAWQConfig(AWQConfig, JaxCommonConfig):
49
+
50
+ @classmethod
51
+ def get_name(cls):
52
+ return AWQ
53
+
54
+ def get_supported_act_dtypes(self) -> list[torch.dtype]:
55
+ # NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
56
+ # bfloat16 is significantly preferred over float16. This might lead to
57
+ # some numeric output change.
58
+ return [torch.bfloat16]
59
+
60
+ def get_quant_method(
61
+ self, layer: torch.nn.Module, prefix: str
62
+ ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
63
+ if isinstance(layer, LinearBase):
64
+ linear_config = self.get_linear_config(layer)
65
+ if is_layer_skipped(prefix, self.modules_to_not_convert):
66
+ return VllmUnquantizedLinearMethod(linear_config)
67
+ return VllmAWQLinearMethod(self, linear_config)
68
+ elif isinstance(layer, FusedMoE):
69
+ raise NotImplementedError(
70
+ "AWQ FusedMoE is currently not supported in torchax-jax")
71
+ return None
72
+
73
+
74
+ class VllmAWQLinearMethod(AWQLinearMethod):
75
+
76
+ def __init__(self, quant_config: VllmAWQConfig,
77
+ jax_config: JaxCommonLinearConfig):
78
+ super().__init__(quant_config)
79
+ self.jax_config = jax_config
80
+
81
+ out_sharding, in_sharding = self.jax_config.weight_sharding[:]
82
+ self.jax_config.weight_sharding = P(in_sharding, None, out_sharding)
83
+ self.jax_config.scale_sharding = P(in_sharding, out_sharding)
84
+
85
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
86
+ qweight = layer.qweight
87
+ qweight = unpack_awq_weight(qweight, qweight.packed_dim)
88
+
89
+ group_size = self.quant_config.group_size
90
+ # Reshape so that each qweight[i] were quantized with same scales[i].
91
+ qweight = qweight.reshape((-1, group_size, layer.output_size))
92
+ qweight = torch_to_jax_param(qweight,
93
+ NamedSharding(
94
+ self.jax_config.mesh,
95
+ self.jax_config.weight_sharding),
96
+ self.jax_config.output_sizes,
97
+ self.jax_config.n_shards,
98
+ self.jax_config.fuse_matmuls,
99
+ dim=2,
100
+ jax_dtype=jnp.uint4)
101
+ delattr(layer, "qweight")
102
+ layer.qweight = qweight
103
+
104
+ qzeros = layer.qzeros
105
+ qzeros = unpack_awq_weight(qzeros, qzeros.packed_dim)
106
+ qzeros = torch_to_jax_param(qzeros,
107
+ NamedSharding(
108
+ self.jax_config.mesh,
109
+ self.jax_config.scale_sharding),
110
+ self.jax_config.output_sizes,
111
+ self.jax_config.n_shards,
112
+ self.jax_config.fuse_matmuls,
113
+ dim=1,
114
+ jax_dtype=jnp.uint4)
115
+ delattr(layer, "qzeros")
116
+ layer.qzeros = qzeros
117
+
118
+ scales = torch_to_jax_param(layer.scales,
119
+ NamedSharding(
120
+ self.jax_config.mesh,
121
+ self.jax_config.scale_sharding),
122
+ self.jax_config.output_sizes,
123
+ self.jax_config.n_shards,
124
+ self.jax_config.fuse_matmuls,
125
+ dim=1)
126
+ delattr(layer, "scales")
127
+ layer.scales = scales
128
+
129
+ if layer.bias is not None and not layer.skip_bias_add:
130
+ if layer.return_bias:
131
+ logger.warning_once("Bias might return incorrect value.")
132
+
133
+ bias = torch_to_jax_param(
134
+ layer.bias,
135
+ NamedSharding(self.jax_config.mesh,
136
+ self.jax_config.bias_sharding),
137
+ self.jax_config.output_sizes,
138
+ self.jax_config.n_shards,
139
+ self.jax_config.fuse_matmuls,
140
+ )
141
+ delattr(layer, "bias")
142
+ layer.bias = bias
143
+
144
+ def apply(self,
145
+ layer: torch.nn.Module,
146
+ x: torch.Tensor,
147
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
148
+
149
+ with jax.named_scope(layer._get_name()):
150
+ if self.jax_config.fuse_matmuls:
151
+ out = self._apply_fused(layer, x, bias)
152
+ else:
153
+ out = self._apply_split(layer, x, bias)
154
+
155
+ return out
156
+
157
+ def _apply_fused(self,
158
+ layer: torch.nn.Module,
159
+ x: torch.Tensor,
160
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
161
+ x_jax = jax_view(x)
162
+
163
+ qweight = jax_view(layer.qweight)
164
+ qzeros = jnp.expand_dims(jax_view(layer.qzeros), 1)
165
+ scales = jnp.expand_dims(jax_view(layer.scales), 1)
166
+
167
+ qweight = qweight.astype(jnp.int8)
168
+ qzeros = qzeros.astype(jnp.int8)
169
+
170
+ weight = (qweight - qzeros) * scales
171
+ weight = weight.reshape((-1, weight.shape[-1]))
172
+ outs = jnp.einsum("bd,df->bf", x_jax, weight)
173
+
174
+ if bias is not None and not layer.skip_bias_add:
175
+ outs += bias.jax()
176
+
177
+ outs = slice_sharded_tensor_for_concatenation(
178
+ outs, self.jax_config.output_sizes, self.jax_config.n_shards)
179
+ out = jnp.concatenate(outs, axis=-1)
180
+ return torch_view(out)
181
+
182
+ def _apply_split(self,
183
+ layer: torch.nn.Module,
184
+ x: torch.Tensor,
185
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
186
+ assert isinstance(layer.qweight, torch.nn.ParameterList)
187
+
188
+ x_jax = jax_view(x)
189
+ params = zip(layer.qweight, layer.qzeros, layer.scales)
190
+ outs = []
191
+ for i, (qweight, qzeros, scales) in enumerate(params):
192
+ qweight = jax_view(qweight)
193
+ scales = jnp.expand_dims(jax_view(scales), 1)
194
+ qzeros = jnp.expand_dims(jax_view(qzeros), 1)
195
+
196
+ qweight = qweight.astype(jnp.int8)
197
+ qzeros = qzeros.astype(jnp.int8)
198
+
199
+ weight = (qweight - qzeros) * scales
200
+ weight = weight.reshape((-1, weight.shape[-1]))
201
+ out = jnp.einsum("bd,df->bf", x_jax, weight)
202
+
203
+ if bias is not None and not layer.skip_bias_add:
204
+ out += jax_view(bias[i])
205
+
206
+ outs.append(out)
207
+ out = jnp.concatenate(outs, axis=-1)
208
+ return torch_view(out)
209
+
210
+
211
+ def unpack_awq_weight(weight: torch.Tensor, packed_dim: int):
212
+ weight = unpack_quantized_values_into_int32(weight, scalar_types.uint4,
213
+ packed_dim)
214
+
215
+ # AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
216
+ # Following list maps the order used by AWQ into an ascending order.
217
+ reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
218
+
219
+ orig_shape = weight.shape
220
+ weight = weight.reshape(orig_shape[:-1] + (-1, 8))
221
+ return weight[..., reverse_awq_order].reshape(orig_shape)
@@ -0,0 +1,124 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torchax
16
+ from jax.sharding import Mesh, PartitionSpec
17
+ from vllm.config import VllmConfig
18
+ from vllm.logger import init_logger
19
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEConfig
20
+ # yapf: disable
21
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
22
+ LinearBase,
23
+ MergedColumnParallelLinear,
24
+ QKVParallelLinear,
25
+ ReplicatedLinear,
26
+ RowParallelLinear)
27
+
28
+ from tpu_inference.layers.vllm.linear_common import \
29
+ get_model_matmul_fusion_assignment
30
+ from tpu_inference.utils import TPU_SECOND_LAST_MINOR
31
+
32
+ # yapf: enable
33
+
34
+ P = PartitionSpec
35
+
36
+ logger = init_logger(__name__)
37
+
38
+
39
+ class JaxCommonLinearConfig:
40
+
41
+ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
42
+ assert isinstance(layer, LinearBase)
43
+
44
+ self.mesh = mesh
45
+ self.output_sizes = [layer.output_size]
46
+ self.weight_sharding = P(None, None)
47
+ self.fuse_matmuls = True
48
+ self.enable_sp = vllm_config.compilation_config.pass_config.enable_sp
49
+ self.input_sharding = None
50
+ self.output_sharding = None
51
+
52
+ if isinstance(layer, RowParallelLinear):
53
+ self.weight_sharding = P(None, "model")
54
+ if self.enable_sp:
55
+ self.output_sharding = P("model", None)
56
+ elif isinstance(layer, ColumnParallelLinear):
57
+ self.weight_sharding = P("model", None)
58
+ if self.enable_sp:
59
+ self.input_sharding = P("model", None)
60
+
61
+ if isinstance(layer, MergedColumnParallelLinear) or isinstance(
62
+ layer, QKVParallelLinear):
63
+ self.output_sizes = layer.output_sizes
64
+
65
+ self.fuse_matmuls = get_model_matmul_fusion_assignment(
66
+ vllm_config.model_config.model,
67
+ vllm_config.scheduler_config.max_num_batched_tokens,
68
+ vllm_config.parallel_config.tensor_parallel_size,
69
+ layer._get_name())
70
+ elif isinstance(layer, ReplicatedLinear):
71
+ self.weight_sharding = P(None, None)
72
+ else:
73
+ logger.warning(
74
+ "Unsupported linear layer type of %s. Can potentially yield "
75
+ " bad performance.", type(layer))
76
+
77
+ self.bias_sharding = P(self.weight_sharding[0])
78
+ if isinstance(self.weight_sharding[0], tuple):
79
+ self.n_shards = 1
80
+ for axis in self.weight_sharding[0]:
81
+ self.n_shards *= self.mesh.shape.get(axis, 1)
82
+ else:
83
+ self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
84
+
85
+ def get_input_sharding(self, x: torchax.tensor.Tensor):
86
+ if self.enable_sp:
87
+ token_num = x.shape[0]
88
+ # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
89
+ if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
90
+ return self.input_sharding
91
+ else:
92
+ return None
93
+ return self.input_sharding
94
+
95
+ def get_output_sharding(self, x: torchax.tensor.Tensor):
96
+ if self.enable_sp:
97
+ token_num = x.shape[0]
98
+ # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
99
+ if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
100
+ return self.output_sharding
101
+ else:
102
+ return None
103
+ return self.output_sharding
104
+
105
+
106
+ class JaxCommonConfig:
107
+ vllm_config: VllmConfig
108
+ mesh: Mesh
109
+
110
+ @classmethod
111
+ def set_configs(cls, vllm_config: VllmConfig, mesh: Mesh):
112
+ cls.vllm_config = vllm_config
113
+ cls.mesh = mesh
114
+
115
+ def get_linear_config(self, layer: LinearBase) -> JaxCommonLinearConfig:
116
+ assert isinstance(layer, LinearBase)
117
+ return JaxCommonLinearConfig(self.vllm_config, self.mesh, layer)
118
+
119
+ def get_moe_config(self, layer: FusedMoE) -> FusedMoEConfig:
120
+ assert isinstance(layer, FusedMoE)
121
+ moe_config = layer.moe_config
122
+ use_ep = self.vllm_config.parallel_config.enable_expert_parallel
123
+ moe_config.moe_parallel_config.use_ep = use_ep
124
+ return moe_config
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.