tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,205 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+ from absl.testing import absltest, parameterized
18
+ from jax._src import test_util as jtu
19
+
20
+ from tpu_inference.kernels.megablox.gmm import gmm
21
+
22
+ jax.config.parse_flags_with_absl()
23
+
24
+
25
+ def quantize_tensor(x: jax.Array,
26
+ dtype: jnp.dtype,
27
+ axis: int = -1,
28
+ block_size: int = 256):
29
+ if jnp.issubdtype(dtype, jnp.integer):
30
+ dtype_info = jnp.iinfo(dtype)
31
+ max_val = int(dtype_info.max)
32
+ min_val = int(dtype_info.min)
33
+ else:
34
+ dtype_info = jnp.finfo(dtype)
35
+ max_val = float(dtype_info.max)
36
+ min_val = float(dtype_info.min)
37
+
38
+ orig_shape = x.shape
39
+ blocked_shape = orig_shape[:axis] + (-1,
40
+ block_size) + orig_shape[axis + 1:]
41
+ x_blocked = x.reshape(blocked_shape)
42
+
43
+ x_blocked_abs_max = jnp.max(jnp.abs(x_blocked),
44
+ axis=axis + 1,
45
+ keepdims=True)
46
+ scale = x_blocked_abs_max / max_val
47
+ x_blocked_q = jnp.clip(x_blocked / scale, min_val, max_val).astype(dtype)
48
+
49
+ x_q = x_blocked_q.reshape(orig_shape)
50
+ scale = scale.squeeze(axis=axis + 1).astype(jnp.float32)
51
+ return x_q, scale
52
+
53
+
54
+ def reference_gmm(
55
+ lhs: jax.Array,
56
+ rhs: jax.Array,
57
+ group_sizes: jax.Array,
58
+ rhs_scale: jax.Array | None = None,
59
+ rhs_bias: jax.Array | None = None,
60
+ group_offset: jax.Array | None = None,
61
+ ):
62
+ num_groups, out_size, in_size = rhs.shape
63
+ assert lhs.shape[1] == in_size
64
+
65
+ if group_offset is None:
66
+ group_offset = jnp.array(0, dtype=jnp.int32)
67
+ start = group_sizes[:group_offset].sum()
68
+ group_sizes = group_sizes[group_offset:]
69
+ assert len(group_sizes) == num_groups
70
+
71
+ if rhs_scale is not None:
72
+ num_blocks = rhs_scale.shape[1]
73
+ else:
74
+ num_blocks = 1
75
+ block_size = in_size // num_blocks
76
+
77
+ gmm_out = [jnp.zeros((start, out_size), lhs.dtype)]
78
+ for group in range(num_groups):
79
+ end = start + group_sizes[group]
80
+
81
+ lhs_slice = lhs[start:end]
82
+ rhs_slice = rhs[group]
83
+
84
+ out = 0
85
+ for block in range(num_blocks):
86
+ block_start = block * block_size
87
+ block_end = block_start + block_size
88
+ lhs_block = lhs_slice[:, block_start:block_end].astype(jnp.float32)
89
+ rhs_block = rhs_slice[:, block_start:block_end].astype(jnp.float32)
90
+
91
+ acc = jnp.einsum("bd,hd->bh", lhs_block, rhs_block)
92
+ if rhs_scale is not None:
93
+ acc *= rhs_scale[group][block]
94
+ out += acc
95
+ if rhs_bias is not None:
96
+ out = out + rhs_bias[group]
97
+
98
+ gmm_out.append(out.astype(lhs.dtype))
99
+ start = end
100
+
101
+ return jnp.concat(gmm_out, axis=0)
102
+
103
+
104
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
105
+ class GmmTest(jtu.JaxTestCase):
106
+
107
+ @parameterized.product(
108
+ batch_size=[128],
109
+ in_size=[1024],
110
+ out_size=[1024],
111
+ num_groups=[16, 32],
112
+ has_bias=[True, False],
113
+ )
114
+ def test_gmm(self, batch_size, in_size, out_size, num_groups, has_bias):
115
+ key = jax.random.key(0)
116
+
117
+ lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
118
+ rhs = jax.random.normal(key, (num_groups, out_size, in_size),
119
+ dtype=jnp.bfloat16)
120
+ rhs_bias = None
121
+ if has_bias:
122
+ rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
123
+ dtype=jnp.bfloat16)
124
+
125
+ group_sizes = jax.random.randint(key, (num_groups, ),
126
+ 0,
127
+ batch_size,
128
+ dtype=jnp.int32)
129
+
130
+ expected = reference_gmm(lhs, rhs, group_sizes, rhs_bias=rhs_bias)
131
+ actual = gmm(
132
+ lhs,
133
+ rhs,
134
+ group_sizes,
135
+ rhs_bias=rhs_bias,
136
+ transpose_rhs=True,
137
+ preferred_element_type=jnp.bfloat16,
138
+ )
139
+
140
+ self.assertArraysAllClose(actual, expected)
141
+
142
+ @parameterized.product(
143
+ batch_size=[128],
144
+ in_size=[1024],
145
+ out_size=[1024],
146
+ num_groups=[16, 32],
147
+ has_bias=[True, False],
148
+ weight_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn],
149
+ block_size=[256, 512],
150
+ )
151
+ def test_gmm_weight_quantized(
152
+ self,
153
+ batch_size,
154
+ in_size,
155
+ out_size,
156
+ num_groups,
157
+ has_bias,
158
+ weight_dtype,
159
+ block_size,
160
+ ):
161
+ if weight_dtype == jnp.float4_e2m1fn and not jtu.is_device_tpu_at_least(
162
+ version=7):
163
+ self.skipTest("Expect TPUv7+")
164
+ key = jax.random.key(0)
165
+
166
+ lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
167
+ rhs = jax.random.normal(key, (num_groups, out_size, in_size),
168
+ dtype=jnp.bfloat16)
169
+ rhs_q, rhs_scale = quantize_tensor(rhs,
170
+ weight_dtype,
171
+ axis=2,
172
+ block_size=block_size)
173
+ rhs_scale = jnp.swapaxes(rhs_scale, 1, 2)
174
+ rhs_scale = jnp.expand_dims(rhs_scale, axis=2)
175
+
176
+ rhs_bias = None
177
+ if has_bias:
178
+ rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
179
+ dtype=jnp.bfloat16)
180
+
181
+ group_sizes = jax.random.randint(key, (num_groups, ),
182
+ 0,
183
+ batch_size,
184
+ dtype=jnp.int32)
185
+
186
+ expected = reference_gmm(lhs,
187
+ rhs_q,
188
+ group_sizes,
189
+ rhs_scale=rhs_scale,
190
+ rhs_bias=rhs_bias)
191
+ actual = gmm(
192
+ lhs,
193
+ rhs_q,
194
+ group_sizes,
195
+ rhs_scale=rhs_scale,
196
+ rhs_bias=rhs_bias,
197
+ transpose_rhs=True,
198
+ preferred_element_type=jnp.bfloat16,
199
+ )
200
+
201
+ self.assertArraysAllClose(actual, expected, atol=3e-1, rtol=3e-1)
202
+
203
+
204
+ if __name__ == "__main__":
205
+ absltest.main(testLoader=jtu.JaxTestLoader())
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np
@@ -42,6 +56,7 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
42
56
 
43
57
  padded_r_dim = align_to(r_dim, 128)
44
58
  padded_lkv_dim = align_to(lkv_dim, 128)
59
+ padded_kv_dim = padded_lkv_dim + padded_r_dim
45
60
  packing = get_dtype_packing(kv_dtype)
46
61
  q_lens = [s[0] for s in seq_lens]
47
62
  kv_lens_list = [s[1] for s in seq_lens]
@@ -69,13 +84,10 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
69
84
  new_kv_c = gen_random((total_q_len, lkv_dim), kv_dtype)
70
85
  new_k_pe = gen_random((total_q_len, r_dim), kv_dtype)
71
86
 
72
- cache_kv_c = gen_random(
73
- (total_num_pages, page_size // packing, packing, padded_lkv_dim),
87
+ cache_kv = gen_random(
88
+ (total_num_pages, page_size // packing, packing, padded_kv_dim),
74
89
  kv_dtype,
75
90
  )
76
- cache_k_pe = gen_random(
77
- (total_num_pages, page_size // packing, packing, padded_r_dim),
78
- kv_dtype)
79
91
  kv_lens = jnp.array(kv_lens_list, dtype=jnp.int32)
80
92
  page_indices = jnp.array(page_indices_list, dtype=jnp.int32)
81
93
  cu_q_lens = jnp.array(cu_q_lens_list, dtype=jnp.int32)
@@ -84,14 +96,13 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
84
96
  ql_nope_for_kernel = ql_nope.copy()
85
97
  q_pe_for_kernel = q_pe.copy()
86
98
 
87
- expected_out, expected_updated_kv_c, expeceted_updated_k_pe = (
99
+ expected_out, expected_updated_kv = (
88
100
  mla.ref_mla_ragged_paged_attention(
89
101
  ql_nope,
90
102
  q_pe,
91
103
  new_kv_c,
92
104
  new_k_pe,
93
- cache_kv_c.copy(),
94
- cache_k_pe.copy(),
105
+ cache_kv.copy(),
95
106
  kv_lens,
96
107
  page_indices,
97
108
  cu_q_lens,
@@ -101,50 +112,141 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
101
112
  soft_cap=soft_cap,
102
113
  ))
103
114
 
104
- kernel_out, kernel_updated_kv_c, kernel_updated_k_pe = (
105
- mla.mla_ragged_paged_attention(
106
- ql_nope_for_kernel,
107
- q_pe_for_kernel,
108
- new_kv_c,
109
- new_k_pe,
110
- cache_kv_c.copy(),
111
- cache_k_pe.copy(),
112
- kv_lens,
113
- page_indices,
114
- cu_q_lens,
115
- distribution,
116
- sm_scale=sm_scale,
117
- sliding_window=sliding_window,
118
- soft_cap=soft_cap,
119
- num_kv_pages_per_block=num_kv_pages_per_block,
120
- num_queries_per_block=num_queries_per_block,
121
- vmem_limit_bytes=vmem_limit_bytes,
122
- ))
115
+ kernel_out, kernel_updated_kv = (mla.mla_ragged_paged_attention(
116
+ ql_nope_for_kernel,
117
+ q_pe_for_kernel,
118
+ new_kv_c,
119
+ new_k_pe,
120
+ cache_kv.copy(),
121
+ kv_lens,
122
+ page_indices,
123
+ cu_q_lens,
124
+ distribution,
125
+ sm_scale=sm_scale,
126
+ sliding_window=sliding_window,
127
+ soft_cap=soft_cap,
128
+ num_kv_pages_per_block=num_kv_pages_per_block,
129
+ num_queries_per_block=num_queries_per_block,
130
+ vmem_limit_bytes=vmem_limit_bytes,
131
+ ))
123
132
 
124
133
  self.assertEqual(expected_out.shape,
125
134
  (total_q_len, num_heads, padded_lkv_dim))
126
135
  self.assertEqual(
127
- expected_updated_kv_c.shape,
128
- (total_num_pages, page_size // packing, packing, padded_lkv_dim),
129
- )
130
- self.assertEqual(
131
- expeceted_updated_k_pe.shape,
132
- (total_num_pages, page_size // packing, packing, padded_r_dim),
136
+ expected_updated_kv.shape,
137
+ (total_num_pages, page_size // packing, packing, padded_kv_dim),
133
138
  )
134
139
  self.assertEqual(expected_out.dtype, kv_dtype)
135
- self.assertEqual(expected_updated_kv_c.dtype, kv_dtype)
136
- self.assertEqual(expeceted_updated_k_pe.dtype, kv_dtype)
140
+ self.assertEqual(expected_updated_kv.dtype, kv_dtype)
137
141
 
138
142
  self.assertAllClose(expected_out, kernel_out, atol=0.2, rtol=0.2)
139
- self.assertAllClose(expected_updated_kv_c,
140
- kernel_updated_kv_c,
141
- atol=0.2,
142
- rtol=0.2)
143
- self.assertAllClose(expeceted_updated_k_pe,
144
- kernel_updated_k_pe,
143
+ self.assertAllClose(expected_updated_kv,
144
+ kernel_updated_kv,
145
145
  atol=0.2,
146
146
  rtol=0.2)
147
147
 
148
+ def test_update_kv_cache(self):
149
+ lkv_dim = 4
150
+ r_dim = 4
151
+ padded_lkv_dim = align_to(lkv_dim, 128)
152
+ padded_r_dim = align_to(r_dim, 128)
153
+ kv_dtype = jnp.bfloat16
154
+ new_kv_c = jnp.arange(16, dtype=kv_dtype).reshape((4, lkv_dim))
155
+ new_k_pe = (jnp.arange(16, dtype=kv_dtype).reshape((4, r_dim)) + 100)
156
+ total_num_pages = 2
157
+ page_size = 4
158
+ cache_kv_shape = mla.get_kv_cache_shape(
159
+ total_num_pages,
160
+ page_size,
161
+ padded_lkv_dim + padded_r_dim,
162
+ kv_dtype,
163
+ )
164
+ cache_kv = jnp.zeros(cache_kv_shape, dtype=kv_dtype)
165
+
166
+ # two sequences, first with 3 tokens, second with 1 token
167
+ kv_lens = jnp.array([3, 1], dtype=jnp.int32)
168
+ # first seq uses page 0, second uses page 1
169
+ page_indices = jnp.array([0, -1, 1, -1], dtype=jnp.int32)
170
+ # three tokens for first seq, one for second
171
+ cu_q_lens = jnp.array([0, 3, 4], dtype=jnp.int32)
172
+ distribution = jnp.array([0, 0, 2], dtype=jnp.int32)
173
+
174
+ # manually compute the expected cache
175
+ padded_new_kv_c = jnp.pad(new_kv_c,
176
+ ((0, 0), (0, padded_lkv_dim - lkv_dim)),
177
+ constant_values=0)
178
+ padded_new_k_pe = jnp.pad(new_k_pe,
179
+ ((0, 0), (0, padded_r_dim - r_dim)),
180
+ constant_values=0)
181
+
182
+ expected_cache = cache_kv
183
+ # First sequence
184
+ # token 0
185
+ page_idx, row, col = 0, 0, 0
186
+ expected_cache = expected_cache.at[page_idx, row,
187
+ col, :padded_lkv_dim].set(
188
+ padded_new_kv_c[0])
189
+ expected_cache = expected_cache.at[page_idx, row, col,
190
+ padded_lkv_dim:padded_lkv_dim +
191
+ padded_r_dim].set(
192
+ padded_new_k_pe[0])
193
+ # token 1
194
+ page_idx, row, col = 0, 0, 1
195
+ expected_cache = expected_cache.at[page_idx, row,
196
+ col, :padded_lkv_dim].set(
197
+ padded_new_kv_c[1])
198
+ expected_cache = expected_cache.at[page_idx, row, col,
199
+ padded_lkv_dim:padded_lkv_dim +
200
+ padded_r_dim].set(
201
+ padded_new_k_pe[1])
202
+ # token 2
203
+ page_idx, row, col = 0, 1, 0
204
+ expected_cache = expected_cache.at[page_idx, row,
205
+ col, :padded_lkv_dim].set(
206
+ padded_new_kv_c[2])
207
+ expected_cache = expected_cache.at[page_idx, row, col,
208
+ padded_lkv_dim:padded_lkv_dim +
209
+ padded_r_dim].set(
210
+ padded_new_k_pe[2])
211
+
212
+ # Second sequence
213
+ # token 0
214
+ page_idx, row, col = 1, 0, 0
215
+ expected_cache = expected_cache.at[page_idx, row,
216
+ col, :padded_lkv_dim].set(
217
+ padded_new_kv_c[3])
218
+ expected_cache = expected_cache.at[page_idx, row, col,
219
+ padded_lkv_dim:padded_lkv_dim +
220
+ padded_r_dim].set(
221
+ padded_new_k_pe[3])
222
+
223
+ updated_cache = mla.update_kv_cache(
224
+ new_kv_c,
225
+ new_k_pe,
226
+ cache_kv,
227
+ kv_lens,
228
+ page_indices,
229
+ cu_q_lens,
230
+ distribution,
231
+ )
232
+
233
+ self.assertAllClose(updated_cache, expected_cache)
234
+
235
+ def test_get_kv_cache_shape(self):
236
+ total_num_pages = 10
237
+ page_size = 16
238
+ lkv_dim = 128
239
+ kv_dtype = jnp.bfloat16
240
+ # The calculation for the expected shape is as follows:
241
+ # kv_packing is determined by the dtype, which is 2 for bfloat16.
242
+ # The second dimension is page_size / kv_packing = 16 / 2 = 8
243
+ # The third dimension is kv_packing = 2
244
+ # The fourth dimension is lkv_dim aligned to 128, which is 128
245
+ expected_shape = (10, 8, 2, 128)
246
+ self.assertEqual(
247
+ mla.get_kv_cache_shape(total_num_pages, page_size, lkv_dim,
248
+ kv_dtype), expected_shape)
249
+
148
250
  def test_ragged_paged_attention_basic(self):
149
251
  dtype = jnp.bfloat16
150
252
  seq_lens = [(192, 328), (128, 180), (64, 255)]
@@ -1,7 +1,5 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import functools
4
-
5
3
  import jax
6
4
  import jax.numpy as jnp
7
5
  from absl.testing import absltest, parameterized
@@ -10,6 +8,7 @@ from jax._src import test_util as jtu
10
8
  from tpu_inference.kernels.quantized_matmul import (kernel, tuned_block_sizes,
11
9
  util)
12
10
 
11
+ xla_quantized_matmul = kernel.xla_quantized_matmul
13
12
  quantized_matmul_kernel = kernel.quantized_matmul_kernel
14
13
  quantize_tensor = util.quantize_tensor
15
14
  get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
@@ -17,37 +16,6 @@ get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
17
16
  jax.config.parse_flags_with_absl()
18
17
 
19
18
 
20
- @functools.partial(jax.jit, static_argnames=["quantize_activation"])
21
- def reference_quantized_matmul(
22
- x: jax.Array,
23
- w_q: jax.Array,
24
- w_scale: jax.Array,
25
- quantize_activation=True,
26
- ):
27
- if quantize_activation:
28
- acc_dtype = jnp.float32
29
- if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
30
- acc_dtype = jnp.int32
31
-
32
- x_q, x_scale = quantize_tensor(x, w_q.dtype)
33
- out = jax.lax.dot_general(
34
- x_q,
35
- w_q,
36
- dimension_numbers=(((1, ), (1, )), ((), ())),
37
- preferred_element_type=acc_dtype,
38
- ).astype(jnp.float32)
39
- out *= x_scale
40
- else:
41
- out = jax.lax.dot_general(
42
- x,
43
- w_q,
44
- dimension_numbers=(((1, ), (1, )), ((), ())),
45
- preferred_element_type=jnp.float32,
46
- )
47
- out *= jnp.expand_dims(w_scale, 0)
48
- return out.astype(x.dtype)
49
-
50
-
51
19
  @jtu.with_config(jax_numpy_dtype_promotion="standard")
52
20
  class QuantizedMatmulKernelTest(jtu.JaxTestCase):
53
21
 
@@ -94,7 +62,7 @@ class QuantizedMatmulKernelTest(jtu.JaxTestCase):
94
62
  x_q_dtype=x_q_dtype,
95
63
  tuned_value=tuned_value,
96
64
  )
97
- expected = reference_quantized_matmul(
65
+ expected = xla_quantized_matmul(
98
66
  x, w_q, w_scale, quantize_activation=quantize_activation)
99
67
 
100
68
  self.assertAllClose(output,
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import random
2
16
 
3
17
  import jax
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np
@@ -176,7 +190,9 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
176
190
  )
177
191
  output = output[:cu_q_lens[distribution[-1]]]
178
192
 
179
- dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
193
+ dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
194
+ dtypes, "bit_width") else dtypes.itemsize_bits(
195
+ jnp.dtype(kv_dtype)))
180
196
  tols = {
181
197
  32: 0.15,
182
198
  16: 0.2,
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np
@@ -162,7 +176,9 @@ class RaggedPagedAttentionKernelTest(jtu.JaxTestCase):
162
176
  )
163
177
  output = output[:cu_q_lens[distribution[-1]]]
164
178
 
165
- dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
179
+ dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
180
+ dtypes, "bit_width") else dtypes.itemsize_bits(
181
+ jnp.dtype(kv_dtype)))
166
182
  tols = {
167
183
  32: 0.15,
168
184
  16: 0.2,
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.