tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,282 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import itertools
16
+ from typing import Tuple
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+
21
+ MXFP4_BLOCK_SIZE = 32
22
+
23
+
24
+ def quantize_tensor_to_mxfp4_packed(
25
+ tensor: jax.Array,
26
+ axis: int | tuple = -1,
27
+ ) -> Tuple[jax.Array, jax.Array]:
28
+ """Quantize a tensor to mxfp4 and pack it into uint8."""
29
+
30
+ # Perform regular block quantization.
31
+ tensor_q, scale = quantize_tensor(
32
+ jnp.float4_e2m1fn,
33
+ tensor,
34
+ axis,
35
+ MXFP4_BLOCK_SIZE,
36
+ )
37
+
38
+ # last two e2m1 elements will be packed into a single uint8 element.
39
+ bitcast_shape = tensor_q.shape[:-1] + (-1, 2)
40
+ tensor_q = tensor_q.reshape(bitcast_shape)
41
+ tensor_q_packed = jax.lax.bitcast_convert_type(tensor_q, jnp.uint8)
42
+
43
+ # Since TPU does not have native support for e8m0, we convert scale into
44
+ # e8m0 manually and store it as uint8.
45
+ e8m0_finfo = jnp.finfo(jnp.float8_e8m0fnu)
46
+ _, scale_exp = jnp.frexp(scale)
47
+ # Subtract exponents by one since e8m0 has no decimal.
48
+ scale_exp -= 1
49
+ scale_exp = (scale_exp - e8m0_finfo.minexp).astype(jnp.uint8)
50
+
51
+ return tensor_q_packed, scale_exp
52
+
53
+
54
+ def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
55
+ """Unpack e2m1 tensor that was packed into u8."""
56
+ assert u8_packed_e2m1.dtype == jnp.uint8
57
+ e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
58
+ # bitcast creates one more dimension that splits 8 bits into two e2m1.
59
+ # we flatten them with the last dim.
60
+ return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
61
+
62
+
63
+ def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
64
+ """Convert e8m0 (that was bitcasted to u8) into fp32."""
65
+ assert u8.dtype == jnp.uint8
66
+
67
+ e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
68
+ exponents = u8.astype(jnp.int32) + e8_finfo.minexp
69
+ ones = jnp.ones_like(u8, dtype=jnp.float32)
70
+ return jnp.ldexp(ones, exponents)
71
+
72
+
73
+ def awq_u32_unpack_u4(awq_u32_packed: jax.Array) -> jax.Array:
74
+ """Unpack u4 tensor that was packed into u32 in awq ordering."""
75
+
76
+ awq_u4 = jax.lax.bitcast_convert_type(awq_u32_packed, jnp.uint4)
77
+
78
+ # AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
79
+ # Following list maps the order used by AWQ into an ascending order.
80
+ reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
81
+ u4 = awq_u4[..., reverse_awq_order]
82
+ return jnp.reshape(u4, u4.shape[:-2] + (-1, ))
83
+
84
+
85
+ def dequantize_tensor(
86
+ tensor_q: jax.Array,
87
+ scale: jax.Array,
88
+ axis: int | None | tuple = -1,
89
+ out_dtype: jnp.dtype = jnp.bfloat16,
90
+ ) -> jax.Array:
91
+ """Dequantize a quantized tensor
92
+
93
+ Args:
94
+ tensor_q: Quantized tensor.
95
+ scale: Quantization scale.
96
+ axis: The axis tensor was quantized. None denotes per-tensor.
97
+ out_dtype: Dtype of the output.
98
+
99
+ Returns:
100
+ Dequantized tensor_q.
101
+ """
102
+ if axis is None:
103
+ # Perform per-tensor quantization.
104
+ axis = [i for i in range(tensor_q.ndim)]
105
+ if isinstance(axis, int):
106
+ axis = [axis]
107
+
108
+ orig_shape = tensor_q.shape
109
+ if tensor_q.ndim == scale.ndim:
110
+ # Indicates the tensor was block quantized.
111
+ blocked_shape = [[i] for i in orig_shape]
112
+ for i in axis:
113
+ num_blocks = scale.shape[i]
114
+ if tensor_q.shape[i] % num_blocks:
115
+ raise ValueError(
116
+ f"Unable to perform block dequantization. axis={i} of "
117
+ f"{tensor_q.shape=} is not divisible by {num_blocks=}", )
118
+ block_size = tensor_q.shape[i] // num_blocks
119
+
120
+ blocked_shape[i] = (num_blocks, block_size)
121
+
122
+ # Convert all axis into positive values.
123
+ axis = sorted([(i + tensor_q.ndim) % tensor_q.ndim for i in axis])
124
+ # Shift axis by 1 since its original position is now occupied by
125
+ # num_blocks dim. Also, if n axes before an axis was also quantized,
126
+ # shift its position by n.
127
+ axis = [1 + n + i for n, i in enumerate(axis)]
128
+
129
+ # Flatten list of lists that contains (num_blocks, block).
130
+ blocked_shape = list(itertools.chain(*blocked_shape))
131
+ tensor_q = tensor_q.reshape(blocked_shape)
132
+
133
+ scale = jnp.expand_dims(scale, axis)
134
+
135
+ tensor = (tensor_q.astype(jnp.float32) * scale).astype(out_dtype)
136
+
137
+ return tensor.reshape(orig_shape)
138
+
139
+
140
+ def dequantize_tensor_from_mxfp4_packed(
141
+ tensor_q: jax.Array,
142
+ scale: jax.Array,
143
+ axis: int | tuple = -1,
144
+ out_dtype: jnp.dtype = jnp.bfloat16,
145
+ ) -> jax.Array:
146
+ """Dequantize packed mxfp4 tensor.
147
+
148
+ Args:
149
+ tensor_q: fp4 tensor packed into uint8.
150
+ scale: e8m0 scale packed into uint8.
151
+ axis: The axis tensor was quantized.
152
+ out_dtype: Dtype of the output.
153
+
154
+ Returns:
155
+ Dequantized tensor_q.
156
+ """
157
+ tensor_e2m1 = u8_unpack_e2m1(tensor_q)
158
+ scale_fp32 = e8m0_to_fp32(scale)
159
+
160
+ return dequantize_tensor(
161
+ tensor_e2m1,
162
+ scale_fp32,
163
+ axis,
164
+ out_dtype,
165
+ )
166
+
167
+
168
+ def quantize_tensor(
169
+ dtype: jnp.dtype,
170
+ tensor: jax.Array,
171
+ axis: int | tuple | None = -1,
172
+ block_size: int | None = None,
173
+ pad_tensor: bool = False,
174
+ ) -> tuple[jax.Array, jax.Array]:
175
+ """Quantize tensor.
176
+
177
+ Args:
178
+ dtype: dtype to perform quantization.
179
+ tensor: Unquantized tensor
180
+ axis: Axis to perform quantization. None denotes per-tensor.
181
+ block_size: Specify block quantization size.
182
+ pad_tensor: Whether to pad the axis along block size.
183
+
184
+ Returns:
185
+ Tensor quantized to dtype.
186
+ """
187
+ if axis is None:
188
+ # Perform per-tensor quantization.
189
+ axis = [i for i in range(tensor.ndim)]
190
+ if isinstance(axis, int):
191
+ axis = [axis]
192
+
193
+ orig_shape = tensor.shape
194
+ mask = jnp.ones_like(tensor, jnp.int32)
195
+
196
+ if block_size is not None:
197
+ if isinstance(block_size, int):
198
+ block_size = [block_size] * len(axis)
199
+
200
+ blocked_shape = [[i] for i in orig_shape]
201
+ pad_width = [[0, 0] for _ in range(tensor.ndim)]
202
+ for i, block in zip(axis, block_size):
203
+ num_blocks = (tensor.shape[i] + block - 1) // block
204
+ padding_size = num_blocks * block - tensor.shape[i]
205
+ if padding_size and not pad_tensor:
206
+ raise ValueError(
207
+ f"Unable to perform block quantization. axis={i} of "
208
+ f"{tensor.shape=} is not divisible by {block=}")
209
+
210
+ # Pad the tensor to align with block size.
211
+ pad_width[i][1] = padding_size
212
+
213
+ blocked_shape[i] = (num_blocks, block)
214
+
215
+ # In order to avoid padded values affecting scale value, we pad it
216
+ # using edge value of the tensor.
217
+ tensor = jnp.pad(tensor, pad_width, "edge")
218
+ mask = jnp.pad(mask, pad_width)
219
+
220
+ orig_shape = tensor.shape
221
+ # Convert all axis into positive values.
222
+ axis = sorted([i % tensor.ndim for i in axis])
223
+ # Shift axis by 1 since its original position is now occupied by
224
+ # num_blocks dim. Also, if n axes before an axis was also quantized,
225
+ # shift its position by n.
226
+ axis = [1 + n + i for n, i in enumerate(axis)]
227
+
228
+ # Flatten list of lists that contains (num_blocks, block).
229
+ blocked_shape = list(itertools.chain(*blocked_shape))
230
+ tensor = tensor.reshape(blocked_shape)
231
+
232
+ if jnp.issubdtype(dtype, jnp.integer):
233
+ dtype_info = jnp.iinfo(dtype)
234
+ else:
235
+ dtype_info = jnp.finfo(dtype)
236
+
237
+ dtype_max = float(dtype_info.max)
238
+ dtype_min = float(dtype_info.min)
239
+
240
+ abs_max = jnp.max(jnp.abs(tensor), axis=axis, keepdims=True)
241
+ scale = abs_max / dtype_max
242
+
243
+ tensor_q = jnp.clip(tensor / scale, dtype_min, dtype_max)
244
+ tensor_q = tensor_q.reshape(orig_shape)
245
+ tensor_q = tensor_q.astype(dtype)
246
+
247
+ # To avoid padded values affecting output of quantized matmul, we mask them
248
+ # out with 0s.
249
+ tensor_q = jnp.where(mask, tensor_q, 0)
250
+
251
+ scale = jnp.squeeze(scale, axis).astype(jnp.float32)
252
+
253
+ return tensor_q, scale
254
+
255
+
256
+ def static_per_tensor_quantize_tensor(
257
+ dtype: jnp.dtype,
258
+ tensor: jax.Array,
259
+ scale: float,
260
+ ) -> jax.Array:
261
+ if jnp.issubdtype(dtype, jnp.integer):
262
+ dtype_info = jnp.iinfo(dtype)
263
+ else:
264
+ dtype_info = jnp.finfo(dtype)
265
+
266
+ dtype_max = float(dtype_info.max)
267
+ dtype_min = float(dtype_info.min)
268
+
269
+ return jnp.clip(tensor / scale, dtype_min, dtype_max).astype(dtype)
270
+
271
+
272
+ def quantize_kv(
273
+ dtype: jnp.dtype,
274
+ key: jax.Array,
275
+ value: jax.Array,
276
+ k_scale: float,
277
+ v_scale: float,
278
+ ) -> Tuple[jax.Array, jax.Array]:
279
+ """Static quantize key and value tensors."""
280
+ key = static_per_tensor_quantize_tensor(dtype, key, k_scale)
281
+ value = static_per_tensor_quantize_tensor(dtype, value, v_scale)
282
+ return key, value
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import json
2
16
  import math
3
17
  from dataclasses import asdict, dataclass
@@ -26,7 +40,7 @@ class ShardingAxisNameBase:
26
40
  MLP_TENSOR = ('attn_dp', 'model', 'expert')
27
41
  MOE_TENSOR = ('attn_dp', 'model')
28
42
  EXPERT = ('attn_dp', 'expert', 'model')
29
- VOCAB = ('expert', 'model')
43
+ VOCAB = ('expert', 'attn_dp', 'model')
30
44
 
31
45
 
32
46
  class ShardingAxisName2D:
@@ -127,6 +141,11 @@ class ShardingConfigManager:
127
141
  kv_dtype = utils.get_jax_dtype_from_str_dtype(
128
142
  cache_dtype) or jnp.bfloat16
129
143
  packing = 4 // jnp.dtype(kv_dtype).itemsize
144
+
145
+ # The default head dim is 128 but 64 is also supported as a special case.
146
+ if vllm_config.model_config.get_head_size() == 64:
147
+ packing *= 2
148
+
130
149
  # When num_kv_heads * 2 / packing < TP, tensor parallelism would
131
150
  # duplicate KV heads across devices, wasting kv cache memory.
132
151
  # Use attention DP instead to reduce per-device num_kv_heads and
@@ -172,8 +191,8 @@ class ShardingConfigManager:
172
191
  if sharding_strategy.attention_data_parallelism > 1:
173
192
  if not envs.NEW_MODEL_DESIGN:
174
193
  raise ValueError(
175
- "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
176
- "NEW_MODEL_DESIGN=True.")
194
+ "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set "
195
+ "NEW_MODEL_DESIGN=True")
177
196
 
178
197
  @property
179
198
  def total_dp_size(self) -> int:
@@ -0,0 +1,94 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+
18
+
19
+ def reorder_concatenated_tensor_for_sharding(concatenated_tensor: jax.Array,
20
+ split_sizes: list[int],
21
+ n_shards: int, dim: int):
22
+ """
23
+ Reorder a replicated concatenated tensor such that when sharded on multiple chips, each shard is a concatenation of the shards of the individual tensors.
24
+ For example, let the concatenated_tensor be:
25
+ AAAAAAAAAAAABBBBBBBBCCCC
26
+ 12 As 8 Bs 4 Cs
27
+ and let the split_sizes = [12, 8, 4] and n_shards = 4.
28
+ The output is:
29
+ AAABBCAAABBCAAABBCAAABBC
30
+ In other words, it reorders the input tensor into 4 segements, with each segment corresponding to a shard and being AAABBC.
31
+ Args:
32
+ concatenated_tensor: the tensor, concatenated on the dimension specified by `dim`.
33
+ split_sizes: each individual tensor's size on the dimension specified by `dim`.
34
+ n_shards: num of shards.
35
+ dim: the dimension on which the concatenated_tensor is concatenated.
36
+ """
37
+ # Split the concatenated tensor into individual tensors.
38
+ if dim < 0:
39
+ dim += concatenated_tensor.ndim
40
+ split_tensors = []
41
+ start_offset = 0
42
+ old_shape = concatenated_tensor.shape
43
+ # New shape ensures each split_tensor[i] maps to a tensor in ith shards
44
+ new_shape = old_shape[:dim] + (n_shards, -1) + old_shape[dim + 1:]
45
+ for split_size in split_sizes:
46
+ split_tensor = jax.lax.slice_in_dim(concatenated_tensor,
47
+ start_offset,
48
+ start_offset + split_size,
49
+ axis=dim)
50
+ split_tensors.append(split_tensor.reshape(new_shape))
51
+ start_offset += split_size
52
+ # While maintaining 0th dim as a shard dim, we concatenate along 1th dim to
53
+ # to create concatenated tnensor where 0th dim maps to shard dim.
54
+ reordered_tensor = jnp.concatenate(split_tensors, axis=dim + 1)
55
+ return reordered_tensor.reshape(old_shape)
56
+
57
+
58
+ def slice_sharded_tensor_for_concatenation(sharded_tensor: jax.Array,
59
+ split_sizes: list[int],
60
+ n_shards: int):
61
+ """
62
+ Slice the input tensor which is sharded on multiple chips (on the last dim) into individual tensors with the same sharding.
63
+ For example, let the sharded_tensor be:
64
+ AAABBC | AAABBC | AAABBC | AAABBC
65
+ Shard0 Shard1 Shard2 Shard3
66
+ and let the split_sizes = [12, 8, 4] and n_shards = 4.
67
+ The output is a list of 3 tensors:
68
+ AAA | AAA | AAA | AAA
69
+ BB | BB | BB | BB
70
+ C | C | C | C
71
+ Shard0 Shard1 Shard2 Shard3
72
+ In other words, each individual tensor is a slice of the input tensor with the same sharding.
73
+ Args:
74
+ sharded_tensor: the input tensor, sharded on the last dim.
75
+ split_sizes: each individual tensor's size on the last dim.
76
+ n_shards: num of shards.
77
+ """
78
+ new_shape = sharded_tensor.shape[:-1] + (n_shards, -1)
79
+ # New shape ensures each sharded_tensor[:, i] maps to a tensor in ith shards
80
+ sharded_tensor = sharded_tensor.reshape(new_shape)
81
+
82
+ split_tensors = []
83
+ start_offset = 0
84
+ for split_size in split_sizes:
85
+ assert split_size % n_shards == 0
86
+ sz = split_size // n_shards # size of this split tensor per shard
87
+ end_offset = start_offset + sz
88
+ # Because we are slicing over last dim, sharding dim remains intact.
89
+ # Therefore, splitting happens locally.
90
+ split_tensor = sharded_tensor[..., start_offset:end_offset]
91
+ split_tensors.append(split_tensor.reshape(new_shape[:-2] + (-1, )))
92
+ start_offset = end_offset
93
+
94
+ return split_tensors
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import InitVar, dataclass
2
16
  from typing import Any, Tuple
3
17
 
@@ -5,7 +19,6 @@ import jax
5
19
  import jax.numpy as jnp
6
20
  from flax import nnx
7
21
  from flax.typing import Sharding
8
- from jax.experimental import shard_map
9
22
  from jax.sharding import Mesh
10
23
  from jax.sharding import PartitionSpec as P
11
24
 
@@ -13,6 +26,7 @@ from tpu_inference import utils
13
26
  from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
14
27
  ragged_paged_attention
15
28
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29
+ from tpu_inference.layers.common.quantization import quantize_kv
16
30
  from tpu_inference.layers.common.sharding import ShardingAxisName
17
31
  from tpu_inference.layers.jax.base import create_param
18
32
  from tpu_inference.layers.jax.rope_interface import apply_rope
@@ -149,9 +163,8 @@ class Attention(nnx.Module):
149
163
  # q_scale = self._q_scale
150
164
  k_scale = self._k_scale
151
165
  v_scale = self._v_scale
152
- k_SKH, v_SKH = utils.quantize_kv(k_SKH, v_SKH,
153
- self.kv_cache_quantized_dtype,
154
- k_scale, v_scale)
166
+ k_SKH, v_SKH = quantize_kv(self.kv_cache_quantized_dtype, k_SKH,
167
+ v_SKH, k_scale, v_scale)
155
168
 
156
169
  with jax.named_scope("attn_op"):
157
170
  new_kv_cache, outputs_TNH = self.attention(
@@ -236,12 +249,12 @@ class Attention(nnx.Module):
236
249
  )
237
250
 
238
251
  output_TNH, kv_cache = jax.jit(
239
- shard_map.shard_map(
252
+ jax.shard_map(
240
253
  _ragged_paged_attention,
241
254
  mesh=mesh,
242
255
  in_specs=in_specs,
243
256
  out_specs=out_specs,
244
- check_rep=False,
257
+ check_vma=False,
245
258
  ))(
246
259
  q_TNH,
247
260
  k_SKH,