tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,520 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+ import numpy as np
18
+ from absl.testing import absltest, parameterized
19
+ from jax._src import dtypes
20
+ from jax._src import test_util as jtu
21
+
22
+ from tpu_inference.kernels.ragged_paged_attention.v3.kernel import (
23
+ ragged_paged_attention, ref_ragged_paged_attention)
24
+ from tpu_inference.kernels.ragged_paged_attention.v3.util import (
25
+ align_to, cdiv, get_dtype_packing)
26
+
27
+ jax.config.parse_flags_with_absl()
28
+
29
+
30
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
31
+ class RaggedPagedAttentionKernelTest(jtu.JaxTestCase):
32
+
33
+ def _test_ragged_paged_attention(
34
+ self,
35
+ seq_lens, # List[(q_len, kv_len)]
36
+ num_heads, # [num_q_heads, num_kv_heads]
37
+ head_dim,
38
+ page_size,
39
+ q_dtype,
40
+ kv_dtype,
41
+ num_pages,
42
+ *,
43
+ num_kv_pages_per_block=8,
44
+ num_queries_per_block=64,
45
+ vmem_limit_bytes=100 * 1024 * 1024,
46
+ max_num_batched_tokens=512,
47
+ max_num_seq=8,
48
+ sliding_window: int | None = None,
49
+ soft_cap: float | None = None,
50
+ q_scale: float | None = None,
51
+ k_scale: float | None = None,
52
+ v_scale: float | None = None,
53
+ ):
54
+ rng = np.random.default_rng(1234)
55
+
56
+ def gen_random(shape, dtype):
57
+ return jnp.array(rng.random(size=shape,
58
+ dtype=np.float32)).astype(dtype)
59
+
60
+ if not jtu.is_device_tpu_at_least(version=4):
61
+ self.skipTest("Expect TPUv4+")
62
+ cu_q_lens = [0]
63
+ kv_lens = []
64
+ for q_len, kv_len in seq_lens:
65
+ assert q_len <= kv_len
66
+ cu_q_lens.append(cu_q_lens[-1] + q_len)
67
+ kv_lens.append(kv_len)
68
+
69
+ max_num_batched_tokens = max(align_to(cu_q_lens[-1], 128),
70
+ max_num_batched_tokens)
71
+ max_num_seq = max(align_to(len(seq_lens), 8), max_num_seq)
72
+ max_kv_len = max(kv_lens)
73
+ pages_per_seq = cdiv(max_kv_len, page_size)
74
+ num_q_heads, num_kv_heads = num_heads
75
+
76
+ q = gen_random((max_num_batched_tokens, num_q_heads, head_dim),
77
+ q_dtype)
78
+ k = gen_random((max_num_batched_tokens, num_kv_heads, head_dim),
79
+ kv_dtype)
80
+ v = gen_random((max_num_batched_tokens, num_kv_heads, head_dim),
81
+ kv_dtype)
82
+ page_cnt = 0
83
+ page_indices_list = []
84
+ kv_pages_list = []
85
+ kv_packing = get_dtype_packing(kv_dtype)
86
+ padded_head_dim = align_to(head_dim, 128)
87
+ num_kv_heads_x2 = align_to(num_kv_heads * 2, kv_packing)
88
+ for kv_len in kv_lens:
89
+ kv = gen_random((
90
+ kv_len,
91
+ num_kv_heads_x2 // kv_packing,
92
+ kv_packing,
93
+ padded_head_dim,
94
+ ), kv_dtype)
95
+ kv = jnp.pad(
96
+ kv,
97
+ (
98
+ (
99
+ 0,
100
+ cdiv(kv_len, page_size) * page_size - kv_len,
101
+ ),
102
+ (0, 0),
103
+ (0, 0),
104
+ (0, 0),
105
+ ),
106
+ constant_values=jnp.nan,
107
+ ).reshape(
108
+ -1,
109
+ page_size,
110
+ num_kv_heads_x2 // kv_packing,
111
+ kv_packing,
112
+ padded_head_dim,
113
+ )
114
+ indices = page_cnt + jnp.arange(kv.shape[0], dtype=jnp.int32)
115
+ indices = jnp.pad(
116
+ indices,
117
+ ((0, pages_per_seq - indices.shape[0]), ),
118
+ constant_values=jnp.nan,
119
+ )
120
+ page_indices_list.append(indices)
121
+ page_cnt += kv.shape[0]
122
+ kv_pages_list.append(kv)
123
+
124
+ kv_cache = jnp.concatenate(kv_pages_list, axis=0)
125
+ kv_cache = jnp.pad(
126
+ kv_cache,
127
+ ((0, num_pages - kv_cache.shape[0]), (0, 0), (0, 0), (0, 0),
128
+ (0, 0)),
129
+ constant_values=jnp.nan,
130
+ )
131
+ page_indices = jnp.stack(page_indices_list, axis=0)
132
+ page_indices = jnp.pad(
133
+ page_indices,
134
+ ((0, max_num_seq - page_indices.shape[0]), (0, 0)),
135
+ constant_values=jnp.nan,
136
+ )
137
+ page_indices = page_indices.reshape(-1)
138
+
139
+ cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32)
140
+ cu_q_lens = jnp.pad(cu_q_lens,
141
+ (0, max_num_seq + 1 - cu_q_lens.shape[0]))
142
+ kv_lens = jnp.array(kv_lens, dtype=jnp.int32)
143
+ kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0]))
144
+ distribution = jnp.array([0, 0, len(seq_lens)], dtype=jnp.int32)
145
+
146
+ args = (
147
+ q,
148
+ k,
149
+ v,
150
+ kv_cache,
151
+ kv_lens,
152
+ page_indices,
153
+ cu_q_lens,
154
+ distribution,
155
+ )
156
+
157
+ kwargs = {
158
+ "sliding_window": sliding_window,
159
+ "soft_cap": soft_cap,
160
+ "q_scale": q_scale,
161
+ "k_scale": k_scale,
162
+ "v_scale": v_scale,
163
+ }
164
+
165
+ expected, expected_kv_cache = ref_ragged_paged_attention(
166
+ *args,
167
+ **kwargs,
168
+ )
169
+
170
+ output, updated_kv_cache = ragged_paged_attention(
171
+ *args,
172
+ **kwargs,
173
+ num_kv_pages_per_block=num_kv_pages_per_block,
174
+ num_queries_per_block=num_queries_per_block,
175
+ vmem_limit_bytes=vmem_limit_bytes,
176
+ )
177
+ output = output[:cu_q_lens[distribution[-1]]]
178
+
179
+ dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
180
+ dtypes, "bit_width") else dtypes.itemsize_bits(
181
+ jnp.dtype(kv_dtype)))
182
+ tols = {
183
+ 32: 0.15,
184
+ 16: 0.2,
185
+ 8: 0.2,
186
+ 4: 0.2,
187
+ }
188
+ tol = tols[dtype_bits]
189
+ self.assertAllClose(output, expected, atol=tol, rtol=tol)
190
+ mask = ~jnp.isnan(expected_kv_cache)
191
+ self.assertArraysEqual(updated_kv_cache[mask], expected_kv_cache[mask])
192
+ self.assertEqual(output.shape[-1], head_dim)
193
+
194
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
195
+ def test_ragged_paged_attention_basic(self, dtype):
196
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
197
+ num_heads = (32, 8)
198
+ head_dim = 128
199
+ page_size = 16
200
+ num_pages = 1000
201
+
202
+ self._test_ragged_paged_attention(
203
+ seq_lens,
204
+ num_heads,
205
+ head_dim,
206
+ page_size,
207
+ dtype,
208
+ dtype,
209
+ num_pages,
210
+ )
211
+
212
+ # TODO: support integer (int8, int4) and fp4 kv cache
213
+ @parameterized.product(
214
+ q_dtype=[jnp.bfloat16],
215
+ kv_dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn],
216
+ kv_scales=[(0.5, 0.5), (1.0, 1.0)],
217
+ )
218
+ def test_ragged_paged_attention_quantized_kv_cache(self, q_dtype, kv_dtype,
219
+ kv_scales):
220
+ if not jtu.is_device_tpu_at_least(version=5):
221
+ self.skipTest("Expect TPUv5+")
222
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
223
+ num_heads = (32, 8)
224
+ head_dim = 128
225
+ page_size = 16
226
+ num_pages = 1000
227
+ k_scale, v_scale = kv_scales
228
+
229
+ self._test_ragged_paged_attention(
230
+ seq_lens,
231
+ num_heads,
232
+ head_dim,
233
+ page_size,
234
+ q_dtype,
235
+ kv_dtype,
236
+ num_pages,
237
+ k_scale=k_scale,
238
+ v_scale=v_scale,
239
+ )
240
+
241
+ @parameterized.product(
242
+ q_dtype=[jnp.bfloat16],
243
+ kv_dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn],
244
+ q_scale=[0.5, 1.0],
245
+ kv_scales=[(0.5, 0.5), (1.0, 1.0)],
246
+ )
247
+ def test_ragged_paged_attention_quantized_attention(
248
+ self, q_dtype, kv_dtype, q_scale, kv_scales):
249
+ if not jtu.is_device_tpu_at_least(version=5):
250
+ self.skipTest("Expect TPUv5+")
251
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
252
+ num_heads = (32, 8)
253
+ head_dim = 128
254
+ page_size = 16
255
+ num_pages = 1000
256
+ k_scale, v_scale = kv_scales
257
+
258
+ self._test_ragged_paged_attention(
259
+ seq_lens,
260
+ num_heads,
261
+ head_dim,
262
+ page_size,
263
+ q_dtype,
264
+ kv_dtype,
265
+ num_pages,
266
+ q_scale=q_scale,
267
+ k_scale=k_scale,
268
+ v_scale=v_scale,
269
+ )
270
+
271
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
272
+ def test_ragged_paged_attention_decode_only(self, dtype):
273
+ seq_lens = [
274
+ (1, 18),
275
+ (1, 129),
276
+ (1, 597),
277
+ (1, 122),
278
+ (1, 64),
279
+ (1, 322),
280
+ (1, 463),
281
+ (1, 181),
282
+ (1, 1107),
283
+ (1, 123),
284
+ (1, 31),
285
+ (1, 18),
286
+ (1, 1229),
287
+ (1, 229),
288
+ (1, 87),
289
+ (1, 1328),
290
+ ]
291
+ num_heads = (32, 8)
292
+ head_dim = 128
293
+ page_size = 16
294
+ num_pages = 1000
295
+
296
+ self._test_ragged_paged_attention(
297
+ seq_lens,
298
+ num_heads,
299
+ head_dim,
300
+ page_size,
301
+ dtype,
302
+ dtype,
303
+ num_pages,
304
+ )
305
+
306
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
307
+ def test_ragged_paged_attention_prefill_only(self, dtype):
308
+ seq_lens = [
309
+ (5, 18),
310
+ (15, 129),
311
+ (120, 597),
312
+ (100, 122),
313
+ (21, 64),
314
+ (32, 322),
315
+ (251, 463),
316
+ (40, 181),
317
+ (64, 1107),
318
+ (99, 123),
319
+ (10, 31),
320
+ (5, 18),
321
+ (3, 1229),
322
+ (120, 229),
323
+ (9, 87),
324
+ (2, 1328),
325
+ ]
326
+ num_heads = (32, 8)
327
+ head_dim = 128
328
+ page_size = 16
329
+ num_pages = 1000
330
+
331
+ self._test_ragged_paged_attention(
332
+ seq_lens,
333
+ num_heads,
334
+ head_dim,
335
+ page_size,
336
+ dtype,
337
+ dtype,
338
+ num_pages,
339
+ )
340
+
341
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
342
+ def test_ragged_paged_attention_mixed(self, dtype):
343
+ seq_lens = [
344
+ (5, 18),
345
+ (1, 129),
346
+ (120, 597),
347
+ (1, 122),
348
+ (1, 64),
349
+ (32, 322),
350
+ (251, 463),
351
+ (1, 181),
352
+ (1, 1107),
353
+ (99, 123),
354
+ (1, 31),
355
+ (5, 18),
356
+ (3, 1229),
357
+ (117, 229),
358
+ (1, 87),
359
+ (1, 1328),
360
+ ]
361
+ num_heads = (32, 8)
362
+ head_dim = 128
363
+ page_size = 16
364
+ num_pages = 1000
365
+
366
+ self._test_ragged_paged_attention(
367
+ seq_lens,
368
+ num_heads,
369
+ head_dim,
370
+ page_size,
371
+ dtype,
372
+ dtype,
373
+ num_pages,
374
+ )
375
+
376
+ @parameterized.product(
377
+ num_seqs=[1, 17],
378
+ num_heads=[(32, 8), (12, 2), (5, 1), (3, 3)],
379
+ head_dim=[80, 240],
380
+ dtype=[jnp.float32, jnp.bfloat16],
381
+ # num_kv_pages_per_block=[8, 16],
382
+ # num_queries_per_block=[16, 32],
383
+ )
384
+ def test_ragged_paged_attention_complex(
385
+ self,
386
+ num_seqs,
387
+ num_heads,
388
+ head_dim,
389
+ dtype,
390
+ # num_kv_pages_per_block,
391
+ # num_queries_per_block,
392
+ ):
393
+ rng = np.random.default_rng(1234)
394
+ q_lens = rng.integers(1, 100, num_seqs)
395
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
396
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
397
+ page_size = 16
398
+ num_pages = 1000
399
+
400
+ self._test_ragged_paged_attention(
401
+ seq_lens,
402
+ num_heads,
403
+ head_dim,
404
+ page_size,
405
+ dtype,
406
+ dtype,
407
+ num_pages,
408
+ # num_kv_pages_per_block=num_kv_pages_per_block,
409
+ # num_queries_per_block=num_queries_per_block,
410
+ )
411
+
412
+ @parameterized.product(sliding_window=[None, 5, 128], )
413
+ def test_ragged_paged_attention_sliding_window(
414
+ self,
415
+ sliding_window: int | None,
416
+ ):
417
+ num_seqs = 5
418
+ num_heads = (4, 4)
419
+ dtype = jnp.float32
420
+ rng = np.random.default_rng(1234)
421
+ q_lens = rng.integers(1, 100, num_seqs)
422
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
423
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
424
+ head_dim = 128
425
+ page_size = 16
426
+ num_pages = 1000
427
+
428
+ self._test_ragged_paged_attention(
429
+ seq_lens,
430
+ num_heads,
431
+ head_dim,
432
+ page_size,
433
+ dtype,
434
+ dtype,
435
+ num_pages,
436
+ sliding_window=sliding_window,
437
+ )
438
+
439
+ @parameterized.product(soft_cap=[None, 50.0], )
440
+ def test_ragged_paged_attention_logit_soft_capping(
441
+ self,
442
+ soft_cap: float | None,
443
+ ):
444
+ num_heads = (16, 2)
445
+ num_seqs = 2
446
+ dtype = jnp.float32
447
+ rng = np.random.default_rng(1234)
448
+ q_lens = rng.integers(1, 100, num_seqs)
449
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
450
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
451
+ head_dim = 128
452
+ page_size = 16
453
+ num_pages = 1000
454
+
455
+ self._test_ragged_paged_attention(
456
+ seq_lens,
457
+ num_heads,
458
+ head_dim,
459
+ page_size,
460
+ dtype,
461
+ dtype,
462
+ num_pages,
463
+ soft_cap=soft_cap,
464
+ )
465
+
466
+ def test_ragged_paged_attention_sliding_window_should_be_positive(self):
467
+ dtype = jnp.float32
468
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
469
+ num_heads = (32, 8)
470
+ head_dim = 128
471
+ page_size = 16
472
+ num_pages = 1000
473
+
474
+ with self.assertRaisesRegex(ValueError, "must be positive"):
475
+ self._test_ragged_paged_attention(
476
+ seq_lens,
477
+ num_heads,
478
+ head_dim,
479
+ page_size,
480
+ dtype,
481
+ dtype,
482
+ num_pages,
483
+ sliding_window=0,
484
+ )
485
+
486
+ with self.assertRaisesRegex(ValueError, "must be positive"):
487
+ self._test_ragged_paged_attention(
488
+ seq_lens,
489
+ num_heads,
490
+ head_dim,
491
+ page_size,
492
+ dtype,
493
+ dtype,
494
+ num_pages,
495
+ sliding_window=-1,
496
+ )
497
+
498
+ def test_ragged_paged_attention_soft_cap_cannot_be_zero(self):
499
+ dtype = jnp.float32
500
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
501
+ num_heads = (32, 8)
502
+ head_dim = 128
503
+ page_size = 16
504
+ num_pages = 1000
505
+
506
+ with self.assertRaisesRegex(ValueError, "must not be 0.0"):
507
+ self._test_ragged_paged_attention(
508
+ seq_lens,
509
+ num_heads,
510
+ head_dim,
511
+ page_size,
512
+ dtype,
513
+ dtype,
514
+ num_pages,
515
+ soft_cap=0.0,
516
+ )
517
+
518
+
519
+ if __name__ == "__main__":
520
+ absltest.main(testLoader=jtu.JaxTestLoader())
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,156 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import pytest
21
+ from jax.sharding import Mesh
22
+
23
+ from tpu_inference.layers.common.attention_interface import attention
24
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
25
+ from tpu_inference.runner.kv_cache import get_kv_cache_shape_with_mesh
26
+
27
+ # ---- Test Configuration & Constants ----
28
+
29
+ # Total number of tokens across all sequences in the batch
30
+ TOTAL_TOKENS = 10
31
+ # Number of sequences in the batch
32
+ NUM_SEQS = 2
33
+ # Padded maximum number of sequences
34
+ MAX_NUM_SEQS = 4
35
+ # Number of attention heads (Query)
36
+ NUM_HEADS = 8
37
+ # Number of attention heads (Key/Value) - for Grouped-Query Attention
38
+ NUM_KV_HEADS = 4
39
+ # Total number of blocks in the KV cache
40
+ NUM_BLOCKS = 32
41
+ # Number of tokens per block
42
+ BLOCK_SIZE = 16
43
+ # Maximum number of blocks a single sequence can occupy
44
+ MAX_BLOCKS_PER_SEQ = 8
45
+
46
+
47
+ @pytest.fixture
48
+ def mesh():
49
+ """Provides a mock 1D JAX mesh for testing."""
50
+ # Create a mesh with available devices, useful for running on CPU/GPU/TPU
51
+ # For this test, it will likely be a single CPU device.
52
+ devices = np.array(jax.local_devices()[:1])
53
+ if not devices.any():
54
+ # Add a mock device if no devices are present (e.g., in a CI environment)
55
+ devices = np.array([jax.devices("cpu")[0]])
56
+ return Mesh(devices.reshape((-1, 1, 1)), ("data", "attn_dp", "model"))
57
+
58
+
59
+ # ---- Test for `attention` ----
60
+
61
+
62
+ def _test_attention(monkeypatch, mesh, head_dim, use_sinks=False):
63
+ """
64
+ Tests the main `attention` function.
65
+
66
+ Verifies that:
67
+ 1. It calls the `sharded_ragged_paged_attention` kernel with correct metadata.
68
+ 2. The final outputs (kv_cache and attention output) have the correct shapes.
69
+ """
70
+ # 1. Arrange
71
+
72
+ # Create input tensors
73
+ q_dtype = jnp.float32
74
+ kv_dtype = jnp.float32
75
+ q = jnp.ones((TOTAL_TOKENS, NUM_HEADS, head_dim), dtype=q_dtype)
76
+ k = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, head_dim), dtype=kv_dtype)
77
+ v = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, head_dim), dtype=kv_dtype)
78
+ sinks = jnp.ones((NUM_HEADS, ), dtype=jnp.float32) if use_sinks else None
79
+
80
+ kv_cache_shape = get_kv_cache_shape_with_mesh(
81
+ mesh,
82
+ NUM_BLOCKS,
83
+ BLOCK_SIZE,
84
+ NUM_KV_HEADS,
85
+ head_dim,
86
+ kv_dtype,
87
+ )
88
+ kv_cache = jnp.zeros(kv_cache_shape, dtype=kv_dtype)
89
+
90
+ # Mock ragged_paged_attention to return a tensor of the correct shape
91
+ mock_paged_attn_kernel = MagicMock(return_value=(jnp.ones(
92
+ (TOTAL_TOKENS, NUM_HEADS, head_dim)), kv_cache), )
93
+
94
+ if head_dim == 64:
95
+ monkeypatch.setattr(
96
+ "tpu_inference.layers.common.attention_interface.ragged_paged_attention_hd64",
97
+ mock_paged_attn_kernel,
98
+ )
99
+ else:
100
+ monkeypatch.setattr(
101
+ "tpu_inference.layers.common.attention_interface.ragged_paged_attention",
102
+ mock_paged_attn_kernel,
103
+ )
104
+
105
+ # Create AttentionMetadata
106
+ attention_metadata = AttentionMetadata(
107
+ input_positions=jnp.arange(TOTAL_TOKENS, dtype=jnp.int32),
108
+ block_tables=jnp.zeros((MAX_NUM_SEQS * MAX_BLOCKS_PER_SEQ, ),
109
+ dtype=jnp.int32),
110
+ seq_lens=jnp.array([5, 5, 0, 0], dtype=jnp.int32),
111
+ query_start_loc=jnp.array([0, 5, 10, 10, 10], dtype=jnp.int32),
112
+ request_distribution=jnp.array([0, 0, NUM_SEQS], dtype=jnp.int32),
113
+ )
114
+
115
+ # 2. Act
116
+ final_kv_cache, output = attention(
117
+ kv_cache=kv_cache,
118
+ q=q,
119
+ k=k,
120
+ v=v,
121
+ attention_metadata=attention_metadata,
122
+ mesh=mesh,
123
+ head_dim_original=head_dim,
124
+ sinks=sinks,
125
+ )
126
+
127
+ # 3. Assert
128
+ # Check that both mocked kernels were called
129
+ mock_paged_attn_kernel.assert_called_once()
130
+
131
+ # Check output shapes
132
+ assert final_kv_cache.shape == kv_cache.shape
133
+ assert output.shape == q.shape
134
+
135
+ # Check that the output is the one from our mock
136
+ assert jnp.all(output == 1.0)
137
+
138
+
139
+ def test_attention(monkeypatch, mesh):
140
+ _test_attention(monkeypatch, mesh, 128)
141
+
142
+
143
+ def test_attention_hd64(monkeypatch, mesh):
144
+ _test_attention(monkeypatch, mesh, 64)
145
+
146
+
147
+ def test_attention_sink(monkeypatch, mesh):
148
+ _test_attention(monkeypatch, mesh, 64, True)
149
+
150
+
151
+ def test_attention_sink_no_64_raises_error(monkeypatch, mesh):
152
+ with pytest.raises(
153
+ NotImplementedError,
154
+ match="Attention sink support is only available when head_dim==64"
155
+ ):
156
+ _test_attention(monkeypatch, mesh, 128, True)