tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np
@@ -42,6 +56,7 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
42
56
 
43
57
  padded_r_dim = align_to(r_dim, 128)
44
58
  padded_lkv_dim = align_to(lkv_dim, 128)
59
+ padded_kv_dim = padded_lkv_dim + padded_r_dim
45
60
  packing = get_dtype_packing(kv_dtype)
46
61
  q_lens = [s[0] for s in seq_lens]
47
62
  kv_lens_list = [s[1] for s in seq_lens]
@@ -69,13 +84,10 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
69
84
  new_kv_c = gen_random((total_q_len, lkv_dim), kv_dtype)
70
85
  new_k_pe = gen_random((total_q_len, r_dim), kv_dtype)
71
86
 
72
- cache_kv_c = gen_random(
73
- (total_num_pages, page_size // packing, packing, padded_lkv_dim),
87
+ cache_kv = gen_random(
88
+ (total_num_pages, page_size // packing, packing, padded_kv_dim),
74
89
  kv_dtype,
75
90
  )
76
- cache_k_pe = gen_random(
77
- (total_num_pages, page_size // packing, packing, padded_r_dim),
78
- kv_dtype)
79
91
  kv_lens = jnp.array(kv_lens_list, dtype=jnp.int32)
80
92
  page_indices = jnp.array(page_indices_list, dtype=jnp.int32)
81
93
  cu_q_lens = jnp.array(cu_q_lens_list, dtype=jnp.int32)
@@ -84,14 +96,13 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
84
96
  ql_nope_for_kernel = ql_nope.copy()
85
97
  q_pe_for_kernel = q_pe.copy()
86
98
 
87
- expected_out, expected_updated_kv_c, expeceted_updated_k_pe = (
99
+ expected_out, expected_updated_kv = (
88
100
  mla.ref_mla_ragged_paged_attention(
89
101
  ql_nope,
90
102
  q_pe,
91
103
  new_kv_c,
92
104
  new_k_pe,
93
- cache_kv_c.copy(),
94
- cache_k_pe.copy(),
105
+ cache_kv.copy(),
95
106
  kv_lens,
96
107
  page_indices,
97
108
  cu_q_lens,
@@ -101,50 +112,141 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
101
112
  soft_cap=soft_cap,
102
113
  ))
103
114
 
104
- kernel_out, kernel_updated_kv_c, kernel_updated_k_pe = (
105
- mla.mla_ragged_paged_attention(
106
- ql_nope_for_kernel,
107
- q_pe_for_kernel,
108
- new_kv_c,
109
- new_k_pe,
110
- cache_kv_c.copy(),
111
- cache_k_pe.copy(),
112
- kv_lens,
113
- page_indices,
114
- cu_q_lens,
115
- distribution,
116
- sm_scale=sm_scale,
117
- sliding_window=sliding_window,
118
- soft_cap=soft_cap,
119
- num_kv_pages_per_block=num_kv_pages_per_block,
120
- num_queries_per_block=num_queries_per_block,
121
- vmem_limit_bytes=vmem_limit_bytes,
122
- ))
115
+ kernel_out, kernel_updated_kv = (mla.mla_ragged_paged_attention(
116
+ ql_nope_for_kernel,
117
+ q_pe_for_kernel,
118
+ new_kv_c,
119
+ new_k_pe,
120
+ cache_kv.copy(),
121
+ kv_lens,
122
+ page_indices,
123
+ cu_q_lens,
124
+ distribution,
125
+ sm_scale=sm_scale,
126
+ sliding_window=sliding_window,
127
+ soft_cap=soft_cap,
128
+ num_kv_pages_per_block=num_kv_pages_per_block,
129
+ num_queries_per_block=num_queries_per_block,
130
+ vmem_limit_bytes=vmem_limit_bytes,
131
+ ))
123
132
 
124
133
  self.assertEqual(expected_out.shape,
125
134
  (total_q_len, num_heads, padded_lkv_dim))
126
135
  self.assertEqual(
127
- expected_updated_kv_c.shape,
128
- (total_num_pages, page_size // packing, packing, padded_lkv_dim),
129
- )
130
- self.assertEqual(
131
- expeceted_updated_k_pe.shape,
132
- (total_num_pages, page_size // packing, packing, padded_r_dim),
136
+ expected_updated_kv.shape,
137
+ (total_num_pages, page_size // packing, packing, padded_kv_dim),
133
138
  )
134
139
  self.assertEqual(expected_out.dtype, kv_dtype)
135
- self.assertEqual(expected_updated_kv_c.dtype, kv_dtype)
136
- self.assertEqual(expeceted_updated_k_pe.dtype, kv_dtype)
140
+ self.assertEqual(expected_updated_kv.dtype, kv_dtype)
137
141
 
138
142
  self.assertAllClose(expected_out, kernel_out, atol=0.2, rtol=0.2)
139
- self.assertAllClose(expected_updated_kv_c,
140
- kernel_updated_kv_c,
141
- atol=0.2,
142
- rtol=0.2)
143
- self.assertAllClose(expeceted_updated_k_pe,
144
- kernel_updated_k_pe,
143
+ self.assertAllClose(expected_updated_kv,
144
+ kernel_updated_kv,
145
145
  atol=0.2,
146
146
  rtol=0.2)
147
147
 
148
+ def test_update_kv_cache(self):
149
+ lkv_dim = 4
150
+ r_dim = 4
151
+ padded_lkv_dim = align_to(lkv_dim, 128)
152
+ padded_r_dim = align_to(r_dim, 128)
153
+ kv_dtype = jnp.bfloat16
154
+ new_kv_c = jnp.arange(16, dtype=kv_dtype).reshape((4, lkv_dim))
155
+ new_k_pe = (jnp.arange(16, dtype=kv_dtype).reshape((4, r_dim)) + 100)
156
+ total_num_pages = 2
157
+ page_size = 4
158
+ cache_kv_shape = mla.get_kv_cache_shape(
159
+ total_num_pages,
160
+ page_size,
161
+ padded_lkv_dim + padded_r_dim,
162
+ kv_dtype,
163
+ )
164
+ cache_kv = jnp.zeros(cache_kv_shape, dtype=kv_dtype)
165
+
166
+ # two sequences, first with 3 tokens, second with 1 token
167
+ kv_lens = jnp.array([3, 1], dtype=jnp.int32)
168
+ # first seq uses page 0, second uses page 1
169
+ page_indices = jnp.array([0, -1, 1, -1], dtype=jnp.int32)
170
+ # three tokens for first seq, one for second
171
+ cu_q_lens = jnp.array([0, 3, 4], dtype=jnp.int32)
172
+ distribution = jnp.array([0, 0, 2], dtype=jnp.int32)
173
+
174
+ # manually compute the expected cache
175
+ padded_new_kv_c = jnp.pad(new_kv_c,
176
+ ((0, 0), (0, padded_lkv_dim - lkv_dim)),
177
+ constant_values=0)
178
+ padded_new_k_pe = jnp.pad(new_k_pe,
179
+ ((0, 0), (0, padded_r_dim - r_dim)),
180
+ constant_values=0)
181
+
182
+ expected_cache = cache_kv
183
+ # First sequence
184
+ # token 0
185
+ page_idx, row, col = 0, 0, 0
186
+ expected_cache = expected_cache.at[page_idx, row,
187
+ col, :padded_lkv_dim].set(
188
+ padded_new_kv_c[0])
189
+ expected_cache = expected_cache.at[page_idx, row, col,
190
+ padded_lkv_dim:padded_lkv_dim +
191
+ padded_r_dim].set(
192
+ padded_new_k_pe[0])
193
+ # token 1
194
+ page_idx, row, col = 0, 0, 1
195
+ expected_cache = expected_cache.at[page_idx, row,
196
+ col, :padded_lkv_dim].set(
197
+ padded_new_kv_c[1])
198
+ expected_cache = expected_cache.at[page_idx, row, col,
199
+ padded_lkv_dim:padded_lkv_dim +
200
+ padded_r_dim].set(
201
+ padded_new_k_pe[1])
202
+ # token 2
203
+ page_idx, row, col = 0, 1, 0
204
+ expected_cache = expected_cache.at[page_idx, row,
205
+ col, :padded_lkv_dim].set(
206
+ padded_new_kv_c[2])
207
+ expected_cache = expected_cache.at[page_idx, row, col,
208
+ padded_lkv_dim:padded_lkv_dim +
209
+ padded_r_dim].set(
210
+ padded_new_k_pe[2])
211
+
212
+ # Second sequence
213
+ # token 0
214
+ page_idx, row, col = 1, 0, 0
215
+ expected_cache = expected_cache.at[page_idx, row,
216
+ col, :padded_lkv_dim].set(
217
+ padded_new_kv_c[3])
218
+ expected_cache = expected_cache.at[page_idx, row, col,
219
+ padded_lkv_dim:padded_lkv_dim +
220
+ padded_r_dim].set(
221
+ padded_new_k_pe[3])
222
+
223
+ updated_cache = mla.update_kv_cache(
224
+ new_kv_c,
225
+ new_k_pe,
226
+ cache_kv,
227
+ kv_lens,
228
+ page_indices,
229
+ cu_q_lens,
230
+ distribution,
231
+ )
232
+
233
+ self.assertAllClose(updated_cache, expected_cache)
234
+
235
+ def test_get_kv_cache_shape(self):
236
+ total_num_pages = 10
237
+ page_size = 16
238
+ lkv_dim = 128
239
+ kv_dtype = jnp.bfloat16
240
+ # The calculation for the expected shape is as follows:
241
+ # kv_packing is determined by the dtype, which is 2 for bfloat16.
242
+ # The second dimension is page_size / kv_packing = 16 / 2 = 8
243
+ # The third dimension is kv_packing = 2
244
+ # The fourth dimension is lkv_dim aligned to 128, which is 128
245
+ expected_shape = (10, 8, 2, 128)
246
+ self.assertEqual(
247
+ mla.get_kv_cache_shape(total_num_pages, page_size, lkv_dim,
248
+ kv_dtype), expected_shape)
249
+
148
250
  def test_ragged_paged_attention_basic(self):
149
251
  dtype = jnp.bfloat16
150
252
  seq_lens = [(192, 328), (128, 180), (64, 255)]
@@ -1,7 +1,5 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import functools
4
-
5
3
  import jax
6
4
  import jax.numpy as jnp
7
5
  from absl.testing import absltest, parameterized
@@ -10,6 +8,7 @@ from jax._src import test_util as jtu
10
8
  from tpu_inference.kernels.quantized_matmul import (kernel, tuned_block_sizes,
11
9
  util)
12
10
 
11
+ xla_quantized_matmul = kernel.xla_quantized_matmul
13
12
  quantized_matmul_kernel = kernel.quantized_matmul_kernel
14
13
  quantize_tensor = util.quantize_tensor
15
14
  get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
@@ -17,37 +16,6 @@ get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
17
16
  jax.config.parse_flags_with_absl()
18
17
 
19
18
 
20
- @functools.partial(jax.jit, static_argnames=["quantize_activation"])
21
- def reference_quantized_matmul(
22
- x: jax.Array,
23
- w_q: jax.Array,
24
- w_scale: jax.Array,
25
- quantize_activation=True,
26
- ):
27
- if quantize_activation:
28
- acc_dtype = jnp.float32
29
- if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
30
- acc_dtype = jnp.int32
31
-
32
- x_q, x_scale = quantize_tensor(x, w_q.dtype)
33
- out = jax.lax.dot_general(
34
- x_q,
35
- w_q,
36
- dimension_numbers=(((1, ), (1, )), ((), ())),
37
- preferred_element_type=acc_dtype,
38
- ).astype(jnp.float32)
39
- out *= x_scale
40
- else:
41
- out = jax.lax.dot_general(
42
- x,
43
- w_q,
44
- dimension_numbers=(((1, ), (1, )), ((), ())),
45
- preferred_element_type=jnp.float32,
46
- )
47
- out *= jnp.expand_dims(w_scale, 0)
48
- return out.astype(x.dtype)
49
-
50
-
51
19
  @jtu.with_config(jax_numpy_dtype_promotion="standard")
52
20
  class QuantizedMatmulKernelTest(jtu.JaxTestCase):
53
21
 
@@ -94,7 +62,7 @@ class QuantizedMatmulKernelTest(jtu.JaxTestCase):
94
62
  x_q_dtype=x_q_dtype,
95
63
  tuned_value=tuned_value,
96
64
  )
97
- expected = reference_quantized_matmul(
65
+ expected = xla_quantized_matmul(
98
66
  x, w_q, w_scale, quantize_activation=quantize_activation)
99
67
 
100
68
  self.assertAllClose(output,
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import random
2
16
 
3
17
  import jax
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np
@@ -176,7 +190,9 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
176
190
  )
177
191
  output = output[:cu_q_lens[distribution[-1]]]
178
192
 
179
- dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
193
+ dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
194
+ dtypes, "bit_width") else dtypes.itemsize_bits(
195
+ jnp.dtype(kv_dtype)))
180
196
  tols = {
181
197
  32: 0.15,
182
198
  16: 0.2,
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np
@@ -162,7 +176,9 @@ class RaggedPagedAttentionKernelTest(jtu.JaxTestCase):
162
176
  )
163
177
  output = output[:cu_q_lens[distribution[-1]]]
164
178
 
165
- dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
179
+ dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
180
+ dtypes, "bit_width") else dtypes.itemsize_bits(
181
+ jnp.dtype(kv_dtype)))
166
182
  tols = {
167
183
  32: 0.15,
168
184
  16: 0.2,
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,156 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import pytest
21
+ from jax.sharding import Mesh
22
+
23
+ from tpu_inference.layers.common.attention_interface import attention
24
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
25
+ from tpu_inference.runner.kv_cache import get_kv_cache_shape_with_mesh
26
+
27
+ # ---- Test Configuration & Constants ----
28
+
29
+ # Total number of tokens across all sequences in the batch
30
+ TOTAL_TOKENS = 10
31
+ # Number of sequences in the batch
32
+ NUM_SEQS = 2
33
+ # Padded maximum number of sequences
34
+ MAX_NUM_SEQS = 4
35
+ # Number of attention heads (Query)
36
+ NUM_HEADS = 8
37
+ # Number of attention heads (Key/Value) - for Grouped-Query Attention
38
+ NUM_KV_HEADS = 4
39
+ # Total number of blocks in the KV cache
40
+ NUM_BLOCKS = 32
41
+ # Number of tokens per block
42
+ BLOCK_SIZE = 16
43
+ # Maximum number of blocks a single sequence can occupy
44
+ MAX_BLOCKS_PER_SEQ = 8
45
+
46
+
47
+ @pytest.fixture
48
+ def mesh():
49
+ """Provides a mock 1D JAX mesh for testing."""
50
+ # Create a mesh with available devices, useful for running on CPU/GPU/TPU
51
+ # For this test, it will likely be a single CPU device.
52
+ devices = np.array(jax.local_devices()[:1])
53
+ if not devices.any():
54
+ # Add a mock device if no devices are present (e.g., in a CI environment)
55
+ devices = np.array([jax.devices("cpu")[0]])
56
+ return Mesh(devices.reshape((-1, 1, 1)), ("data", "attn_dp", "model"))
57
+
58
+
59
+ # ---- Test for `attention` ----
60
+
61
+
62
+ def _test_attention(monkeypatch, mesh, head_dim, use_sinks=False):
63
+ """
64
+ Tests the main `attention` function.
65
+
66
+ Verifies that:
67
+ 1. It calls the `sharded_ragged_paged_attention` kernel with correct metadata.
68
+ 2. The final outputs (kv_cache and attention output) have the correct shapes.
69
+ """
70
+ # 1. Arrange
71
+
72
+ # Create input tensors
73
+ q_dtype = jnp.float32
74
+ kv_dtype = jnp.float32
75
+ q = jnp.ones((TOTAL_TOKENS, NUM_HEADS, head_dim), dtype=q_dtype)
76
+ k = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, head_dim), dtype=kv_dtype)
77
+ v = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, head_dim), dtype=kv_dtype)
78
+ sinks = jnp.ones((NUM_HEADS, ), dtype=jnp.float32) if use_sinks else None
79
+
80
+ kv_cache_shape = get_kv_cache_shape_with_mesh(
81
+ mesh,
82
+ NUM_BLOCKS,
83
+ BLOCK_SIZE,
84
+ NUM_KV_HEADS,
85
+ head_dim,
86
+ kv_dtype,
87
+ )
88
+ kv_cache = jnp.zeros(kv_cache_shape, dtype=kv_dtype)
89
+
90
+ # Mock ragged_paged_attention to return a tensor of the correct shape
91
+ mock_paged_attn_kernel = MagicMock(return_value=(jnp.ones(
92
+ (TOTAL_TOKENS, NUM_HEADS, head_dim)), kv_cache), )
93
+
94
+ if head_dim == 64:
95
+ monkeypatch.setattr(
96
+ "tpu_inference.layers.common.attention_interface.ragged_paged_attention_hd64",
97
+ mock_paged_attn_kernel,
98
+ )
99
+ else:
100
+ monkeypatch.setattr(
101
+ "tpu_inference.layers.common.attention_interface.ragged_paged_attention",
102
+ mock_paged_attn_kernel,
103
+ )
104
+
105
+ # Create AttentionMetadata
106
+ attention_metadata = AttentionMetadata(
107
+ input_positions=jnp.arange(TOTAL_TOKENS, dtype=jnp.int32),
108
+ block_tables=jnp.zeros((MAX_NUM_SEQS * MAX_BLOCKS_PER_SEQ, ),
109
+ dtype=jnp.int32),
110
+ seq_lens=jnp.array([5, 5, 0, 0], dtype=jnp.int32),
111
+ query_start_loc=jnp.array([0, 5, 10, 10, 10], dtype=jnp.int32),
112
+ request_distribution=jnp.array([0, 0, NUM_SEQS], dtype=jnp.int32),
113
+ )
114
+
115
+ # 2. Act
116
+ final_kv_cache, output = attention(
117
+ kv_cache=kv_cache,
118
+ q=q,
119
+ k=k,
120
+ v=v,
121
+ attention_metadata=attention_metadata,
122
+ mesh=mesh,
123
+ head_dim_original=head_dim,
124
+ sinks=sinks,
125
+ )
126
+
127
+ # 3. Assert
128
+ # Check that both mocked kernels were called
129
+ mock_paged_attn_kernel.assert_called_once()
130
+
131
+ # Check output shapes
132
+ assert final_kv_cache.shape == kv_cache.shape
133
+ assert output.shape == q.shape
134
+
135
+ # Check that the output is the one from our mock
136
+ assert jnp.all(output == 1.0)
137
+
138
+
139
+ def test_attention(monkeypatch, mesh):
140
+ _test_attention(monkeypatch, mesh, 128)
141
+
142
+
143
+ def test_attention_hd64(monkeypatch, mesh):
144
+ _test_attention(monkeypatch, mesh, 64)
145
+
146
+
147
+ def test_attention_sink(monkeypatch, mesh):
148
+ _test_attention(monkeypatch, mesh, 64, True)
149
+
150
+
151
+ def test_attention_sink_no_64_raises_error(monkeypatch, mesh):
152
+ with pytest.raises(
153
+ NotImplementedError,
154
+ match="Attention sink support is only available when head_dim==64"
155
+ ):
156
+ _test_attention(monkeypatch, mesh, 128, True)