tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """Utility functions for ragged paged attention."""
2
15
  import jax
3
16
  from jax._src import dtypes
@@ -13,7 +26,8 @@ def align_to(x, a):
13
26
 
14
27
 
15
28
  def get_dtype_bitwidth(dtype):
16
- return dtypes.bit_width(dtype)
29
+ return (dtypes.bit_width(dtype)
30
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
17
31
 
18
32
 
19
33
  def get_dtype_packing(dtype):
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,10 +1,23 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  import math
3
17
  from typing import Any, Callable, Optional, Tuple
4
18
 
5
19
  import jax
6
20
  import jax.numpy as jnp
7
- from jax.experimental import shard_map
8
21
  from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
9
22
  from jax.experimental.pallas.ops.tpu.splash_attention import \
10
23
  splash_attention_kernel as splash
@@ -55,11 +68,11 @@ def sharded_flash_attention(
55
68
  vmem_limit_bytes=vmem_limit_bytes)
56
69
 
57
70
  return jax.jit(
58
- shard_map.shard_map(_flash_attention,
59
- mesh=mesh,
60
- in_specs=in_specs,
61
- out_specs=out_specs,
62
- check_rep=False))
71
+ jax.shard_map(_flash_attention,
72
+ mesh=mesh,
73
+ in_specs=in_specs,
74
+ out_specs=out_specs,
75
+ check_vma=False))
63
76
 
64
77
 
65
78
  def sharded_paged_attention(
@@ -94,12 +107,12 @@ def sharded_paged_attention(
94
107
  )
95
108
 
96
109
  return jax.jit(
97
- shard_map.shard_map(
110
+ jax.shard_map(
98
111
  _paged_attention_fn,
99
112
  mesh=mesh,
100
113
  in_specs=in_specs,
101
114
  out_specs=out_specs,
102
- check_rep=False,
115
+ check_vma=False,
103
116
  ))
104
117
 
105
118
 
@@ -257,7 +270,7 @@ def sharded_splash_attention(
257
270
  )
258
271
  out_specs = P("data", "model", None, None)
259
272
  return jax.jit(
260
- shard_map.shard_map(
273
+ jax.shard_map(
261
274
  functools.partial(
262
275
  apply_splash,
263
276
  window_size=window_size,
@@ -267,7 +280,7 @@ def sharded_splash_attention(
267
280
  mesh=mesh,
268
281
  in_specs=in_specs,
269
282
  out_specs=out_specs,
270
- check_rep=False,
283
+ check_vma=False,
271
284
  ))
272
285
 
273
286
 
@@ -308,13 +321,7 @@ def sharded_ragged_paged_attention(
308
321
  args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
309
322
 
310
323
  use_hd64 = q.shape[-1] == 64
311
-
312
- func = ragged_paged_attention
313
- if use_hd64:
314
- func = functools.partial(ragged_paged_attention_hd64,
315
- strict_sliding_window=False)
316
- else:
317
- func = ragged_paged_attention
324
+ func = ragged_paged_attention_hd64 if use_hd64 else ragged_paged_attention
318
325
 
319
326
  if attention_sink is not None:
320
327
  if not use_hd64:
@@ -334,12 +341,12 @@ def sharded_ragged_paged_attention(
334
341
  v_scale=v_scale,
335
342
  )
336
343
 
337
- return shard_map.shard_map(
344
+ return jax.shard_map(
338
345
  _ragged_paged_attention,
339
346
  mesh=mesh,
340
347
  in_specs=in_specs,
341
348
  out_specs=out_specs,
342
- check_rep=False,
349
+ check_vma=False,
343
350
  )(*args)
344
351
 
345
352
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  from dataclasses import dataclass, field
3
17
  from typing import Any
@@ -1,7 +1,22 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  UNQUANTIZED = "unquantized"
2
16
  MXFP4 = "mxfp4"
3
17
  AWQ = "awq"
4
18
  COMPRESSED_TENSORS = "compressed-tensors"
19
+ FP8 = "fp8"
5
20
 
6
21
 
7
22
  def get_tpu_quant_method(quant_method: str) -> str:
@@ -0,0 +1,270 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import itertools
16
+ from typing import Tuple
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+
21
+ MXFP4_BLOCK_SIZE = 32
22
+
23
+
24
+ def quantize_tensor_to_mxfp4_packed(
25
+ tensor: jax.Array,
26
+ axis: int | tuple = -1,
27
+ ) -> Tuple[jax.Array, jax.Array]:
28
+ """Quantize a tensor to mxfp4 and pack it into uint8."""
29
+
30
+ # Perform regular block quantization.
31
+ tensor_q, scale = quantize_tensor(
32
+ jnp.float4_e2m1fn,
33
+ tensor,
34
+ axis,
35
+ MXFP4_BLOCK_SIZE,
36
+ )
37
+
38
+ # last two e2m1 elements will be packed into a single uint8 element.
39
+ bitcast_shape = tensor_q.shape[:-1] + (-1, 2)
40
+ tensor_q = tensor_q.reshape(bitcast_shape)
41
+ tensor_q_packed = jax.lax.bitcast_convert_type(tensor_q, jnp.uint8)
42
+
43
+ # Since TPU does not have native support for e8m0, we convert scale into
44
+ # e8m0 manually and store it as uint8.
45
+ e8m0_finfo = jnp.finfo(jnp.float8_e8m0fnu)
46
+ _, scale_exp = jnp.frexp(scale)
47
+ # Subtract exponents by one since e8m0 has no decimal.
48
+ scale_exp -= 1
49
+ scale_exp = (scale_exp - e8m0_finfo.minexp).astype(jnp.uint8)
50
+
51
+ return tensor_q_packed, scale_exp
52
+
53
+
54
+ def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
55
+ """Unpack e2m1 tensor packed into u8."""
56
+ assert u8_packed_e2m1.dtype == jnp.uint8
57
+ e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
58
+ # bitcast creates one more dimension that splits 8 bits into two e2m1.
59
+ # we flatten them with the last dim.
60
+ return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
61
+
62
+
63
+ def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
64
+ """Convert e8m0 (that was bitcasted to u8) into fp32"""
65
+ assert u8.dtype == jnp.uint8
66
+
67
+ e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
68
+ exponents = u8.astype(jnp.int32) + e8_finfo.minexp
69
+ ones = jnp.ones_like(u8, dtype=jnp.float32)
70
+ return jnp.ldexp(ones, exponents)
71
+
72
+
73
+ def dequantize_tensor(
74
+ tensor_q: jax.Array,
75
+ scale: jax.Array,
76
+ axis: int | None | tuple = -1,
77
+ out_dtype: jnp.dtype = jnp.bfloat16,
78
+ ) -> jax.Array:
79
+ """Dequantize a quantized tensor
80
+
81
+ Args:
82
+ tensor_q: Quantized tensor.
83
+ scale: Quantization scale.
84
+ axis: The axis tensor was quantized. None denotes per-tensor.
85
+ out_dtype: Dtype of the output.
86
+
87
+ Returns:
88
+ Dequantized tensor_q.
89
+ """
90
+ if axis is None:
91
+ # Perform per-tensor quantization.
92
+ axis = [i for i in range(tensor_q.ndim)]
93
+ if isinstance(axis, int):
94
+ axis = [axis]
95
+
96
+ orig_shape = tensor_q.shape
97
+ if tensor_q.ndim == scale.ndim:
98
+ # Indicates the tensor was block quantized.
99
+ blocked_shape = [[i] for i in orig_shape]
100
+ for i in axis:
101
+ num_blocks = scale.shape[i]
102
+ if tensor_q.shape[i] % num_blocks:
103
+ raise ValueError(
104
+ f"Unable to perform block dequantization. axis={i} of "
105
+ f"{tensor_q.shape=} is not divisible by {num_blocks=}", )
106
+ block_size = tensor_q.shape[i] // num_blocks
107
+
108
+ blocked_shape[i] = (num_blocks, block_size)
109
+
110
+ # Convert all axis into positive values.
111
+ axis = sorted([(i + tensor_q.ndim) % tensor_q.ndim for i in axis])
112
+ # Shift axis by 1 since its original position is now occupied by
113
+ # num_blocks dim. Also, if n axes before an axis was also quantized,
114
+ # shift its position by n.
115
+ axis = [1 + n + i for n, i in enumerate(axis)]
116
+
117
+ # Flatten list of lists that contains (num_blocks, block).
118
+ blocked_shape = list(itertools.chain(*blocked_shape))
119
+ tensor_q = tensor_q.reshape(blocked_shape)
120
+
121
+ scale = jnp.expand_dims(scale, axis)
122
+
123
+ tensor = (tensor_q.astype(jnp.float32) * scale).astype(out_dtype)
124
+
125
+ return tensor.reshape(orig_shape)
126
+
127
+
128
+ def dequantize_tensor_from_mxfp4_packed(
129
+ tensor_q: jax.Array,
130
+ scale: jax.Array,
131
+ axis: int | tuple = -1,
132
+ out_dtype: jnp.dtype = jnp.bfloat16,
133
+ ) -> jax.Array:
134
+ """Dequantize packed mxfp4 tensor.
135
+
136
+ Args:
137
+ tensor_q: fp4 tensor packed into uint8.
138
+ scale: e8m0 scale packed into uint8.
139
+ axis: The axis tensor was quantized.
140
+ out_dtype: Dtype of the output.
141
+
142
+ Returns:
143
+ Dequantized tensor_q.
144
+ """
145
+ tensor_e2m1 = u8_unpack_e2m1(tensor_q)
146
+ scale_fp32 = e8m0_to_fp32(scale)
147
+
148
+ return dequantize_tensor(
149
+ tensor_e2m1,
150
+ scale_fp32,
151
+ axis,
152
+ out_dtype,
153
+ )
154
+
155
+
156
+ def quantize_tensor(
157
+ dtype: jnp.dtype,
158
+ tensor: jax.Array,
159
+ axis: int | tuple | None = -1,
160
+ block_size: int | None = None,
161
+ pad_tensor: bool = False,
162
+ ) -> tuple[jax.Array, jax.Array]:
163
+ """Quantize tensor.
164
+
165
+ Args:
166
+ dtype: dtype to perform quantization.
167
+ tensor: Unquantized tensor
168
+ axis: Axis to perform quantization. None denotes per-tensor.
169
+ block_size: Specify block quantization size.
170
+ pad_tensor: Whether to pad the axis along block size.
171
+
172
+ Returns:
173
+ Tensor quantized to dtype.
174
+ """
175
+ if axis is None:
176
+ # Perform per-tensor quantization.
177
+ axis = [i for i in range(tensor.ndim)]
178
+ if isinstance(axis, int):
179
+ axis = [axis]
180
+
181
+ orig_shape = tensor.shape
182
+ mask = jnp.ones_like(tensor, jnp.int32)
183
+
184
+ if block_size is not None:
185
+ if isinstance(block_size, int):
186
+ block_size = [block_size] * len(axis)
187
+
188
+ blocked_shape = [[i] for i in orig_shape]
189
+ pad_width = [[0, 0] for _ in range(tensor.ndim)]
190
+ for i, block in zip(axis, block_size):
191
+ num_blocks = (tensor.shape[i] + block - 1) // block
192
+ padding_size = num_blocks * block - tensor.shape[i]
193
+ if padding_size and not pad_tensor:
194
+ raise ValueError(
195
+ f"Unable to perform block quantization. axis={i} of "
196
+ f"{tensor.shape=} is not divisible by {block=}")
197
+
198
+ # Pad the tensor to align with block size.
199
+ pad_width[i][1] = padding_size
200
+
201
+ blocked_shape[i] = (num_blocks, block)
202
+
203
+ # In order to avoid padded values affecting scale value, we pad it
204
+ # using edge value of the tensor.
205
+ tensor = jnp.pad(tensor, pad_width, "edge")
206
+ mask = jnp.pad(mask, pad_width)
207
+
208
+ orig_shape = tensor.shape
209
+ # Convert all axis into positive values.
210
+ axis = sorted([i % tensor.ndim for i in axis])
211
+ # Shift axis by 1 since its original position is now occupied by
212
+ # num_blocks dim. Also, if n axes before an axis was also quantized,
213
+ # shift its position by n.
214
+ axis = [1 + n + i for n, i in enumerate(axis)]
215
+
216
+ # Flatten list of lists that contains (num_blocks, block).
217
+ blocked_shape = list(itertools.chain(*blocked_shape))
218
+ tensor = tensor.reshape(blocked_shape)
219
+
220
+ if jnp.issubdtype(dtype, jnp.integer):
221
+ dtype_info = jnp.iinfo(dtype)
222
+ else:
223
+ dtype_info = jnp.finfo(dtype)
224
+
225
+ dtype_max = float(dtype_info.max)
226
+ dtype_min = float(dtype_info.min)
227
+
228
+ abs_max = jnp.max(jnp.abs(tensor), axis=axis, keepdims=True)
229
+ scale = abs_max / dtype_max
230
+
231
+ tensor_q = jnp.clip(tensor / scale, dtype_min, dtype_max)
232
+ tensor_q = tensor_q.reshape(orig_shape)
233
+ tensor_q = tensor_q.astype(dtype)
234
+
235
+ # To avoid padded values affecting output of quantized matmul, we mask them
236
+ # out with 0s.
237
+ tensor_q = jnp.where(mask, tensor_q, 0)
238
+
239
+ scale = jnp.squeeze(scale, axis).astype(jnp.float32)
240
+
241
+ return tensor_q, scale
242
+
243
+
244
+ def static_per_tensor_quantize_tensor(
245
+ dtype: jnp.dtype,
246
+ tensor: jax.Array,
247
+ scale: float,
248
+ ) -> jax.Array:
249
+ if jnp.issubdtype(dtype, jnp.integer):
250
+ dtype_info = jnp.iinfo(dtype)
251
+ else:
252
+ dtype_info = jnp.finfo(dtype)
253
+
254
+ dtype_max = float(dtype_info.max)
255
+ dtype_min = float(dtype_info.min)
256
+
257
+ return jnp.clip(tensor / scale, dtype_min, dtype_max).astype(dtype)
258
+
259
+
260
+ def quantize_kv(
261
+ dtype: jnp.dtype,
262
+ key: jax.Array,
263
+ value: jax.Array,
264
+ k_scale: float,
265
+ v_scale: float,
266
+ ) -> Tuple[jax.Array, jax.Array]:
267
+ """Static quantize key and value tensors."""
268
+ key = static_per_tensor_quantize_tensor(dtype, key, k_scale)
269
+ value = static_per_tensor_quantize_tensor(dtype, value, v_scale)
270
+ return key, value
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import json
2
16
  import math
3
17
  from dataclasses import asdict, dataclass
@@ -26,7 +40,7 @@ class ShardingAxisNameBase:
26
40
  MLP_TENSOR = ('attn_dp', 'model', 'expert')
27
41
  MOE_TENSOR = ('attn_dp', 'model')
28
42
  EXPERT = ('attn_dp', 'expert', 'model')
29
- VOCAB = ('expert', 'model')
43
+ VOCAB = ('expert', 'attn_dp', 'model')
30
44
 
31
45
 
32
46
  class ShardingAxisName2D:
@@ -119,10 +133,19 @@ class ShardingConfigManager:
119
133
  False)
120
134
  if enable_dp_attention:
121
135
  # Replicate attention layer when num_kv_heads < TP
122
- num_kv_heads = vllm_config.model_config.get_total_num_kv_heads()
136
+ num_kv_heads = 1 if vllm_config.model_config.use_mla else vllm_config.model_config.get_total_num_kv_heads(
137
+ )
138
+ cache_dtype = vllm_config.cache_config.cache_dtype
139
+ if cache_dtype == 'auto':
140
+ cache_dtype = vllm_config.model_config.dtype
123
141
  kv_dtype = utils.get_jax_dtype_from_str_dtype(
124
- vllm_config.cache_config.cache_dtype) or jnp.bfloat16
142
+ cache_dtype) or jnp.bfloat16
125
143
  packing = 4 // jnp.dtype(kv_dtype).itemsize
144
+
145
+ # The default head dim is 128 but 64 is also supported as a special case.
146
+ if vllm_config.model_config.get_head_size() == 64:
147
+ packing *= 2
148
+
126
149
  # When num_kv_heads * 2 / packing < TP, tensor parallelism would
127
150
  # duplicate KV heads across devices, wasting kv cache memory.
128
151
  # Use attention DP instead to reduce per-device num_kv_heads and
@@ -168,8 +191,8 @@ class ShardingConfigManager:
168
191
  if sharding_strategy.attention_data_parallelism > 1:
169
192
  if not envs.NEW_MODEL_DESIGN:
170
193
  raise ValueError(
171
- "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
172
- "NEW_MODEL_DESIGN=True.")
194
+ "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set "
195
+ "NEW_MODEL_DESIGN=True")
173
196
 
174
197
  @property
175
198
  def total_dp_size(self) -> int:
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import InitVar, dataclass
2
16
  from typing import Any, Tuple
3
17
 
@@ -5,7 +19,6 @@ import jax
5
19
  import jax.numpy as jnp
6
20
  from flax import nnx
7
21
  from flax.typing import Sharding
8
- from jax.experimental import shard_map
9
22
  from jax.sharding import Mesh
10
23
  from jax.sharding import PartitionSpec as P
11
24
 
@@ -13,6 +26,7 @@ from tpu_inference import utils
13
26
  from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
14
27
  ragged_paged_attention
15
28
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29
+ from tpu_inference.layers.common.quantization import quantize_kv
16
30
  from tpu_inference.layers.common.sharding import ShardingAxisName
17
31
  from tpu_inference.layers.jax.base import create_param
18
32
  from tpu_inference.layers.jax.rope_interface import apply_rope
@@ -149,9 +163,8 @@ class Attention(nnx.Module):
149
163
  # q_scale = self._q_scale
150
164
  k_scale = self._k_scale
151
165
  v_scale = self._v_scale
152
- k_SKH, v_SKH = utils.quantize_kv(k_SKH, v_SKH,
153
- self.kv_cache_quantized_dtype,
154
- k_scale, v_scale)
166
+ k_SKH, v_SKH = quantize_kv(self.kv_cache_quantized_dtype, k_SKH,
167
+ v_SKH, k_scale, v_scale)
155
168
 
156
169
  with jax.named_scope("attn_op"):
157
170
  new_kv_cache, outputs_TNH = self.attention(
@@ -236,12 +249,12 @@ class Attention(nnx.Module):
236
249
  )
237
250
 
238
251
  output_TNH, kv_cache = jax.jit(
239
- shard_map.shard_map(
252
+ jax.shard_map(
240
253
  _ragged_paged_attention,
241
254
  mesh=mesh,
242
255
  in_specs=in_specs,
243
256
  out_specs=out_specs,
244
- check_rep=False,
257
+ check_vma=False,
245
258
  ))(
246
259
  q_TNH,
247
260
  k_SKH,