tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +46 -17
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +44 -17
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
  240. tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,156 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import pytest
21
+ from jax.sharding import Mesh
22
+
23
+ from tpu_inference.layers.common.attention_interface import attention
24
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
25
+ from tpu_inference.runner.kv_cache import get_kv_cache_shape_with_mesh
26
+
27
+ # ---- Test Configuration & Constants ----
28
+
29
+ # Total number of tokens across all sequences in the batch
30
+ TOTAL_TOKENS = 10
31
+ # Number of sequences in the batch
32
+ NUM_SEQS = 2
33
+ # Padded maximum number of sequences
34
+ MAX_NUM_SEQS = 4
35
+ # Number of attention heads (Query)
36
+ NUM_HEADS = 8
37
+ # Number of attention heads (Key/Value) - for Grouped-Query Attention
38
+ NUM_KV_HEADS = 4
39
+ # Total number of blocks in the KV cache
40
+ NUM_BLOCKS = 32
41
+ # Number of tokens per block
42
+ BLOCK_SIZE = 16
43
+ # Maximum number of blocks a single sequence can occupy
44
+ MAX_BLOCKS_PER_SEQ = 8
45
+
46
+
47
+ @pytest.fixture
48
+ def mesh():
49
+ """Provides a mock 1D JAX mesh for testing."""
50
+ # Create a mesh with available devices, useful for running on CPU/GPU/TPU
51
+ # For this test, it will likely be a single CPU device.
52
+ devices = np.array(jax.local_devices()[:1])
53
+ if not devices.any():
54
+ # Add a mock device if no devices are present (e.g., in a CI environment)
55
+ devices = np.array([jax.devices("cpu")[0]])
56
+ return Mesh(devices.reshape((-1, 1, 1)), ("data", "attn_dp", "model"))
57
+
58
+
59
+ # ---- Test for `attention` ----
60
+
61
+
62
+ def _test_attention(monkeypatch, mesh, head_dim, use_sinks=False):
63
+ """
64
+ Tests the main `attention` function.
65
+
66
+ Verifies that:
67
+ 1. It calls the `sharded_ragged_paged_attention` kernel with correct metadata.
68
+ 2. The final outputs (kv_cache and attention output) have the correct shapes.
69
+ """
70
+ # 1. Arrange
71
+
72
+ # Create input tensors
73
+ q_dtype = jnp.float32
74
+ kv_dtype = jnp.float32
75
+ q = jnp.ones((TOTAL_TOKENS, NUM_HEADS, head_dim), dtype=q_dtype)
76
+ k = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, head_dim), dtype=kv_dtype)
77
+ v = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, head_dim), dtype=kv_dtype)
78
+ sinks = jnp.ones((NUM_HEADS, ), dtype=jnp.float32) if use_sinks else None
79
+
80
+ kv_cache_shape = get_kv_cache_shape_with_mesh(
81
+ mesh,
82
+ NUM_BLOCKS,
83
+ BLOCK_SIZE,
84
+ NUM_KV_HEADS,
85
+ head_dim,
86
+ kv_dtype,
87
+ )
88
+ kv_cache = jnp.zeros(kv_cache_shape, dtype=kv_dtype)
89
+
90
+ # Mock ragged_paged_attention to return a tensor of the correct shape
91
+ mock_paged_attn_kernel = MagicMock(return_value=(jnp.ones(
92
+ (TOTAL_TOKENS, NUM_HEADS, head_dim)), kv_cache), )
93
+
94
+ if head_dim == 64:
95
+ monkeypatch.setattr(
96
+ "tpu_inference.layers.common.attention_interface.ragged_paged_attention_hd64",
97
+ mock_paged_attn_kernel,
98
+ )
99
+ else:
100
+ monkeypatch.setattr(
101
+ "tpu_inference.layers.common.attention_interface.ragged_paged_attention",
102
+ mock_paged_attn_kernel,
103
+ )
104
+
105
+ # Create AttentionMetadata
106
+ attention_metadata = AttentionMetadata(
107
+ input_positions=jnp.arange(TOTAL_TOKENS, dtype=jnp.int32),
108
+ block_tables=jnp.zeros((MAX_NUM_SEQS * MAX_BLOCKS_PER_SEQ, ),
109
+ dtype=jnp.int32),
110
+ seq_lens=jnp.array([5, 5, 0, 0], dtype=jnp.int32),
111
+ query_start_loc=jnp.array([0, 5, 10, 10, 10], dtype=jnp.int32),
112
+ request_distribution=jnp.array([0, 0, NUM_SEQS], dtype=jnp.int32),
113
+ )
114
+
115
+ # 2. Act
116
+ final_kv_cache, output = attention(
117
+ kv_cache=kv_cache,
118
+ q=q,
119
+ k=k,
120
+ v=v,
121
+ attention_metadata=attention_metadata,
122
+ mesh=mesh,
123
+ head_dim_original=head_dim,
124
+ sinks=sinks,
125
+ )
126
+
127
+ # 3. Assert
128
+ # Check that both mocked kernels were called
129
+ mock_paged_attn_kernel.assert_called_once()
130
+
131
+ # Check output shapes
132
+ assert final_kv_cache.shape == kv_cache.shape
133
+ assert output.shape == q.shape
134
+
135
+ # Check that the output is the one from our mock
136
+ assert jnp.all(output == 1.0)
137
+
138
+
139
+ def test_attention(monkeypatch, mesh):
140
+ _test_attention(monkeypatch, mesh, 128)
141
+
142
+
143
+ def test_attention_hd64(monkeypatch, mesh):
144
+ _test_attention(monkeypatch, mesh, 64)
145
+
146
+
147
+ def test_attention_sink(monkeypatch, mesh):
148
+ _test_attention(monkeypatch, mesh, 64, True)
149
+
150
+
151
+ def test_attention_sink_no_64_raises_error(monkeypatch, mesh):
152
+ with pytest.raises(
153
+ NotImplementedError,
154
+ match="Attention sink support is only available when head_dim==64"
155
+ ):
156
+ _test_attention(monkeypatch, mesh, 128, True)
@@ -0,0 +1,149 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from absl.testing import absltest, parameterized
20
+ from jax._src import test_util as jtu
21
+
22
+ from tpu_inference.layers.common.quantization import (
23
+ dequantize_tensor, dequantize_tensor_from_mxfp4_packed, quantize_kv,
24
+ quantize_tensor, quantize_tensor_to_mxfp4_packed)
25
+
26
+
27
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
28
+ class QuantizationTest(jtu.JaxTestCase):
29
+
30
+ @parameterized.product(axis=[-1, 0, (0, 1)])
31
+ def test_mxfp4_quantization(self, axis):
32
+ if not jtu.is_device_tpu_at_least(version=7):
33
+ self.skipTest("mxfp4 is only supported in TPUv7+")
34
+
35
+ key = jax.random.key(0)
36
+
37
+ shape = (128, 128, 128)
38
+ original = jax.random.normal(key, shape, jnp.bfloat16)
39
+
40
+ tensor_q, scale = quantize_tensor_to_mxfp4_packed(original, axis)
41
+ dequantized = dequantize_tensor_from_mxfp4_packed(
42
+ tensor_q, scale, axis)
43
+
44
+ self.assertAllClose(dequantized, original, rtol=0.5, atol=0.5)
45
+
46
+ @parameterized.product(dtype=[jnp.float8_e4m3fn, jnp.int8],
47
+ axis=[None, -1, 1, (0, 1)])
48
+ def test_quantization(self, dtype, axis):
49
+ key = jax.random.key(0)
50
+
51
+ shape = (128, 128, 128)
52
+ original = jax.random.normal(key, shape, jnp.bfloat16)
53
+
54
+ tensor_q, scale = quantize_tensor(dtype, original, axis)
55
+ dequantized = dequantize_tensor(tensor_q, scale, axis)
56
+
57
+ self.assertAllClose(dequantized, original, rtol=0.1, atol=0.1)
58
+
59
+ @parameterized.product(dtype=[jnp.float8_e4m3fn, jnp.int8],
60
+ axis=[-1, 1],
61
+ block_size=[32, 64])
62
+ def test_block_quantization(self, dtype, axis, block_size):
63
+ key = jax.random.key(0)
64
+
65
+ shape = (128, 128, 128)
66
+ original = jax.random.normal(key, shape, jnp.bfloat16)
67
+
68
+ tensor_q, scale = quantize_tensor(dtype, original, axis, block_size)
69
+ dequantized = dequantize_tensor(tensor_q, scale, axis)
70
+
71
+ self.assertAllClose(dequantized, original, rtol=0.1, atol=0.1)
72
+
73
+ @parameterized.product(dtype=[jnp.float8_e4m3fn, jnp.int8],
74
+ axis=[(0, 1), (-1, 0)],
75
+ block_size=[32, (64, 32)])
76
+ def test_multi_block_quantization(self, dtype, axis, block_size):
77
+ key = jax.random.key(0)
78
+
79
+ shape = (128, 128, 128)
80
+ original = jax.random.normal(key, shape, jnp.bfloat16)
81
+
82
+ tensor_q, scale = quantize_tensor(dtype, original, axis, block_size)
83
+ dequantized = dequantize_tensor(tensor_q, scale, axis)
84
+
85
+ self.assertAllClose(dequantized, original, rtol=0.1, atol=0.1)
86
+
87
+ def test_unaligned_block_quantization_raises_error(self):
88
+ key = jax.random.key(0)
89
+
90
+ shape = (128, 128)
91
+ tensor = jax.random.normal(key, shape, jnp.bfloat16)
92
+ block_size = 100
93
+ axis = 0
94
+
95
+ self.assertRaises(
96
+ ValueError,
97
+ functools.partial(quantize_tensor, jnp.int8, tensor, axis,
98
+ block_size))
99
+
100
+ def test_block_quantization_padding(self):
101
+ key = jax.random.key(0)
102
+
103
+ shape = (128, 128)
104
+
105
+ original = jax.random.normal(key, shape, jnp.bfloat16)
106
+ block_size = 100
107
+ axis = 0
108
+
109
+ tensor_q, scale = quantize_tensor(jnp.int8, original, axis, block_size,
110
+ True)
111
+
112
+ dequantized = dequantize_tensor(tensor_q, scale, axis)
113
+
114
+ padded_size = ((shape[axis] + block_size) // block_size) * block_size
115
+ self.assertEqual(tensor_q.shape[axis], padded_size)
116
+ self.assertTrue((tensor_q[shape[0]:] == 0).all())
117
+ self.assertAllClose(dequantized[:shape[0]],
118
+ original,
119
+ rtol=0.1,
120
+ atol=0.1)
121
+
122
+ @parameterized.product(kv_quant_dtype=[jnp.float8_e4m3fn, jnp.int8])
123
+ def test_quantize_kv(self, kv_quant_dtype):
124
+ """Tests the quantize_kv function with float8_e4m3fn dtype."""
125
+ key = jax.random.key(0)
126
+
127
+ shape = (128, 128)
128
+ k_original = jax.random.normal(key, shape, jnp.bfloat16)
129
+ v_original = jax.random.normal(key, shape, jnp.bfloat16)
130
+ k_scale = 0.1
131
+ v_scale = 0.2
132
+
133
+ k_quantized, v_quantized = quantize_kv(
134
+ kv_quant_dtype,
135
+ k_original,
136
+ v_original,
137
+ k_scale,
138
+ v_scale,
139
+ )
140
+
141
+ k_dequantized = k_quantized.astype(jnp.bfloat16) * k_scale
142
+ v_dequantized = v_quantized.astype(jnp.bfloat16) * v_scale
143
+
144
+ self.assertAllClose(k_dequantized, k_original, rtol=0.2, atol=0.2)
145
+ self.assertAllClose(v_dequantized, v_original, rtol=0.2, atol=0.2)
146
+
147
+
148
+ if __name__ == "__main__":
149
+ absltest.main(testLoader=jtu.JaxTestLoader())
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,103 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ from typing import Tuple
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ from flax import nnx
22
+ from jax.sharding import Mesh
23
+ from parameterized import parameterized
24
+
25
+ from tpu_inference.layers.common.attention_interface import get_kv_cache_shape
26
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.layers.jax.attention.attention import Attention
28
+
29
+ KVCache = Tuple[jax.Array, jax.Array]
30
+
31
+
32
+ class TestAttention(unittest.TestCase):
33
+ """Unit test suite for the JAX Attention module."""
34
+
35
+ def setUp(self):
36
+ """Sets up the testing environment before each test."""
37
+ self.mesh = Mesh(
38
+ np.array(jax.devices()[:1]).reshape(1, 1, 1, -1),
39
+ axis_names=(
40
+ "data",
41
+ "attn_dp",
42
+ "expert",
43
+ "model",
44
+ ),
45
+ )
46
+
47
+ @parameterized.expand([["auto"], ["fp8"]])
48
+ def test_attention_forward_pass(self, kv_cache_str):
49
+ """Tests the forward pass of the Attention module in prefill mode."""
50
+ hidden_size = 1024
51
+ num_attention_heads = 8
52
+ head_dim = hidden_size // num_attention_heads
53
+
54
+ with jax.set_mesh(self.mesh):
55
+ attention = Attention(hidden_size=hidden_size,
56
+ num_attention_heads=num_attention_heads,
57
+ num_key_value_heads=num_attention_heads,
58
+ head_dim=head_dim,
59
+ rope_theta=10000.0,
60
+ rope_scaling={},
61
+ dtype=jnp.bfloat16,
62
+ mesh=self.mesh,
63
+ random_init=True,
64
+ rngs=nnx.Rngs(42),
65
+ kv_cache_dtype=kv_cache_str)
66
+
67
+ seq_len = 64
68
+ x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
69
+
70
+ block_size = 16
71
+ num_blocks = 8
72
+ kv_dtype = jnp.float8_e4m3fn if kv_cache_str == "fp8" else jnp.bfloat16
73
+ cache_shape = get_kv_cache_shape(num_blocks, block_size,
74
+ num_attention_heads, head_dim,
75
+ kv_dtype)
76
+
77
+ kv_cache = jnp.zeros(cache_shape, dtype=kv_dtype)
78
+
79
+ num_required_blocks = seq_len // block_size
80
+
81
+ attention_metadata = AttentionMetadata(
82
+ input_positions=jnp.arange(seq_len, dtype=jnp.int32),
83
+ block_tables=jnp.array(list(range(num_required_blocks)),
84
+ dtype=jnp.int32),
85
+ seq_lens=jnp.array([seq_len], dtype=jnp.int32),
86
+ query_start_loc=jnp.array([0, seq_len], dtype=jnp.int32),
87
+ request_distribution=jnp.array([0, 0, 1], dtype=jnp.int32),
88
+ )
89
+
90
+ new_kv_cache, output = attention(
91
+ x,
92
+ is_prefill=True,
93
+ kv_cache=kv_cache,
94
+ attention_metadata=attention_metadata,
95
+ )
96
+
97
+ self.assertEqual(output.shape, (seq_len, hidden_size))
98
+
99
+ self.assertEqual(new_kv_cache.shape, kv_cache.shape)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ unittest.main()
@@ -0,0 +1,233 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import unittest
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ from flax import nnx
22
+ from jax.sharding import Mesh, PartitionSpec
23
+ from parameterized import parameterized
24
+
25
+ import tpu_inference.kernels.mla.v1.kernel as mla
26
+ from tpu_inference.layers.common.attention_interface import get_kv_cache_shape
27
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
28
+ from tpu_inference.layers.common.sharding import ShardingAxisName
29
+ from tpu_inference.layers.jax.attention.deepseek_v3_attention import MLA
30
+
31
+
32
+ class TestMLA(unittest.TestCase):
33
+
34
+ def setUp(self):
35
+ os.environ["NEW_MODEL_DESIGN"] = "1"
36
+ self.mesh = Mesh(
37
+ np.array(jax.devices("tpu")[:1]).reshape(1, 1, 1, 1),
38
+ axis_names=("data", "attn_dp", "expert", "model"),
39
+ )
40
+
41
+ @parameterized.expand([["auto"], ["fp8"]])
42
+ def test_mla_forward_pass(self, kv_cache_str):
43
+ hidden_size = 256
44
+
45
+ num_key_value_heads = 32
46
+ qk_nope_head_dim = 64
47
+ qk_rope_head_dim = 32
48
+
49
+ with jax.set_mesh(self.mesh):
50
+ query_tnh_spec = PartitionSpec(None, ShardingAxisName.MLP_TENSOR,
51
+ None)
52
+ keyvalue_skh_spec = PartitionSpec(None,
53
+ ShardingAxisName.MLP_TENSOR,
54
+ None)
55
+ attn_o_tnh_spec = PartitionSpec(None, ShardingAxisName.MLP_TENSOR,
56
+ None)
57
+
58
+ mla_layer = MLA(
59
+ hidden_size=hidden_size,
60
+ num_attention_heads=32,
61
+ num_key_value_heads=num_key_value_heads,
62
+ head_dim=64, # MLA uses v_head_dim as head_dim
63
+ rope_theta=10000,
64
+ dtype=jnp.bfloat16,
65
+ q_lora_rank=512,
66
+ kv_lora_rank=512,
67
+ qk_nope_head_dim=
68
+ qk_nope_head_dim, # Half of DeepSeek v3's real values
69
+ qk_rope_head_dim=
70
+ qk_rope_head_dim, # Half of DeepSeek v3's real values
71
+ v_head_dim=64, # Half of DeepSeek v3's real values
72
+ rms_norm_eps=1e-5,
73
+ rngs=nnx.Rngs(42),
74
+ rope_scaling={
75
+ "beta_fast": 32,
76
+ "beta_slow": 1,
77
+ "factor": 40,
78
+ "mscale": 1.0,
79
+ "mscale_all_dim": 1.0,
80
+ "original_max_position_embeddings": 4096,
81
+ "type": "yarn",
82
+ },
83
+ mesh=self.mesh,
84
+ random_init=True,
85
+ kv_cache_dtype=kv_cache_str,
86
+ query_tnh=query_tnh_spec,
87
+ keyvalue_skh=keyvalue_skh_spec,
88
+ attn_o_tnh=attn_o_tnh_spec,
89
+ q_da_sharding=(None, ShardingAxisName.VOCAB),
90
+ anh_sharding=(None, ShardingAxisName.MLP_TENSOR, None),
91
+ ap_sharding=(None, ShardingAxisName.MLP_TENSOR),
92
+ kv_da_sharding=(None, ShardingAxisName.VOCAB),
93
+ rd_sharding=(ShardingAxisName.MLP_TENSOR, None),
94
+ )
95
+
96
+ # Create input tensor
97
+ seq_len = 32
98
+ x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
99
+
100
+ # Create KV cache
101
+ qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
102
+ block_size = 16
103
+ num_blocks = 8
104
+ kv_dtype = jnp.float8_e4m3fn if kv_cache_str == "fp8" else jnp.bfloat16
105
+ cache_shape = get_kv_cache_shape(num_blocks, block_size,
106
+ num_key_value_heads, qk_head_dim,
107
+ kv_dtype)
108
+ kv_cache = jnp.zeros(cache_shape, dtype=kv_dtype)
109
+
110
+ # Create attention metadata
111
+ attention_metadata = AttentionMetadata(
112
+ input_positions=jnp.arange(seq_len, dtype=jnp.int32),
113
+ block_tables=jnp.zeros((8, ), dtype=jnp.int32),
114
+ seq_lens=jnp.ones((1, ), dtype=jnp.int32) * seq_len,
115
+ query_start_loc=jnp.array(
116
+ [0, seq_len], dtype=jnp.int32), # This is cu_q_lens
117
+ request_distribution=jnp.array([0, 0, 1], dtype=jnp.int32),
118
+ )
119
+
120
+ mla_layer.rope.initialize_cache(self.mesh)
121
+
122
+ # Run forward pass
123
+ new_kv_cache, output = mla_layer(
124
+ x,
125
+ is_prefill=True,
126
+ kv_cache=kv_cache,
127
+ attention_metadata=attention_metadata)
128
+
129
+ # Verify output shapes
130
+ self.assertEqual(output.shape, (seq_len, hidden_size))
131
+ self.assertEqual(new_kv_cache.shape, kv_cache.shape)
132
+
133
+ @parameterized.expand([["auto"]]) # MLA kernel does not support fp8 yet
134
+ def test_mla_kernel_forward_pass(self, kv_cache_str):
135
+ hidden_size = 256
136
+
137
+ num_key_value_heads = 1
138
+ qk_nope_head_dim = 64
139
+ qk_rope_head_dim = 32
140
+ v_head_dim = 64
141
+ kv_lora_rank = 512
142
+
143
+ with jax.set_mesh(self.mesh):
144
+ query_tnh_spec = PartitionSpec(ShardingAxisName.MLP_TENSOR, None,
145
+ None)
146
+ keyvalue_skh_spec = PartitionSpec(ShardingAxisName.MLP_TENSOR,
147
+ None)
148
+ attn_o_tnh_spec = PartitionSpec(ShardingAxisName.MLP_TENSOR, None,
149
+ None)
150
+
151
+ mla_layer = MLA(
152
+ hidden_size=hidden_size,
153
+ num_attention_heads=32,
154
+ num_key_value_heads=num_key_value_heads,
155
+ head_dim=v_head_dim, # MLA uses v_head_dim as head_dim
156
+ rope_theta=10000,
157
+ dtype=jnp.bfloat16,
158
+ q_lora_rank=512,
159
+ kv_lora_rank=kv_lora_rank,
160
+ qk_nope_head_dim=qk_nope_head_dim,
161
+ qk_rope_head_dim=qk_rope_head_dim,
162
+ v_head_dim=v_head_dim,
163
+ rms_norm_eps=1e-5,
164
+ rngs=nnx.Rngs(42),
165
+ rope_scaling={
166
+ "beta_fast": 32,
167
+ "beta_slow": 1,
168
+ "factor": 40,
169
+ "mscale": 1.0,
170
+ "mscale_all_dim": 1.0,
171
+ "original_max_position_embeddings": 4096,
172
+ "type": "yarn",
173
+ },
174
+ mesh=self.mesh,
175
+ random_init=True,
176
+ kv_cache_dtype=kv_cache_str,
177
+ use_mla_kernel=
178
+ True, # Set to true, in order to trigger MLA kernel.
179
+ query_tnh=query_tnh_spec,
180
+ keyvalue_skh=keyvalue_skh_spec,
181
+ attn_o_tnh=attn_o_tnh_spec,
182
+ q_da_sharding=(None, ShardingAxisName.VOCAB),
183
+ anh_sharding=(None, ShardingAxisName.MLP_TENSOR, None),
184
+ ap_sharding=(None, ShardingAxisName.MLP_TENSOR),
185
+ kv_da_sharding=(None, ShardingAxisName.VOCAB),
186
+ rd_sharding=(ShardingAxisName.MLP_TENSOR, None),
187
+ )
188
+
189
+ # Create input tensor
190
+ seq_len = 32
191
+ x = jnp.ones((seq_len, hidden_size), dtype=jnp.bfloat16)
192
+
193
+ # Create KV cache for MLA kernel
194
+ block_size = 16
195
+ num_blocks = 8
196
+ kv_dtype = jnp.float8_e4m3fn if kv_cache_str == "fp8" else jnp.bfloat16
197
+
198
+ # For the MLA kernel, the head dimension is the sum of qk_nope_head_dim and v_head_dim
199
+ # and lora rank
200
+ cache_shape = mla.get_kv_cache_shape(
201
+ num_blocks, block_size, kv_lora_rank + qk_rope_head_dim,
202
+ kv_dtype)
203
+ kv_cache = jnp.zeros(cache_shape, dtype=kv_dtype)
204
+
205
+ # Create attention metadata
206
+ attention_metadata = AttentionMetadata(
207
+ input_positions=jnp.arange(seq_len, dtype=jnp.int32),
208
+ block_tables=jnp.zeros((8, ), dtype=jnp.int32),
209
+ seq_lens=jnp.ones((1, ), dtype=jnp.int32) * seq_len,
210
+ query_start_loc=jnp.array([0, seq_len], dtype=jnp.int32),
211
+ request_distribution=jnp.array([0, 0, 1], dtype=jnp.int32),
212
+ )
213
+
214
+ mla_layer.rope.initialize_cache(self.mesh)
215
+
216
+ # Run forward pass
217
+ new_kv_cache, output = mla_layer(
218
+ x,
219
+ is_prefill=True,
220
+ kv_cache=kv_cache,
221
+ attention_metadata=attention_metadata)
222
+
223
+ # Verify output shapes
224
+ self.assertEqual(output.shape, (seq_len, hidden_size))
225
+ self.assertEqual(new_kv_cache.shape, kv_cache.shape)
226
+
227
+
228
+ if __name__ == "__main__":
229
+ unittest.main()
230
+
231
+
232
+ def tearDownModule():
233
+ del os.environ["NEW_MODEL_DESIGN"]