tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,205 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+ from absl.testing import absltest, parameterized
18
+ from jax._src import test_util as jtu
19
+
20
+ from tpu_inference.kernels.megablox.gmm import gmm
21
+
22
+ jax.config.parse_flags_with_absl()
23
+
24
+
25
+ def quantize_tensor(x: jax.Array,
26
+ dtype: jnp.dtype,
27
+ axis: int = -1,
28
+ block_size: int = 256):
29
+ if jnp.issubdtype(dtype, jnp.integer):
30
+ dtype_info = jnp.iinfo(dtype)
31
+ max_val = int(dtype_info.max)
32
+ min_val = int(dtype_info.min)
33
+ else:
34
+ dtype_info = jnp.finfo(dtype)
35
+ max_val = float(dtype_info.max)
36
+ min_val = float(dtype_info.min)
37
+
38
+ orig_shape = x.shape
39
+ blocked_shape = orig_shape[:axis] + (-1,
40
+ block_size) + orig_shape[axis + 1:]
41
+ x_blocked = x.reshape(blocked_shape)
42
+
43
+ x_blocked_abs_max = jnp.max(jnp.abs(x_blocked),
44
+ axis=axis + 1,
45
+ keepdims=True)
46
+ scale = x_blocked_abs_max / max_val
47
+ x_blocked_q = jnp.clip(x_blocked / scale, min_val, max_val).astype(dtype)
48
+
49
+ x_q = x_blocked_q.reshape(orig_shape)
50
+ scale = scale.squeeze(axis=axis + 1).astype(jnp.float32)
51
+ return x_q, scale
52
+
53
+
54
+ def reference_gmm(
55
+ lhs: jax.Array,
56
+ rhs: jax.Array,
57
+ group_sizes: jax.Array,
58
+ rhs_scale: jax.Array | None = None,
59
+ rhs_bias: jax.Array | None = None,
60
+ group_offset: jax.Array | None = None,
61
+ ):
62
+ num_groups, out_size, in_size = rhs.shape
63
+ assert lhs.shape[1] == in_size
64
+
65
+ if group_offset is None:
66
+ group_offset = jnp.array(0, dtype=jnp.int32)
67
+ start = group_sizes[:group_offset].sum()
68
+ group_sizes = group_sizes[group_offset:]
69
+ assert len(group_sizes) == num_groups
70
+
71
+ if rhs_scale is not None:
72
+ num_blocks = rhs_scale.shape[1]
73
+ else:
74
+ num_blocks = 1
75
+ block_size = in_size // num_blocks
76
+
77
+ gmm_out = [jnp.zeros((start, out_size), lhs.dtype)]
78
+ for group in range(num_groups):
79
+ end = start + group_sizes[group]
80
+
81
+ lhs_slice = lhs[start:end]
82
+ rhs_slice = rhs[group]
83
+
84
+ out = 0
85
+ for block in range(num_blocks):
86
+ block_start = block * block_size
87
+ block_end = block_start + block_size
88
+ lhs_block = lhs_slice[:, block_start:block_end].astype(jnp.float32)
89
+ rhs_block = rhs_slice[:, block_start:block_end].astype(jnp.float32)
90
+
91
+ acc = jnp.einsum("bd,hd->bh", lhs_block, rhs_block)
92
+ if rhs_scale is not None:
93
+ acc *= rhs_scale[group][block]
94
+ out += acc
95
+ if rhs_bias is not None:
96
+ out = out + rhs_bias[group]
97
+
98
+ gmm_out.append(out.astype(lhs.dtype))
99
+ start = end
100
+
101
+ return jnp.concat(gmm_out, axis=0)
102
+
103
+
104
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
105
+ class GmmTest(jtu.JaxTestCase):
106
+
107
+ @parameterized.product(
108
+ batch_size=[128],
109
+ in_size=[1024],
110
+ out_size=[1024],
111
+ num_groups=[16, 32],
112
+ has_bias=[True, False],
113
+ )
114
+ def test_gmm(self, batch_size, in_size, out_size, num_groups, has_bias):
115
+ key = jax.random.key(0)
116
+
117
+ lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
118
+ rhs = jax.random.normal(key, (num_groups, out_size, in_size),
119
+ dtype=jnp.bfloat16)
120
+ rhs_bias = None
121
+ if has_bias:
122
+ rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
123
+ dtype=jnp.bfloat16)
124
+
125
+ group_sizes = jax.random.randint(key, (num_groups, ),
126
+ 0,
127
+ batch_size,
128
+ dtype=jnp.int32)
129
+
130
+ expected = reference_gmm(lhs, rhs, group_sizes, rhs_bias=rhs_bias)
131
+ actual = gmm(
132
+ lhs,
133
+ rhs,
134
+ group_sizes,
135
+ rhs_bias=rhs_bias,
136
+ transpose_rhs=True,
137
+ preferred_element_type=jnp.bfloat16,
138
+ )
139
+
140
+ self.assertArraysAllClose(actual, expected)
141
+
142
+ @parameterized.product(
143
+ batch_size=[128],
144
+ in_size=[1024],
145
+ out_size=[1024],
146
+ num_groups=[16, 32],
147
+ has_bias=[True, False],
148
+ weight_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn],
149
+ block_size=[256, 512],
150
+ )
151
+ def test_gmm_weight_quantized(
152
+ self,
153
+ batch_size,
154
+ in_size,
155
+ out_size,
156
+ num_groups,
157
+ has_bias,
158
+ weight_dtype,
159
+ block_size,
160
+ ):
161
+ if weight_dtype == jnp.float4_e2m1fn and not jtu.is_device_tpu_at_least(
162
+ version=7):
163
+ self.skipTest("Expect TPUv7+")
164
+ key = jax.random.key(0)
165
+
166
+ lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
167
+ rhs = jax.random.normal(key, (num_groups, out_size, in_size),
168
+ dtype=jnp.bfloat16)
169
+ rhs_q, rhs_scale = quantize_tensor(rhs,
170
+ weight_dtype,
171
+ axis=2,
172
+ block_size=block_size)
173
+ rhs_scale = jnp.swapaxes(rhs_scale, 1, 2)
174
+ rhs_scale = jnp.expand_dims(rhs_scale, axis=2)
175
+
176
+ rhs_bias = None
177
+ if has_bias:
178
+ rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
179
+ dtype=jnp.bfloat16)
180
+
181
+ group_sizes = jax.random.randint(key, (num_groups, ),
182
+ 0,
183
+ batch_size,
184
+ dtype=jnp.int32)
185
+
186
+ expected = reference_gmm(lhs,
187
+ rhs_q,
188
+ group_sizes,
189
+ rhs_scale=rhs_scale,
190
+ rhs_bias=rhs_bias)
191
+ actual = gmm(
192
+ lhs,
193
+ rhs_q,
194
+ group_sizes,
195
+ rhs_scale=rhs_scale,
196
+ rhs_bias=rhs_bias,
197
+ transpose_rhs=True,
198
+ preferred_element_type=jnp.bfloat16,
199
+ )
200
+
201
+ self.assertArraysAllClose(actual, expected, atol=3e-1, rtol=3e-1)
202
+
203
+
204
+ if __name__ == "__main__":
205
+ absltest.main(testLoader=jtu.JaxTestLoader())
@@ -0,0 +1,498 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+ import numpy as np
18
+ from absl.testing import absltest, parameterized
19
+ from jax._src import test_util as jtu
20
+
21
+ import tpu_inference.kernels.mla.v1.kernel as mla
22
+ from tpu_inference.kernels.ragged_paged_attention.v3.util import (
23
+ align_to, cdiv, get_dtype_packing)
24
+
25
+ jax.config.parse_flags_with_absl()
26
+
27
+
28
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
29
+ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
30
+
31
+ def _test_mla_ragged_paged_attention(
32
+ self,
33
+ seq_lens, # List[(q_len, kv_len)]
34
+ num_heads,
35
+ lkv_dim,
36
+ r_dim,
37
+ page_size,
38
+ q_dtype,
39
+ kv_dtype,
40
+ num_pages,
41
+ *,
42
+ num_kv_pages_per_block=8,
43
+ num_queries_per_block=8,
44
+ vmem_limit_bytes=100 * 1024 * 1024,
45
+ sm_scale=1.0,
46
+ sliding_window: int | None = None,
47
+ soft_cap: float | None = None,
48
+ ):
49
+ if not jtu.is_device_tpu_at_least(version=4):
50
+ self.skipTest("Expect TPUv4+")
51
+ rng = np.random.default_rng(1234)
52
+
53
+ def gen_random(shape, dtype):
54
+ return jnp.array(rng.random(size=shape,
55
+ dtype=np.float32)).astype(dtype)
56
+
57
+ padded_r_dim = align_to(r_dim, 128)
58
+ padded_lkv_dim = align_to(lkv_dim, 128)
59
+ padded_kv_dim = padded_lkv_dim + padded_r_dim
60
+ packing = get_dtype_packing(kv_dtype)
61
+ q_lens = [s[0] for s in seq_lens]
62
+ kv_lens_list = [s[1] for s in seq_lens]
63
+ total_q_len = sum(q_lens)
64
+ cu_q_lens_list = [0]
65
+ for q_len in q_lens:
66
+ cu_q_lens_list.append(cu_q_lens_list[-1] + q_len)
67
+
68
+ max_kv_len = max(kv_lens_list) if kv_lens_list else 0
69
+ pages_per_seq = cdiv(max_kv_len, page_size)
70
+
71
+ page_indices_list = []
72
+ page_count = 0
73
+ for kv_len in kv_lens_list:
74
+ num_seq_pages = cdiv(kv_len, page_size)
75
+ indices = list(range(page_count, page_count + num_seq_pages))
76
+ page_indices_list.extend(indices + [-1] *
77
+ (pages_per_seq - num_seq_pages))
78
+ page_count += num_seq_pages
79
+
80
+ total_num_pages = max(num_pages, page_count)
81
+
82
+ ql_nope = gen_random((total_q_len, num_heads, lkv_dim), q_dtype)
83
+ q_pe = gen_random((total_q_len, num_heads, r_dim), q_dtype)
84
+ new_kv_c = gen_random((total_q_len, lkv_dim), kv_dtype)
85
+ new_k_pe = gen_random((total_q_len, r_dim), kv_dtype)
86
+
87
+ cache_kv = gen_random(
88
+ (total_num_pages, page_size // packing, packing, padded_kv_dim),
89
+ kv_dtype,
90
+ )
91
+ kv_lens = jnp.array(kv_lens_list, dtype=jnp.int32)
92
+ page_indices = jnp.array(page_indices_list, dtype=jnp.int32)
93
+ cu_q_lens = jnp.array(cu_q_lens_list, dtype=jnp.int32)
94
+ distribution = jnp.array([0, 0, len(seq_lens)], dtype=jnp.int32)
95
+
96
+ ql_nope_for_kernel = ql_nope.copy()
97
+ q_pe_for_kernel = q_pe.copy()
98
+
99
+ expected_out, expected_updated_kv = (
100
+ mla.ref_mla_ragged_paged_attention(
101
+ ql_nope,
102
+ q_pe,
103
+ new_kv_c,
104
+ new_k_pe,
105
+ cache_kv.copy(),
106
+ kv_lens,
107
+ page_indices,
108
+ cu_q_lens,
109
+ distribution,
110
+ sm_scale=sm_scale,
111
+ sliding_window=sliding_window,
112
+ soft_cap=soft_cap,
113
+ ))
114
+
115
+ kernel_out, kernel_updated_kv = (mla.mla_ragged_paged_attention(
116
+ ql_nope_for_kernel,
117
+ q_pe_for_kernel,
118
+ new_kv_c,
119
+ new_k_pe,
120
+ cache_kv.copy(),
121
+ kv_lens,
122
+ page_indices,
123
+ cu_q_lens,
124
+ distribution,
125
+ sm_scale=sm_scale,
126
+ sliding_window=sliding_window,
127
+ soft_cap=soft_cap,
128
+ num_kv_pages_per_block=num_kv_pages_per_block,
129
+ num_queries_per_block=num_queries_per_block,
130
+ vmem_limit_bytes=vmem_limit_bytes,
131
+ ))
132
+
133
+ self.assertEqual(expected_out.shape,
134
+ (total_q_len, num_heads, padded_lkv_dim))
135
+ self.assertEqual(
136
+ expected_updated_kv.shape,
137
+ (total_num_pages, page_size // packing, packing, padded_kv_dim),
138
+ )
139
+ self.assertEqual(expected_out.dtype, kv_dtype)
140
+ self.assertEqual(expected_updated_kv.dtype, kv_dtype)
141
+
142
+ self.assertAllClose(expected_out, kernel_out, atol=0.2, rtol=0.2)
143
+ self.assertAllClose(expected_updated_kv,
144
+ kernel_updated_kv,
145
+ atol=0.2,
146
+ rtol=0.2)
147
+
148
+ def test_update_kv_cache(self):
149
+ lkv_dim = 4
150
+ r_dim = 4
151
+ padded_lkv_dim = align_to(lkv_dim, 128)
152
+ padded_r_dim = align_to(r_dim, 128)
153
+ kv_dtype = jnp.bfloat16
154
+ new_kv_c = jnp.arange(16, dtype=kv_dtype).reshape((4, lkv_dim))
155
+ new_k_pe = (jnp.arange(16, dtype=kv_dtype).reshape((4, r_dim)) + 100)
156
+ total_num_pages = 2
157
+ page_size = 4
158
+ cache_kv_shape = mla.get_kv_cache_shape(
159
+ total_num_pages,
160
+ page_size,
161
+ padded_lkv_dim + padded_r_dim,
162
+ kv_dtype,
163
+ )
164
+ cache_kv = jnp.zeros(cache_kv_shape, dtype=kv_dtype)
165
+
166
+ # two sequences, first with 3 tokens, second with 1 token
167
+ kv_lens = jnp.array([3, 1], dtype=jnp.int32)
168
+ # first seq uses page 0, second uses page 1
169
+ page_indices = jnp.array([0, -1, 1, -1], dtype=jnp.int32)
170
+ # three tokens for first seq, one for second
171
+ cu_q_lens = jnp.array([0, 3, 4], dtype=jnp.int32)
172
+ distribution = jnp.array([0, 0, 2], dtype=jnp.int32)
173
+
174
+ # manually compute the expected cache
175
+ padded_new_kv_c = jnp.pad(new_kv_c,
176
+ ((0, 0), (0, padded_lkv_dim - lkv_dim)),
177
+ constant_values=0)
178
+ padded_new_k_pe = jnp.pad(new_k_pe,
179
+ ((0, 0), (0, padded_r_dim - r_dim)),
180
+ constant_values=0)
181
+
182
+ expected_cache = cache_kv
183
+ # First sequence
184
+ # token 0
185
+ page_idx, row, col = 0, 0, 0
186
+ expected_cache = expected_cache.at[page_idx, row,
187
+ col, :padded_lkv_dim].set(
188
+ padded_new_kv_c[0])
189
+ expected_cache = expected_cache.at[page_idx, row, col,
190
+ padded_lkv_dim:padded_lkv_dim +
191
+ padded_r_dim].set(
192
+ padded_new_k_pe[0])
193
+ # token 1
194
+ page_idx, row, col = 0, 0, 1
195
+ expected_cache = expected_cache.at[page_idx, row,
196
+ col, :padded_lkv_dim].set(
197
+ padded_new_kv_c[1])
198
+ expected_cache = expected_cache.at[page_idx, row, col,
199
+ padded_lkv_dim:padded_lkv_dim +
200
+ padded_r_dim].set(
201
+ padded_new_k_pe[1])
202
+ # token 2
203
+ page_idx, row, col = 0, 1, 0
204
+ expected_cache = expected_cache.at[page_idx, row,
205
+ col, :padded_lkv_dim].set(
206
+ padded_new_kv_c[2])
207
+ expected_cache = expected_cache.at[page_idx, row, col,
208
+ padded_lkv_dim:padded_lkv_dim +
209
+ padded_r_dim].set(
210
+ padded_new_k_pe[2])
211
+
212
+ # Second sequence
213
+ # token 0
214
+ page_idx, row, col = 1, 0, 0
215
+ expected_cache = expected_cache.at[page_idx, row,
216
+ col, :padded_lkv_dim].set(
217
+ padded_new_kv_c[3])
218
+ expected_cache = expected_cache.at[page_idx, row, col,
219
+ padded_lkv_dim:padded_lkv_dim +
220
+ padded_r_dim].set(
221
+ padded_new_k_pe[3])
222
+
223
+ updated_cache = mla.update_kv_cache(
224
+ new_kv_c,
225
+ new_k_pe,
226
+ cache_kv,
227
+ kv_lens,
228
+ page_indices,
229
+ cu_q_lens,
230
+ distribution,
231
+ )
232
+
233
+ self.assertAllClose(updated_cache, expected_cache)
234
+
235
+ def test_get_kv_cache_shape(self):
236
+ total_num_pages = 10
237
+ page_size = 16
238
+ lkv_dim = 128
239
+ kv_dtype = jnp.bfloat16
240
+ # The calculation for the expected shape is as follows:
241
+ # kv_packing is determined by the dtype, which is 2 for bfloat16.
242
+ # The second dimension is page_size / kv_packing = 16 / 2 = 8
243
+ # The third dimension is kv_packing = 2
244
+ # The fourth dimension is lkv_dim aligned to 128, which is 128
245
+ expected_shape = (10, 8, 2, 128)
246
+ self.assertEqual(
247
+ mla.get_kv_cache_shape(total_num_pages, page_size, lkv_dim,
248
+ kv_dtype), expected_shape)
249
+
250
+ def test_ragged_paged_attention_basic(self):
251
+ dtype = jnp.bfloat16
252
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
253
+ num_heads = 128
254
+ lkv_dim = 512
255
+ r_dim = 64
256
+ page_size = 16
257
+ num_pages = 1000
258
+
259
+ self._test_mla_ragged_paged_attention(
260
+ seq_lens,
261
+ num_heads,
262
+ lkv_dim,
263
+ r_dim,
264
+ page_size,
265
+ dtype,
266
+ dtype,
267
+ num_pages,
268
+ )
269
+
270
+ @parameterized.product(dtype=[jnp.bfloat16], )
271
+ def test_ragged_paged_attention_decode_only(self, dtype):
272
+ seq_lens = [
273
+ (1, 18),
274
+ (1, 129),
275
+ (1, 597),
276
+ (1, 122),
277
+ (1, 64),
278
+ (1, 322),
279
+ (1, 463),
280
+ (1, 181),
281
+ (1, 1107),
282
+ (1, 123),
283
+ (1, 31),
284
+ (1, 18),
285
+ (1, 1229),
286
+ (1, 229),
287
+ (1, 87),
288
+ (1, 1328),
289
+ ]
290
+ num_heads = 128
291
+ lkv_dim = 512
292
+ r_dim = 64
293
+ page_size = 16
294
+ num_pages = 1000
295
+
296
+ self._test_mla_ragged_paged_attention(
297
+ seq_lens,
298
+ num_heads,
299
+ lkv_dim,
300
+ r_dim,
301
+ page_size,
302
+ dtype,
303
+ dtype,
304
+ num_pages,
305
+ )
306
+
307
+ @parameterized.product(dtype=[jnp.bfloat16], )
308
+ def test_ragged_paged_attention_prefill_only(self, dtype):
309
+ seq_lens = [
310
+ (5, 18),
311
+ (15, 129),
312
+ (120, 597),
313
+ (100, 122),
314
+ (21, 64),
315
+ (32, 322),
316
+ (251, 463),
317
+ (40, 181),
318
+ (64, 1107),
319
+ (99, 123),
320
+ (10, 31),
321
+ (5, 18),
322
+ (3, 1229),
323
+ (120, 229),
324
+ (9, 87),
325
+ (2, 1328),
326
+ ]
327
+ num_heads = 128
328
+ lkv_dim = 512
329
+ r_dim = 64
330
+ page_size = 16
331
+ num_pages = 1000
332
+
333
+ self._test_mla_ragged_paged_attention(
334
+ seq_lens,
335
+ num_heads,
336
+ lkv_dim,
337
+ r_dim,
338
+ page_size,
339
+ dtype,
340
+ dtype,
341
+ num_pages,
342
+ )
343
+
344
+ @parameterized.product(dtype=[jnp.bfloat16], )
345
+ def test_ragged_paged_attention_mixed(self, dtype):
346
+ seq_lens = [
347
+ (5, 18),
348
+ (1, 129),
349
+ (120, 597),
350
+ (1, 122),
351
+ (1, 64),
352
+ (32, 322),
353
+ (251, 463),
354
+ (1, 181),
355
+ (1, 1107),
356
+ (99, 123),
357
+ (1, 31),
358
+ (5, 18),
359
+ (3, 1229),
360
+ (117, 229),
361
+ (1, 87),
362
+ (1, 1328),
363
+ ]
364
+ num_heads = 128
365
+ lkv_dim = 512
366
+ r_dim = 64
367
+ page_size = 16
368
+ num_pages = 1000
369
+
370
+ self._test_mla_ragged_paged_attention(
371
+ seq_lens,
372
+ num_heads,
373
+ lkv_dim,
374
+ r_dim,
375
+ page_size,
376
+ dtype,
377
+ dtype,
378
+ num_pages,
379
+ )
380
+
381
+ @parameterized.product(sliding_window=[None, 5, 128], )
382
+ def test_ragged_paged_attention_sliding_window(
383
+ self,
384
+ sliding_window: int | None,
385
+ ):
386
+ num_seqs = 5
387
+ num_heads = 128
388
+ lkv_dim = 512
389
+ r_dim = 64
390
+ dtype = jnp.float32
391
+ rng = np.random.default_rng(1234)
392
+ q_lens = rng.integers(1, 100, num_seqs)
393
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
394
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
395
+ page_size = 16
396
+ num_pages = 1000
397
+
398
+ self._test_mla_ragged_paged_attention(
399
+ seq_lens,
400
+ num_heads,
401
+ lkv_dim,
402
+ r_dim,
403
+ page_size,
404
+ dtype,
405
+ dtype,
406
+ num_pages,
407
+ sliding_window=sliding_window,
408
+ )
409
+
410
+ @parameterized.product(soft_cap=[None, 50.0], )
411
+ def test_ragged_paged_attention_logit_soft_capping(
412
+ self,
413
+ soft_cap: float | None,
414
+ ):
415
+ num_heads = 128
416
+ num_seqs = 2
417
+ dtype = jnp.float32
418
+ rng = np.random.default_rng(1234)
419
+ q_lens = rng.integers(1, 100, num_seqs)
420
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
421
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
422
+ lkv_dim = 512
423
+ r_dim = 64
424
+ page_size = 16
425
+ num_pages = 1000
426
+
427
+ self._test_mla_ragged_paged_attention(
428
+ seq_lens,
429
+ num_heads,
430
+ lkv_dim,
431
+ r_dim,
432
+ page_size,
433
+ dtype,
434
+ dtype,
435
+ num_pages,
436
+ soft_cap=soft_cap,
437
+ )
438
+
439
+ def test_ragged_paged_attention_sliding_window_should_be_positive(self):
440
+ dtype = jnp.float32
441
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
442
+ num_heads = 128
443
+ lkv_dim = 512
444
+ r_dim = 64
445
+ page_size = 16
446
+ num_pages = 1000
447
+
448
+ with self.assertRaisesRegex(ValueError, "must be positive"):
449
+ self._test_mla_ragged_paged_attention(
450
+ seq_lens,
451
+ num_heads,
452
+ lkv_dim,
453
+ r_dim,
454
+ page_size,
455
+ dtype,
456
+ dtype,
457
+ num_pages,
458
+ sliding_window=0,
459
+ )
460
+
461
+ with self.assertRaisesRegex(ValueError, "must be positive"):
462
+ self._test_mla_ragged_paged_attention(
463
+ seq_lens,
464
+ num_heads,
465
+ lkv_dim,
466
+ r_dim,
467
+ page_size,
468
+ dtype,
469
+ dtype,
470
+ num_pages,
471
+ sliding_window=-1,
472
+ )
473
+
474
+ def test_ragged_paged_attention_soft_cap_cannot_be_zero(self):
475
+ dtype = jnp.float32
476
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
477
+ num_heads = 128
478
+ lkv_dim = 512
479
+ r_dim = 64
480
+ page_size = 16
481
+ num_pages = 1000
482
+
483
+ with self.assertRaisesRegex(ValueError, "must not be 0.0"):
484
+ self._test_mla_ragged_paged_attention(
485
+ seq_lens,
486
+ num_heads,
487
+ lkv_dim,
488
+ r_dim,
489
+ page_size,
490
+ dtype,
491
+ dtype,
492
+ num_pages,
493
+ soft_cap=0.0,
494
+ )
495
+
496
+
497
+ if __name__ == "__main__":
498
+ absltest.main(testLoader=jtu.JaxTestLoader())