tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,205 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+ from absl.testing import absltest, parameterized
18
+ from jax._src import test_util as jtu
19
+
20
+ from tpu_inference.kernels.megablox.gmm import gmm
21
+
22
+ jax.config.parse_flags_with_absl()
23
+
24
+
25
+ def quantize_tensor(x: jax.Array,
26
+ dtype: jnp.dtype,
27
+ axis: int = -1,
28
+ block_size: int = 256):
29
+ if jnp.issubdtype(dtype, jnp.integer):
30
+ dtype_info = jnp.iinfo(dtype)
31
+ max_val = int(dtype_info.max)
32
+ min_val = int(dtype_info.min)
33
+ else:
34
+ dtype_info = jnp.finfo(dtype)
35
+ max_val = float(dtype_info.max)
36
+ min_val = float(dtype_info.min)
37
+
38
+ orig_shape = x.shape
39
+ blocked_shape = orig_shape[:axis] + (-1,
40
+ block_size) + orig_shape[axis + 1:]
41
+ x_blocked = x.reshape(blocked_shape)
42
+
43
+ x_blocked_abs_max = jnp.max(jnp.abs(x_blocked),
44
+ axis=axis + 1,
45
+ keepdims=True)
46
+ scale = x_blocked_abs_max / max_val
47
+ x_blocked_q = jnp.clip(x_blocked / scale, min_val, max_val).astype(dtype)
48
+
49
+ x_q = x_blocked_q.reshape(orig_shape)
50
+ scale = scale.squeeze(axis=axis + 1).astype(jnp.float32)
51
+ return x_q, scale
52
+
53
+
54
+ def reference_gmm(
55
+ lhs: jax.Array,
56
+ rhs: jax.Array,
57
+ group_sizes: jax.Array,
58
+ rhs_scale: jax.Array | None = None,
59
+ rhs_bias: jax.Array | None = None,
60
+ group_offset: jax.Array | None = None,
61
+ ):
62
+ num_groups, out_size, in_size = rhs.shape
63
+ assert lhs.shape[1] == in_size
64
+
65
+ if group_offset is None:
66
+ group_offset = jnp.array(0, dtype=jnp.int32)
67
+ start = group_sizes[:group_offset].sum()
68
+ group_sizes = group_sizes[group_offset:]
69
+ assert len(group_sizes) == num_groups
70
+
71
+ if rhs_scale is not None:
72
+ num_blocks = rhs_scale.shape[1]
73
+ else:
74
+ num_blocks = 1
75
+ block_size = in_size // num_blocks
76
+
77
+ gmm_out = [jnp.zeros((start, out_size), lhs.dtype)]
78
+ for group in range(num_groups):
79
+ end = start + group_sizes[group]
80
+
81
+ lhs_slice = lhs[start:end]
82
+ rhs_slice = rhs[group]
83
+
84
+ out = 0
85
+ for block in range(num_blocks):
86
+ block_start = block * block_size
87
+ block_end = block_start + block_size
88
+ lhs_block = lhs_slice[:, block_start:block_end].astype(jnp.float32)
89
+ rhs_block = rhs_slice[:, block_start:block_end].astype(jnp.float32)
90
+
91
+ acc = jnp.einsum("bd,hd->bh", lhs_block, rhs_block)
92
+ if rhs_scale is not None:
93
+ acc *= rhs_scale[group][block]
94
+ out += acc
95
+ if rhs_bias is not None:
96
+ out = out + rhs_bias[group]
97
+
98
+ gmm_out.append(out.astype(lhs.dtype))
99
+ start = end
100
+
101
+ return jnp.concat(gmm_out, axis=0)
102
+
103
+
104
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
105
+ class GmmTest(jtu.JaxTestCase):
106
+
107
+ @parameterized.product(
108
+ batch_size=[128],
109
+ in_size=[1024],
110
+ out_size=[1024],
111
+ num_groups=[16, 32],
112
+ has_bias=[True, False],
113
+ )
114
+ def test_gmm(self, batch_size, in_size, out_size, num_groups, has_bias):
115
+ key = jax.random.key(0)
116
+
117
+ lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
118
+ rhs = jax.random.normal(key, (num_groups, out_size, in_size),
119
+ dtype=jnp.bfloat16)
120
+ rhs_bias = None
121
+ if has_bias:
122
+ rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
123
+ dtype=jnp.bfloat16)
124
+
125
+ group_sizes = jax.random.randint(key, (num_groups, ),
126
+ 0,
127
+ batch_size,
128
+ dtype=jnp.int32)
129
+
130
+ expected = reference_gmm(lhs, rhs, group_sizes, rhs_bias=rhs_bias)
131
+ actual = gmm(
132
+ lhs,
133
+ rhs,
134
+ group_sizes,
135
+ rhs_bias=rhs_bias,
136
+ transpose_rhs=True,
137
+ preferred_element_type=jnp.bfloat16,
138
+ )
139
+
140
+ self.assertArraysAllClose(actual, expected)
141
+
142
+ @parameterized.product(
143
+ batch_size=[128],
144
+ in_size=[1024],
145
+ out_size=[1024],
146
+ num_groups=[16, 32],
147
+ has_bias=[True, False],
148
+ weight_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn],
149
+ block_size=[256, 512],
150
+ )
151
+ def test_gmm_weight_quantized(
152
+ self,
153
+ batch_size,
154
+ in_size,
155
+ out_size,
156
+ num_groups,
157
+ has_bias,
158
+ weight_dtype,
159
+ block_size,
160
+ ):
161
+ if weight_dtype == jnp.float4_e2m1fn and not jtu.is_device_tpu_at_least(
162
+ version=7):
163
+ self.skipTest("Expect TPUv7+")
164
+ key = jax.random.key(0)
165
+
166
+ lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
167
+ rhs = jax.random.normal(key, (num_groups, out_size, in_size),
168
+ dtype=jnp.bfloat16)
169
+ rhs_q, rhs_scale = quantize_tensor(rhs,
170
+ weight_dtype,
171
+ axis=2,
172
+ block_size=block_size)
173
+ rhs_scale = jnp.swapaxes(rhs_scale, 1, 2)
174
+ rhs_scale = jnp.expand_dims(rhs_scale, axis=2)
175
+
176
+ rhs_bias = None
177
+ if has_bias:
178
+ rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
179
+ dtype=jnp.bfloat16)
180
+
181
+ group_sizes = jax.random.randint(key, (num_groups, ),
182
+ 0,
183
+ batch_size,
184
+ dtype=jnp.int32)
185
+
186
+ expected = reference_gmm(lhs,
187
+ rhs_q,
188
+ group_sizes,
189
+ rhs_scale=rhs_scale,
190
+ rhs_bias=rhs_bias)
191
+ actual = gmm(
192
+ lhs,
193
+ rhs_q,
194
+ group_sizes,
195
+ rhs_scale=rhs_scale,
196
+ rhs_bias=rhs_bias,
197
+ transpose_rhs=True,
198
+ preferred_element_type=jnp.bfloat16,
199
+ )
200
+
201
+ self.assertArraysAllClose(actual, expected, atol=3e-1, rtol=3e-1)
202
+
203
+
204
+ if __name__ == "__main__":
205
+ absltest.main(testLoader=jtu.JaxTestLoader())
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import random
2
16
 
3
17
  import jax
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  import jax.numpy as jnp
3
17
  import numpy as np
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,156 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import pytest
21
+ from jax.sharding import Mesh
22
+
23
+ from tpu_inference.layers.common.attention_interface import attention
24
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
25
+ from tpu_inference.runner.kv_cache import get_kv_cache_shape_with_mesh
26
+
27
+ # ---- Test Configuration & Constants ----
28
+
29
+ # Total number of tokens across all sequences in the batch
30
+ TOTAL_TOKENS = 10
31
+ # Number of sequences in the batch
32
+ NUM_SEQS = 2
33
+ # Padded maximum number of sequences
34
+ MAX_NUM_SEQS = 4
35
+ # Number of attention heads (Query)
36
+ NUM_HEADS = 8
37
+ # Number of attention heads (Key/Value) - for Grouped-Query Attention
38
+ NUM_KV_HEADS = 4
39
+ # Total number of blocks in the KV cache
40
+ NUM_BLOCKS = 32
41
+ # Number of tokens per block
42
+ BLOCK_SIZE = 16
43
+ # Maximum number of blocks a single sequence can occupy
44
+ MAX_BLOCKS_PER_SEQ = 8
45
+
46
+
47
+ @pytest.fixture
48
+ def mesh():
49
+ """Provides a mock 1D JAX mesh for testing."""
50
+ # Create a mesh with available devices, useful for running on CPU/GPU/TPU
51
+ # For this test, it will likely be a single CPU device.
52
+ devices = np.array(jax.local_devices()[:1])
53
+ if not devices.any():
54
+ # Add a mock device if no devices are present (e.g., in a CI environment)
55
+ devices = np.array([jax.devices("cpu")[0]])
56
+ return Mesh(devices.reshape((-1, 1, 1)), ("data", "attn_dp", "model"))
57
+
58
+
59
+ # ---- Test for `attention` ----
60
+
61
+
62
+ def _test_attention(monkeypatch, mesh, head_dim, use_sinks=False):
63
+ """
64
+ Tests the main `attention` function.
65
+
66
+ Verifies that:
67
+ 1. It calls the `sharded_ragged_paged_attention` kernel with correct metadata.
68
+ 2. The final outputs (kv_cache and attention output) have the correct shapes.
69
+ """
70
+ # 1. Arrange
71
+
72
+ # Create input tensors
73
+ q_dtype = jnp.float32
74
+ kv_dtype = jnp.float32
75
+ q = jnp.ones((TOTAL_TOKENS, NUM_HEADS, head_dim), dtype=q_dtype)
76
+ k = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, head_dim), dtype=kv_dtype)
77
+ v = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, head_dim), dtype=kv_dtype)
78
+ sinks = jnp.ones((NUM_HEADS, ), dtype=jnp.float32) if use_sinks else None
79
+
80
+ kv_cache_shape = get_kv_cache_shape_with_mesh(
81
+ mesh,
82
+ NUM_BLOCKS,
83
+ BLOCK_SIZE,
84
+ NUM_KV_HEADS,
85
+ head_dim,
86
+ kv_dtype,
87
+ )
88
+ kv_cache = jnp.zeros(kv_cache_shape, dtype=kv_dtype)
89
+
90
+ # Mock ragged_paged_attention to return a tensor of the correct shape
91
+ mock_paged_attn_kernel = MagicMock(return_value=(jnp.ones(
92
+ (TOTAL_TOKENS, NUM_HEADS, head_dim)), kv_cache), )
93
+
94
+ if head_dim == 64:
95
+ monkeypatch.setattr(
96
+ "tpu_inference.layers.common.attention_interface.ragged_paged_attention_hd64",
97
+ mock_paged_attn_kernel,
98
+ )
99
+ else:
100
+ monkeypatch.setattr(
101
+ "tpu_inference.layers.common.attention_interface.ragged_paged_attention",
102
+ mock_paged_attn_kernel,
103
+ )
104
+
105
+ # Create AttentionMetadata
106
+ attention_metadata = AttentionMetadata(
107
+ input_positions=jnp.arange(TOTAL_TOKENS, dtype=jnp.int32),
108
+ block_tables=jnp.zeros((MAX_NUM_SEQS * MAX_BLOCKS_PER_SEQ, ),
109
+ dtype=jnp.int32),
110
+ seq_lens=jnp.array([5, 5, 0, 0], dtype=jnp.int32),
111
+ query_start_loc=jnp.array([0, 5, 10, 10, 10], dtype=jnp.int32),
112
+ request_distribution=jnp.array([0, 0, NUM_SEQS], dtype=jnp.int32),
113
+ )
114
+
115
+ # 2. Act
116
+ final_kv_cache, output = attention(
117
+ kv_cache=kv_cache,
118
+ q=q,
119
+ k=k,
120
+ v=v,
121
+ attention_metadata=attention_metadata,
122
+ mesh=mesh,
123
+ head_dim_original=head_dim,
124
+ sinks=sinks,
125
+ )
126
+
127
+ # 3. Assert
128
+ # Check that both mocked kernels were called
129
+ mock_paged_attn_kernel.assert_called_once()
130
+
131
+ # Check output shapes
132
+ assert final_kv_cache.shape == kv_cache.shape
133
+ assert output.shape == q.shape
134
+
135
+ # Check that the output is the one from our mock
136
+ assert jnp.all(output == 1.0)
137
+
138
+
139
+ def test_attention(monkeypatch, mesh):
140
+ _test_attention(monkeypatch, mesh, 128)
141
+
142
+
143
+ def test_attention_hd64(monkeypatch, mesh):
144
+ _test_attention(monkeypatch, mesh, 64)
145
+
146
+
147
+ def test_attention_sink(monkeypatch, mesh):
148
+ _test_attention(monkeypatch, mesh, 64, True)
149
+
150
+
151
+ def test_attention_sink_no_64_raises_error(monkeypatch, mesh):
152
+ with pytest.raises(
153
+ NotImplementedError,
154
+ match="Attention sink support is only available when head_dim==64"
155
+ ):
156
+ _test_attention(monkeypatch, mesh, 128, True)
@@ -0,0 +1,149 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from absl.testing import absltest, parameterized
20
+ from jax._src import test_util as jtu
21
+
22
+ from tpu_inference.layers.common.quantization import (
23
+ dequantize_tensor, dequantize_tensor_from_mxfp4_packed, quantize_kv,
24
+ quantize_tensor, quantize_tensor_to_mxfp4_packed)
25
+
26
+
27
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
28
+ class QuantizationTest(jtu.JaxTestCase):
29
+
30
+ @parameterized.product(axis=[-1, 0, (0, 1)])
31
+ def test_mxfp4_quantization(self, axis):
32
+ if not jtu.is_device_tpu_at_least(version=7):
33
+ self.skipTest("mxfp4 is only supported in TPUv7+")
34
+
35
+ key = jax.random.key(0)
36
+
37
+ shape = (128, 128, 128)
38
+ original = jax.random.normal(key, shape, jnp.bfloat16)
39
+
40
+ tensor_q, scale = quantize_tensor_to_mxfp4_packed(original, axis)
41
+ dequantized = dequantize_tensor_from_mxfp4_packed(
42
+ tensor_q, scale, axis)
43
+
44
+ self.assertAllClose(dequantized, original, rtol=0.5, atol=0.5)
45
+
46
+ @parameterized.product(dtype=[jnp.float8_e4m3fn, jnp.int8],
47
+ axis=[None, -1, 1, (0, 1)])
48
+ def test_quantization(self, dtype, axis):
49
+ key = jax.random.key(0)
50
+
51
+ shape = (128, 128, 128)
52
+ original = jax.random.normal(key, shape, jnp.bfloat16)
53
+
54
+ tensor_q, scale = quantize_tensor(dtype, original, axis)
55
+ dequantized = dequantize_tensor(tensor_q, scale, axis)
56
+
57
+ self.assertAllClose(dequantized, original, rtol=0.1, atol=0.1)
58
+
59
+ @parameterized.product(dtype=[jnp.float8_e4m3fn, jnp.int8],
60
+ axis=[-1, 1],
61
+ block_size=[32, 64])
62
+ def test_block_quantization(self, dtype, axis, block_size):
63
+ key = jax.random.key(0)
64
+
65
+ shape = (128, 128, 128)
66
+ original = jax.random.normal(key, shape, jnp.bfloat16)
67
+
68
+ tensor_q, scale = quantize_tensor(dtype, original, axis, block_size)
69
+ dequantized = dequantize_tensor(tensor_q, scale, axis)
70
+
71
+ self.assertAllClose(dequantized, original, rtol=0.1, atol=0.1)
72
+
73
+ @parameterized.product(dtype=[jnp.float8_e4m3fn, jnp.int8],
74
+ axis=[(0, 1), (-1, 0)],
75
+ block_size=[32, (64, 32)])
76
+ def test_multi_block_quantization(self, dtype, axis, block_size):
77
+ key = jax.random.key(0)
78
+
79
+ shape = (128, 128, 128)
80
+ original = jax.random.normal(key, shape, jnp.bfloat16)
81
+
82
+ tensor_q, scale = quantize_tensor(dtype, original, axis, block_size)
83
+ dequantized = dequantize_tensor(tensor_q, scale, axis)
84
+
85
+ self.assertAllClose(dequantized, original, rtol=0.1, atol=0.1)
86
+
87
+ def test_unaligned_block_quantization_raises_error(self):
88
+ key = jax.random.key(0)
89
+
90
+ shape = (128, 128)
91
+ tensor = jax.random.normal(key, shape, jnp.bfloat16)
92
+ block_size = 100
93
+ axis = 0
94
+
95
+ self.assertRaises(
96
+ ValueError,
97
+ functools.partial(quantize_tensor, jnp.int8, tensor, axis,
98
+ block_size))
99
+
100
+ def test_block_quantization_padding(self):
101
+ key = jax.random.key(0)
102
+
103
+ shape = (128, 128)
104
+
105
+ original = jax.random.normal(key, shape, jnp.bfloat16)
106
+ block_size = 100
107
+ axis = 0
108
+
109
+ tensor_q, scale = quantize_tensor(jnp.int8, original, axis, block_size,
110
+ True)
111
+
112
+ dequantized = dequantize_tensor(tensor_q, scale, axis)
113
+
114
+ padded_size = ((shape[axis] + block_size) // block_size) * block_size
115
+ self.assertEqual(tensor_q.shape[axis], padded_size)
116
+ self.assertTrue((tensor_q[shape[0]:] == 0).all())
117
+ self.assertAllClose(dequantized[:shape[0]],
118
+ original,
119
+ rtol=0.1,
120
+ atol=0.1)
121
+
122
+ @parameterized.product(kv_quant_dtype=[jnp.float8_e4m3fn, jnp.int8])
123
+ def test_quantize_kv(self, kv_quant_dtype):
124
+ """Tests the quantize_kv function with float8_e4m3fn dtype."""
125
+ key = jax.random.key(0)
126
+
127
+ shape = (128, 128)
128
+ k_original = jax.random.normal(key, shape, jnp.bfloat16)
129
+ v_original = jax.random.normal(key, shape, jnp.bfloat16)
130
+ k_scale = 0.1
131
+ v_scale = 0.2
132
+
133
+ k_quantized, v_quantized = quantize_kv(
134
+ kv_quant_dtype,
135
+ k_original,
136
+ v_original,
137
+ k_scale,
138
+ v_scale,
139
+ )
140
+
141
+ k_dequantized = k_quantized.astype(jnp.bfloat16) * k_scale
142
+ v_dequantized = v_quantized.astype(jnp.bfloat16) * v_scale
143
+
144
+ self.assertAllClose(k_dequantized, k_original, rtol=0.2, atol=0.2)
145
+ self.assertAllClose(v_dequantized, v_original, rtol=0.2, atol=0.2)
146
+
147
+
148
+ if __name__ == "__main__":
149
+ absltest.main(testLoader=jtu.JaxTestLoader())
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.