tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,149 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from absl.testing import absltest, parameterized
20
+ from jax._src import test_util as jtu
21
+
22
+ from tpu_inference.layers.common.quantization import (
23
+ dequantize_tensor, dequantize_tensor_from_mxfp4_packed, quantize_kv,
24
+ quantize_tensor, quantize_tensor_to_mxfp4_packed)
25
+
26
+
27
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
28
+ class QuantizationTest(jtu.JaxTestCase):
29
+
30
+ @parameterized.product(axis=[-1, 0, (0, 1)])
31
+ def test_mxfp4_quantization(self, axis):
32
+ if not jtu.is_device_tpu_at_least(version=7):
33
+ self.skipTest("mxfp4 is only supported in TPUv7+")
34
+
35
+ key = jax.random.key(0)
36
+
37
+ shape = (128, 128, 128)
38
+ original = jax.random.normal(key, shape, jnp.bfloat16)
39
+
40
+ tensor_q, scale = quantize_tensor_to_mxfp4_packed(original, axis)
41
+ dequantized = dequantize_tensor_from_mxfp4_packed(
42
+ tensor_q, scale, axis)
43
+
44
+ self.assertAllClose(dequantized, original, rtol=0.5, atol=0.5)
45
+
46
+ @parameterized.product(dtype=[jnp.float8_e4m3fn, jnp.int8],
47
+ axis=[None, -1, 1, (0, 1)])
48
+ def test_quantization(self, dtype, axis):
49
+ key = jax.random.key(0)
50
+
51
+ shape = (128, 128, 128)
52
+ original = jax.random.normal(key, shape, jnp.bfloat16)
53
+
54
+ tensor_q, scale = quantize_tensor(dtype, original, axis)
55
+ dequantized = dequantize_tensor(tensor_q, scale, axis)
56
+
57
+ self.assertAllClose(dequantized, original, rtol=0.1, atol=0.1)
58
+
59
+ @parameterized.product(dtype=[jnp.float8_e4m3fn, jnp.int8],
60
+ axis=[-1, 1],
61
+ block_size=[32, 64])
62
+ def test_block_quantization(self, dtype, axis, block_size):
63
+ key = jax.random.key(0)
64
+
65
+ shape = (128, 128, 128)
66
+ original = jax.random.normal(key, shape, jnp.bfloat16)
67
+
68
+ tensor_q, scale = quantize_tensor(dtype, original, axis, block_size)
69
+ dequantized = dequantize_tensor(tensor_q, scale, axis)
70
+
71
+ self.assertAllClose(dequantized, original, rtol=0.1, atol=0.1)
72
+
73
+ @parameterized.product(dtype=[jnp.float8_e4m3fn, jnp.int8],
74
+ axis=[(0, 1), (-1, 0)],
75
+ block_size=[32, (64, 32)])
76
+ def test_multi_block_quantization(self, dtype, axis, block_size):
77
+ key = jax.random.key(0)
78
+
79
+ shape = (128, 128, 128)
80
+ original = jax.random.normal(key, shape, jnp.bfloat16)
81
+
82
+ tensor_q, scale = quantize_tensor(dtype, original, axis, block_size)
83
+ dequantized = dequantize_tensor(tensor_q, scale, axis)
84
+
85
+ self.assertAllClose(dequantized, original, rtol=0.1, atol=0.1)
86
+
87
+ def test_unaligned_block_quantization_raises_error(self):
88
+ key = jax.random.key(0)
89
+
90
+ shape = (128, 128)
91
+ tensor = jax.random.normal(key, shape, jnp.bfloat16)
92
+ block_size = 100
93
+ axis = 0
94
+
95
+ self.assertRaises(
96
+ ValueError,
97
+ functools.partial(quantize_tensor, jnp.int8, tensor, axis,
98
+ block_size))
99
+
100
+ def test_block_quantization_padding(self):
101
+ key = jax.random.key(0)
102
+
103
+ shape = (128, 128)
104
+
105
+ original = jax.random.normal(key, shape, jnp.bfloat16)
106
+ block_size = 100
107
+ axis = 0
108
+
109
+ tensor_q, scale = quantize_tensor(jnp.int8, original, axis, block_size,
110
+ True)
111
+
112
+ dequantized = dequantize_tensor(tensor_q, scale, axis)
113
+
114
+ padded_size = ((shape[axis] + block_size) // block_size) * block_size
115
+ self.assertEqual(tensor_q.shape[axis], padded_size)
116
+ self.assertTrue((tensor_q[shape[0]:] == 0).all())
117
+ self.assertAllClose(dequantized[:shape[0]],
118
+ original,
119
+ rtol=0.1,
120
+ atol=0.1)
121
+
122
+ @parameterized.product(kv_quant_dtype=[jnp.float8_e4m3fn, jnp.int8])
123
+ def test_quantize_kv(self, kv_quant_dtype):
124
+ """Tests the quantize_kv function with float8_e4m3fn dtype."""
125
+ key = jax.random.key(0)
126
+
127
+ shape = (128, 128)
128
+ k_original = jax.random.normal(key, shape, jnp.bfloat16)
129
+ v_original = jax.random.normal(key, shape, jnp.bfloat16)
130
+ k_scale = 0.1
131
+ v_scale = 0.2
132
+
133
+ k_quantized, v_quantized = quantize_kv(
134
+ kv_quant_dtype,
135
+ k_original,
136
+ v_original,
137
+ k_scale,
138
+ v_scale,
139
+ )
140
+
141
+ k_dequantized = k_quantized.astype(jnp.bfloat16) * k_scale
142
+ v_dequantized = v_quantized.astype(jnp.bfloat16) * v_scale
143
+
144
+ self.assertAllClose(k_dequantized, k_original, rtol=0.2, atol=0.2)
145
+ self.assertAllClose(v_dequantized, v_original, rtol=0.2, atol=0.2)
146
+
147
+
148
+ if __name__ == "__main__":
149
+ absltest.main(testLoader=jtu.JaxTestLoader())
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,103 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ from typing import Tuple
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ from flax import nnx
22
+ from jax.sharding import Mesh
23
+ from parameterized import parameterized
24
+
25
+ from tpu_inference.layers.common.attention_interface import get_kv_cache_shape
26
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.layers.jax.attention.attention import Attention
28
+
29
+ KVCache = Tuple[jax.Array, jax.Array]
30
+
31
+
32
+ class TestAttention(unittest.TestCase):
33
+ """Unit test suite for the JAX Attention module."""
34
+
35
+ def setUp(self):
36
+ """Sets up the testing environment before each test."""
37
+ self.mesh = Mesh(
38
+ np.array(jax.devices()[:1]).reshape(1, 1, 1, -1),
39
+ axis_names=(
40
+ "data",
41
+ "attn_dp",
42
+ "expert",
43
+ "model",
44
+ ),
45
+ )
46
+
47
+ @parameterized.expand([["auto"], ["fp8"]])
48
+ def test_attention_forward_pass(self, kv_cache_str):
49
+ """Tests the forward pass of the Attention module in prefill mode."""
50
+ hidden_size = 1024
51
+ num_attention_heads = 8
52
+ head_dim = hidden_size // num_attention_heads
53
+
54
+ with jax.set_mesh(self.mesh):
55
+ attention = Attention(hidden_size=hidden_size,
56
+ num_attention_heads=num_attention_heads,
57
+ num_key_value_heads=num_attention_heads,
58
+ head_dim=head_dim,
59
+ rope_theta=10000.0,
60
+ rope_scaling={},
61
+ dtype=jnp.bfloat16,
62
+ mesh=self.mesh,
63
+ random_init=True,
64
+ rngs=nnx.Rngs(42),
65
+ kv_cache_dtype=kv_cache_str)
66
+
67
+ seq_len = 64
68
+ x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
69
+
70
+ block_size = 16
71
+ num_blocks = 8
72
+ kv_dtype = jnp.float8_e4m3fn if kv_cache_str == "fp8" else jnp.bfloat16
73
+ cache_shape = get_kv_cache_shape(num_blocks, block_size,
74
+ num_attention_heads, head_dim,
75
+ kv_dtype)
76
+
77
+ kv_cache = jnp.zeros(cache_shape, dtype=kv_dtype)
78
+
79
+ num_required_blocks = seq_len // block_size
80
+
81
+ attention_metadata = AttentionMetadata(
82
+ input_positions=jnp.arange(seq_len, dtype=jnp.int32),
83
+ block_tables=jnp.array(list(range(num_required_blocks)),
84
+ dtype=jnp.int32),
85
+ seq_lens=jnp.array([seq_len], dtype=jnp.int32),
86
+ query_start_loc=jnp.array([0, seq_len], dtype=jnp.int32),
87
+ request_distribution=jnp.array([0, 0, 1], dtype=jnp.int32),
88
+ )
89
+
90
+ new_kv_cache, output = attention(
91
+ x,
92
+ is_prefill=True,
93
+ kv_cache=kv_cache,
94
+ attention_metadata=attention_metadata,
95
+ )
96
+
97
+ self.assertEqual(output.shape, (seq_len, hidden_size))
98
+
99
+ self.assertEqual(new_kv_cache.shape, kv_cache.shape)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ unittest.main()
@@ -0,0 +1,233 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import unittest
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ from flax import nnx
22
+ from jax.sharding import Mesh, PartitionSpec
23
+ from parameterized import parameterized
24
+
25
+ import tpu_inference.kernels.mla.v1.kernel as mla
26
+ from tpu_inference.layers.common.attention_interface import get_kv_cache_shape
27
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
28
+ from tpu_inference.layers.common.sharding import ShardingAxisName
29
+ from tpu_inference.layers.jax.attention.deepseek_v3_attention import MLA
30
+
31
+
32
+ class TestMLA(unittest.TestCase):
33
+
34
+ def setUp(self):
35
+ os.environ["NEW_MODEL_DESIGN"] = "1"
36
+ self.mesh = Mesh(
37
+ np.array(jax.devices("tpu")[:1]).reshape(1, 1, 1, 1),
38
+ axis_names=("data", "attn_dp", "expert", "model"),
39
+ )
40
+
41
+ @parameterized.expand([["auto"], ["fp8"]])
42
+ def test_mla_forward_pass(self, kv_cache_str):
43
+ hidden_size = 256
44
+
45
+ num_key_value_heads = 32
46
+ qk_nope_head_dim = 64
47
+ qk_rope_head_dim = 32
48
+
49
+ with jax.set_mesh(self.mesh):
50
+ query_tnh_spec = PartitionSpec(None, ShardingAxisName.MLP_TENSOR,
51
+ None)
52
+ keyvalue_skh_spec = PartitionSpec(None,
53
+ ShardingAxisName.MLP_TENSOR,
54
+ None)
55
+ attn_o_tnh_spec = PartitionSpec(None, ShardingAxisName.MLP_TENSOR,
56
+ None)
57
+
58
+ mla_layer = MLA(
59
+ hidden_size=hidden_size,
60
+ num_attention_heads=32,
61
+ num_key_value_heads=num_key_value_heads,
62
+ head_dim=64, # MLA uses v_head_dim as head_dim
63
+ rope_theta=10000,
64
+ dtype=jnp.bfloat16,
65
+ q_lora_rank=512,
66
+ kv_lora_rank=512,
67
+ qk_nope_head_dim=
68
+ qk_nope_head_dim, # Half of DeepSeek v3's real values
69
+ qk_rope_head_dim=
70
+ qk_rope_head_dim, # Half of DeepSeek v3's real values
71
+ v_head_dim=64, # Half of DeepSeek v3's real values
72
+ rms_norm_eps=1e-5,
73
+ rngs=nnx.Rngs(42),
74
+ rope_scaling={
75
+ "beta_fast": 32,
76
+ "beta_slow": 1,
77
+ "factor": 40,
78
+ "mscale": 1.0,
79
+ "mscale_all_dim": 1.0,
80
+ "original_max_position_embeddings": 4096,
81
+ "type": "yarn",
82
+ },
83
+ mesh=self.mesh,
84
+ random_init=True,
85
+ kv_cache_dtype=kv_cache_str,
86
+ query_tnh=query_tnh_spec,
87
+ keyvalue_skh=keyvalue_skh_spec,
88
+ attn_o_tnh=attn_o_tnh_spec,
89
+ q_da_sharding=(None, ShardingAxisName.VOCAB),
90
+ anh_sharding=(None, ShardingAxisName.MLP_TENSOR, None),
91
+ ap_sharding=(None, ShardingAxisName.MLP_TENSOR),
92
+ kv_da_sharding=(None, ShardingAxisName.VOCAB),
93
+ rd_sharding=(ShardingAxisName.MLP_TENSOR, None),
94
+ )
95
+
96
+ # Create input tensor
97
+ seq_len = 32
98
+ x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
99
+
100
+ # Create KV cache
101
+ qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
102
+ block_size = 16
103
+ num_blocks = 8
104
+ kv_dtype = jnp.float8_e4m3fn if kv_cache_str == "fp8" else jnp.bfloat16
105
+ cache_shape = get_kv_cache_shape(num_blocks, block_size,
106
+ num_key_value_heads, qk_head_dim,
107
+ kv_dtype)
108
+ kv_cache = jnp.zeros(cache_shape, dtype=kv_dtype)
109
+
110
+ # Create attention metadata
111
+ attention_metadata = AttentionMetadata(
112
+ input_positions=jnp.arange(seq_len, dtype=jnp.int32),
113
+ block_tables=jnp.zeros((8, ), dtype=jnp.int32),
114
+ seq_lens=jnp.ones((1, ), dtype=jnp.int32) * seq_len,
115
+ query_start_loc=jnp.array(
116
+ [0, seq_len], dtype=jnp.int32), # This is cu_q_lens
117
+ request_distribution=jnp.array([0, 0, 1], dtype=jnp.int32),
118
+ )
119
+
120
+ mla_layer.rope.initialize_cache(self.mesh)
121
+
122
+ # Run forward pass
123
+ new_kv_cache, output = mla_layer(
124
+ x,
125
+ is_prefill=True,
126
+ kv_cache=kv_cache,
127
+ attention_metadata=attention_metadata)
128
+
129
+ # Verify output shapes
130
+ self.assertEqual(output.shape, (seq_len, hidden_size))
131
+ self.assertEqual(new_kv_cache.shape, kv_cache.shape)
132
+
133
+ @parameterized.expand([["auto"]]) # MLA kernel does not support fp8 yet
134
+ def test_mla_kernel_forward_pass(self, kv_cache_str):
135
+ hidden_size = 256
136
+
137
+ num_key_value_heads = 1
138
+ qk_nope_head_dim = 64
139
+ qk_rope_head_dim = 32
140
+ v_head_dim = 64
141
+ kv_lora_rank = 512
142
+
143
+ with jax.set_mesh(self.mesh):
144
+ query_tnh_spec = PartitionSpec(ShardingAxisName.MLP_TENSOR, None,
145
+ None)
146
+ keyvalue_skh_spec = PartitionSpec(ShardingAxisName.MLP_TENSOR,
147
+ None)
148
+ attn_o_tnh_spec = PartitionSpec(ShardingAxisName.MLP_TENSOR, None,
149
+ None)
150
+
151
+ mla_layer = MLA(
152
+ hidden_size=hidden_size,
153
+ num_attention_heads=32,
154
+ num_key_value_heads=num_key_value_heads,
155
+ head_dim=v_head_dim, # MLA uses v_head_dim as head_dim
156
+ rope_theta=10000,
157
+ dtype=jnp.bfloat16,
158
+ q_lora_rank=512,
159
+ kv_lora_rank=kv_lora_rank,
160
+ qk_nope_head_dim=qk_nope_head_dim,
161
+ qk_rope_head_dim=qk_rope_head_dim,
162
+ v_head_dim=v_head_dim,
163
+ rms_norm_eps=1e-5,
164
+ rngs=nnx.Rngs(42),
165
+ rope_scaling={
166
+ "beta_fast": 32,
167
+ "beta_slow": 1,
168
+ "factor": 40,
169
+ "mscale": 1.0,
170
+ "mscale_all_dim": 1.0,
171
+ "original_max_position_embeddings": 4096,
172
+ "type": "yarn",
173
+ },
174
+ mesh=self.mesh,
175
+ random_init=True,
176
+ kv_cache_dtype=kv_cache_str,
177
+ use_mla_kernel=
178
+ True, # Set to true, in order to trigger MLA kernel.
179
+ query_tnh=query_tnh_spec,
180
+ keyvalue_skh=keyvalue_skh_spec,
181
+ attn_o_tnh=attn_o_tnh_spec,
182
+ q_da_sharding=(None, ShardingAxisName.VOCAB),
183
+ anh_sharding=(None, ShardingAxisName.MLP_TENSOR, None),
184
+ ap_sharding=(None, ShardingAxisName.MLP_TENSOR),
185
+ kv_da_sharding=(None, ShardingAxisName.VOCAB),
186
+ rd_sharding=(ShardingAxisName.MLP_TENSOR, None),
187
+ )
188
+
189
+ # Create input tensor
190
+ seq_len = 32
191
+ x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
192
+
193
+ # Create KV cache for MLA kernel
194
+ block_size = 16
195
+ num_blocks = 8
196
+ kv_dtype = jnp.float8_e4m3fn if kv_cache_str == "fp8" else jnp.bfloat16
197
+
198
+ # For the MLA kernel, the head dimension is the sum of qk_nope_head_dim and v_head_dim
199
+ # and lora rank
200
+ cache_shape = mla.get_kv_cache_shape(
201
+ num_blocks, block_size, kv_lora_rank + qk_rope_head_dim,
202
+ kv_dtype)
203
+ kv_cache = jnp.zeros(cache_shape, dtype=kv_dtype)
204
+
205
+ # Create attention metadata
206
+ attention_metadata = AttentionMetadata(
207
+ input_positions=jnp.arange(seq_len, dtype=jnp.int32),
208
+ block_tables=jnp.zeros((8, ), dtype=jnp.int32),
209
+ seq_lens=jnp.ones((1, ), dtype=jnp.int32) * seq_len,
210
+ query_start_loc=jnp.array([0, seq_len], dtype=jnp.int32),
211
+ request_distribution=jnp.array([0, 0, 1], dtype=jnp.int32),
212
+ )
213
+
214
+ mla_layer.rope.initialize_cache(self.mesh)
215
+
216
+ # Run forward pass
217
+ new_kv_cache, output = mla_layer(
218
+ x,
219
+ is_prefill=True,
220
+ kv_cache=kv_cache,
221
+ attention_metadata=attention_metadata)
222
+
223
+ # Verify output shapes
224
+ self.assertEqual(output.shape, (seq_len, hidden_size))
225
+ self.assertEqual(new_kv_cache.shape, kv_cache.shape)
226
+
227
+
228
+ if __name__ == "__main__":
229
+ unittest.main()
230
+
231
+
232
+ def tearDownModule():
233
+ del os.environ["NEW_MODEL_DESIGN"]
@@ -0,0 +1,135 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import unittest
17
+ from dataclasses import dataclass, field
18
+
19
+ import chex
20
+
21
+ os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax import nnx
26
+ from jax.sharding import NamedSharding
27
+ from jax.sharding import PartitionSpec as P
28
+
29
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
30
+ from tpu_inference.layers.common.sharding import build_mesh
31
+ from tpu_inference.layers.jax.attention.llama4_attention import (
32
+ L2Norm, Llama4Attention)
33
+
34
+
35
+ @dataclass
36
+ class SimpleVLLMConfig:
37
+ additional_config: dict = field(default_factory=dict)
38
+
39
+
40
+ class Llama4AttentionTest(unittest.TestCase):
41
+ """Unit test suite for Llama4-specific attention components."""
42
+
43
+ def setUp(self):
44
+ devices = jax.devices()[:1]
45
+ sharding_strategy = {"tensor_parallelism": len(devices)}
46
+ self.mesh = build_mesh(devices, sharding_strategy)
47
+
48
+ def test_l2norm_forward_pass(self):
49
+ """Tests the forward pass of the L2Norm module with hardcoded values."""
50
+ eps = 1e-5
51
+ l2_norm = L2Norm(eps=eps)
52
+ x = jnp.array([[1.0, 2.0, 3.0, 4.0]], dtype=jnp.float32)
53
+
54
+ output = l2_norm(x)
55
+
56
+ self.assertEqual(output.shape, x.shape)
57
+ self.assertEqual(output.dtype, x.dtype)
58
+
59
+ # Expected values calculated manually:
60
+ # mean_sq = (1^2 + 2^2 + 3^2 + 4^2) / 4 = (1+4+9+16)/4 = 30/4 = 7.5
61
+ # norm_val = sqrt(7.5 + 1e-5)
62
+ # expected = x / norm_val
63
+ expected_output = jnp.array([[0.365148, 0.730297, 1.095445, 1.460594]],
64
+ dtype=jnp.float32)
65
+ self.assertTrue(jnp.allclose(output, expected_output, atol=1e-6))
66
+
67
+ def test_l2norm_with_zeros(self):
68
+ """Tests L2Norm with an all-zero input."""
69
+ l2_norm = L2Norm(eps=1e-5)
70
+ x = jnp.zeros((4, 8, 16))
71
+ output = l2_norm(x)
72
+ self.assertEqual(output.shape, x.shape)
73
+ # Output should be all zeros.
74
+ self.assertTrue(jnp.all(output == 0))
75
+
76
+ def test_l2norm_eps_effect(self):
77
+ """Tests the effect of the epsilon value in L2Norm."""
78
+ eps = 1e-3
79
+ l2_norm = L2Norm(eps=eps)
80
+ x = jax.random.normal(jax.random.PRNGKey(0), (1, 1, 128))
81
+ output = l2_norm(x)
82
+
83
+ mean_sq = jnp.mean(x**2, axis=-1, keepdims=True)
84
+ expected_output = x * jax.lax.rsqrt(mean_sq + eps)
85
+
86
+ self.assertTrue(jnp.allclose(output, expected_output))
87
+
88
+ def test_apply_temperature_tuning(self):
89
+ with jax.set_mesh(self.mesh):
90
+ hidden_size = 64
91
+ num_attention_heads = 4
92
+ head_dim = hidden_size // num_attention_heads
93
+
94
+ # Create dummy sharding objects
95
+ dummy_sharding = NamedSharding(self.mesh, P())
96
+
97
+ llama4_attention = Llama4Attention(
98
+ hidden_size=hidden_size,
99
+ num_attention_heads=num_attention_heads,
100
+ num_key_value_heads=num_attention_heads,
101
+ head_dim=head_dim,
102
+ rope_theta=10000.0,
103
+ rope_scaling={},
104
+ dtype=jnp.bfloat16,
105
+ kv_cache_dtype="auto",
106
+ use_qk_norm=False,
107
+ temperature_tuning=True,
108
+ temperature_tuning_scale=2.0,
109
+ temperature_tuning_floor_scale=2.0,
110
+ mesh=self.mesh,
111
+ random_init=True,
112
+ activation_attention_td=dummy_sharding,
113
+ activation_attention_out_td=dummy_sharding,
114
+ rngs=nnx.Rngs(42),
115
+ )
116
+
117
+ seq_len = 8
118
+ input_arr_TNH = jnp.ones((seq_len, num_attention_heads, head_dim),
119
+ dtype=jnp.bfloat16)
120
+ attention_metadata = AttentionMetadata(
121
+ input_positions=jnp.arange(seq_len, dtype=jnp.int32))
122
+ expected_scales = jnp.array(
123
+ [1, 2.375, 2.375, 3.20312, 3.20312, 3.76562, 3.76562, 4.21875],
124
+ dtype=jnp.bfloat16)
125
+ output = llama4_attention.apply_temperature_tuning(
126
+ attention_metadata, input_arr_TNH)
127
+ chex.assert_shape(output, (seq_len, num_attention_heads, head_dim))
128
+
129
+ expected_output = jnp.ones_like(
130
+ input_arr_TNH) * expected_scales[:, None, None]
131
+ chex.assert_trees_all_close(output, expected_output, atol=1e-3)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ unittest.main()
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.