tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import math
2
16
  from dataclasses import InitVar, dataclass
3
17
  from typing import Any, Tuple
@@ -6,14 +20,18 @@ import jax
6
20
  import jax.numpy as jnp
7
21
  from flax import nnx
8
22
  from flax.typing import Sharding
9
- from jax.experimental import shard_map
10
23
  from jax.sharding import Mesh
11
24
  from jax.sharding import PartitionSpec as P
12
25
 
13
26
  from tpu_inference import utils
27
+ from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention
14
28
  from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
15
29
  ragged_paged_attention
30
+ from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import \
31
+ get_tuned_block_sizes
16
32
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
33
+ from tpu_inference.layers.common.quantization import quantize_kv
34
+ from tpu_inference.layers.common.sharding import ShardingAxisName
17
35
  from tpu_inference.layers.jax.base import create_param
18
36
  from tpu_inference.layers.jax.layers import RMSNorm
19
37
  from tpu_inference.layers.jax.rope import DeepseekScalingRotaryEmbedding
@@ -48,8 +66,9 @@ class MLA(nnx.Module):
48
66
  rms_norm_eps: float
49
67
 
50
68
  # Sharding attributes
51
- nhd_sharding: Sharding = ()
69
+ rd_sharding: Sharding = ()
52
70
  q_da_sharding: Sharding = ()
71
+ ap_sharding: Sharding = ()
53
72
  anh_sharding: Sharding = ()
54
73
  kv_da_sharding: Sharding = ()
55
74
 
@@ -66,6 +85,7 @@ class MLA(nnx.Module):
66
85
  rope_input_ordering: str = "split"
67
86
  quant: Any | None = None
68
87
  rope_mscale_all_dim: float = 1.0
88
+ use_mla_kernel: bool = False
69
89
 
70
90
  rngs: InitVar[nnx.Rngs]
71
91
 
@@ -77,10 +97,10 @@ class MLA(nnx.Module):
77
97
  self.N = self.num_attention_heads
78
98
  self.K = self.num_key_value_heads
79
99
  self.D = self.hidden_size
80
-
81
100
  self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
82
101
 
83
- assert self.N == self.K, "N and K must be equal for MLA"
102
+ if not self.use_mla_kernel:
103
+ assert self.N == self.K, "N and K must be equal for MLA"
84
104
 
85
105
  if self.rope_scaling["factor"] <= 1.0:
86
106
  yarn_mscale = 1.0
@@ -108,10 +128,10 @@ class MLA(nnx.Module):
108
128
  self.q_da_sharding,
109
129
  self.dtype,
110
130
  random_init=self.random_init)
111
- self.kernel_q_up_proj_ANH = create_param(
131
+ self.kernel_q_up_proj_AP = create_param(
112
132
  rngs,
113
- (self.q_lora_rank, self.N, self.qk_head_dim),
114
- self.anh_sharding,
133
+ (self.q_lora_rank, self.N * self.qk_head_dim),
134
+ self.ap_sharding,
115
135
  self.dtype,
116
136
  random_init=self.random_init,
117
137
  )
@@ -122,17 +142,38 @@ class MLA(nnx.Module):
122
142
  self.dtype,
123
143
  random_init=self.random_init,
124
144
  )
125
- self.kernel_kv_up_proj_ANH = create_param(
126
- rngs,
127
- (self.kv_lora_rank, self.N,
128
- self.qk_nope_head_dim + self.v_head_dim),
129
- self.anh_sharding,
130
- self.dtype,
131
- random_init=self.random_init,
132
- )
133
- self.kernel_o_proj_NHD = create_param(
134
- rngs, (self.N, self.v_head_dim, self.D),
135
- self.nhd_sharding,
145
+ # NOTE (jacobplatin): we are keeping these variables as 3D because
146
+ # we would need to reshape them before the below projection,
147
+ # which caused issues as Qwix wasn't quantizing it correctly
148
+ # on the abstract pass
149
+ if self.use_mla_kernel:
150
+ self.kernel_k_up_proj_ANH = create_param(
151
+ rngs,
152
+ (self.kv_lora_rank, self.N, self.qk_nope_head_dim),
153
+ self.anh_sharding,
154
+ self.dtype,
155
+ random_init=self.random_init,
156
+ )
157
+ self.kernel_v_up_proj_ANH = create_param(
158
+ rngs,
159
+ (self.kv_lora_rank, self.N, self.v_head_dim),
160
+ self.anh_sharding,
161
+ self.dtype,
162
+ random_init=self.random_init,
163
+ )
164
+ else:
165
+ self.kernel_kv_up_proj_AL = create_param(
166
+ rngs,
167
+ (self.kv_lora_rank, self.N *
168
+ (self.qk_nope_head_dim + self.v_head_dim)),
169
+ self.
170
+ ap_sharding, # NOTE: we use the same sharding for kv_up_proj_AL and kernel_q_up_proj_AP
171
+ self.dtype,
172
+ random_init=self.random_init,
173
+ )
174
+ self.kernel_o_proj_RD = create_param(
175
+ rngs, (self.N * self.v_head_dim, self.D),
176
+ self.rd_sharding,
136
177
  self.dtype,
137
178
  random_init=self.random_init)
138
179
  self.q_rms_norm = RMSNorm(
@@ -188,17 +229,24 @@ class MLA(nnx.Module):
188
229
  q_TA = jnp.einsum("TD,DA -> TA", x_q_TD,
189
230
  self.kernel_q_down_proj_DA.value)
190
231
  q_TA = self.q_rms_norm(q_TA)
191
- # Query up projection.
192
- q_TNH = jnp.einsum("TA,ANH -> TNH", q_TA,
193
- self.kernel_q_up_proj_ANH.value)
232
+ # Query up projection, then reshape to TNH.
233
+ q_TP = jnp.einsum("TA,AP -> TP", q_TA,
234
+ self.kernel_q_up_proj_AP.value)
235
+ q_TNH = q_TP.reshape(q_TA.shape[0], self.N, self.qk_head_dim)
194
236
  # Split the query into nope and rope.
195
237
  q_nope_TNH = q_TNH[..., :self.qk_nope_head_dim]
196
238
  q_rope_TNH = q_TNH[..., self.qk_nope_head_dim:]
197
239
  q_rope_TNH = self.rope.apply_rope(md.input_positions, q_rope_TNH)
198
- # Concatenate the nope and rope queries.
199
- q_TNH = jnp.concatenate([q_nope_TNH, q_rope_TNH], axis=-1)
200
- # Multiple the query by scaling factor
201
- q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
240
+ if self.use_mla_kernel:
241
+ # Absorb the k up-projection matrix into q
242
+ q_TNA = jnp.einsum("TNH,ANH -> TNA", q_nope_TNH,
243
+ self.kernel_k_up_proj_ANH.value)
244
+ q_TNA = nnx.with_sharding_constraint(q_TNA, self.query_tnh)
245
+ else:
246
+ # Concatenate the nope and rope queries.
247
+ q_TNH = jnp.concatenate([q_nope_TNH, q_rope_TNH], axis=-1)
248
+ # Multiply the query by scaling factor
249
+ q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
202
250
 
203
251
  with jax.named_scope("kv_proj"):
204
252
  # KV down projection.
@@ -209,21 +257,30 @@ class MLA(nnx.Module):
209
257
  # Reshape k_rope_BSH to include head dimension for RoPE application
210
258
  k_rope_SNH = k_rope_SH[..., None, :]
211
259
  k_rope_SNH = self.rope.apply_rope(md.input_positions, k_rope_SNH)
212
- k_rope_SNH = jnp.broadcast_to(
213
- k_rope_SNH,
214
- (k_rope_SNH.shape[0], self.N, self.qk_rope_head_dim))
260
+ assert k_rope_SNH.shape[1] == 1
261
+ k_rope_SH = k_rope_SNH[:, 0, :]
262
+
215
263
  kv_SA = kv_SA[..., :self.kv_lora_rank]
216
264
  kv_SA = self.kv_rms_norm(kv_SA)
217
- # KV up projection.
218
- kv_nope_SNH = jnp.einsum("SA,ANH -> SNH", kv_SA,
219
- self.kernel_kv_up_proj_ANH.value)
220
- # Split the latent kv vector into k nope vector and v vector.
221
- k_nope_SNH = kv_nope_SNH[..., :self.qk_nope_head_dim]
222
- v_SNH = kv_nope_SNH[..., self.qk_nope_head_dim:]
223
- # Concatenate the key vector.
224
- k_SNH = jnp.concatenate([k_nope_SNH, k_rope_SNH], axis=-1)
225
- k_SNH = nnx.with_sharding_constraint(k_SNH, self.keyvalue_skh)
226
- v_SNH = nnx.with_sharding_constraint(v_SNH, self.keyvalue_skh)
265
+ kv_SA = nnx.with_sharding_constraint(kv_SA, self.keyvalue_skh)
266
+
267
+ if not self.use_mla_kernel:
268
+ k_rope_SNH = jnp.broadcast_to(
269
+ k_rope_SNH,
270
+ (k_rope_SNH.shape[0], self.N, self.qk_rope_head_dim))
271
+ # KV up projection, then reshape to SN(Hk+Hv).
272
+ kv_SL = jnp.einsum("SA,AL -> SL", kv_SA,
273
+ self.kernel_kv_up_proj_AL.value)
274
+ kv_nope_SNH = kv_SL.reshape(
275
+ kv_SA.shape[0], self.N,
276
+ self.qk_nope_head_dim + self.v_head_dim)
277
+ # Split the latent kv vector into k nope vector and v vector.
278
+ k_nope_SNH = kv_nope_SNH[..., :self.qk_nope_head_dim]
279
+ v_SNH = kv_nope_SNH[..., self.qk_nope_head_dim:]
280
+ # Concatenate the key vector.
281
+ k_SNH = jnp.concatenate([k_nope_SNH, k_rope_SNH], axis=-1)
282
+ k_SNH = nnx.with_sharding_constraint(k_SNH, self.keyvalue_skh)
283
+ v_SNH = nnx.with_sharding_constraint(v_SNH, self.keyvalue_skh)
227
284
 
228
285
  with jax.named_scope("attn_op"):
229
286
  # TODO(wenxindongwork): K and V have different head dimension,
@@ -234,44 +291,67 @@ class MLA(nnx.Module):
234
291
  # q, k, v head dimension to be multiple of 128. For now, we will
235
292
  # pad the q, k, v dimension to multiple of 128.
236
293
  # We should update the MLA kv cache implementation in the future.
237
- multiple_of_128 = ((self.qk_head_dim - 1) // 128 + 1) * 128
238
- q_TNH = jnp.pad(q_TNH, ((0, 0), (0, 0),
239
- (0, multiple_of_128 - self.qk_head_dim)))
240
- k_SNH = jnp.pad(k_SNH, ((0, 0), (0, 0),
241
- (0, multiple_of_128 - self.qk_head_dim)))
242
- v_SNH = jnp.pad(v_SNH, ((0, 0), (0, 0),
243
- (0, multiple_of_128 - self.v_head_dim)))
294
+ if not self.use_mla_kernel: # MLA kernel handles padding
295
+ multiple_of_128 = ((self.qk_head_dim - 1) // 128 + 1) * 128
296
+ q_TNH = jnp.pad(q_TNH,
297
+ ((0, 0), (0, 0),
298
+ (0, multiple_of_128 - self.qk_head_dim)))
299
+ k_SNH = jnp.pad(k_SNH,
300
+ ((0, 0), (0, 0),
301
+ (0, multiple_of_128 - self.qk_head_dim)))
302
+ v_SNH = jnp.pad(v_SNH,
303
+ ((0, 0), (0, 0),
304
+ (0, multiple_of_128 - self.v_head_dim)))
305
+
244
306
  q_scale = k_scale = v_scale = None
245
- if self.kv_cache_quantized_dtype:
246
- # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
247
- # q_scale = self._q_scale
248
- k_scale = self._k_scale
249
- v_scale = self._v_scale
250
- k_SNH, v_SNH = utils.quantize_kv(k_SNH, v_SNH,
251
- self.kv_cache_quantized_dtype,
252
- k_scale, v_scale)
253
- new_kv_cache, outputs_TNH = self.attention(
254
- is_prefill,
255
- kv_cache,
256
- q_TNH,
257
- k_SNH,
258
- v_SNH,
259
- attention_metadata,
260
- self.mesh,
261
- q_scale,
262
- k_scale,
263
- v_scale,
264
- )
265
- # TODO(wenxindongwork): For now, unpad the outputs_TNH to match the v_head_dim.
266
- # We shall add the MLA kv cache implementation in the future.
267
- outputs_TNH = outputs_TNH[..., :self.v_head_dim]
268
307
 
269
- with jax.named_scope("o_proj"):
270
- o_TD = jnp.einsum("TNH,NHD -> TD", outputs_TNH,
271
- self.kernel_o_proj_NHD.value)
272
- o_TD = nnx.with_sharding_constraint(
273
- o_TD, self.activation_attention_out_td)
274
- return new_kv_cache, o_TD
308
+ # TODO(gpolovets): MLA does not currently support quantized KV!
309
+ if not self.use_mla_kernel:
310
+ if self.kv_cache_quantized_dtype:
311
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
312
+ k_scale = self._k_scale
313
+ v_scale = self._v_scale
314
+ k_SNH, v_SNH = quantize_kv(self.kv_cache_quantized_dtype,
315
+ k_SNH, v_SNH, k_scale, v_scale)
316
+
317
+ new_kv_cache, outputs_TNH = self.attention(
318
+ is_prefill,
319
+ kv_cache,
320
+ q_TNH,
321
+ k_SNH,
322
+ v_SNH,
323
+ attention_metadata,
324
+ self.mesh,
325
+ q_scale,
326
+ k_scale,
327
+ v_scale,
328
+ )
329
+ # TODO(wenxindongwork): For now, unpad the outputs_TNH to match the v_head_dim.
330
+ # We shall add the MLA kv cache implementation in the future.
331
+ outputs_TNH = outputs_TNH[..., :self.v_head_dim]
332
+
333
+ else:
334
+ new_kv_cache, outputs_TNA = self.mla_attention(
335
+ kv_cache,
336
+ q_TNA,
337
+ q_rope_TNH,
338
+ kv_SA,
339
+ k_rope_SH,
340
+ attention_metadata,
341
+ self.mesh,
342
+ )
343
+ outputs_TNH = jnp.einsum("TNA,ANH -> TNH", outputs_TNA,
344
+ self.kernel_v_up_proj_ANH.value)
345
+
346
+ with jax.named_scope("o_proj"):
347
+ outputs_TNH = nnx.with_sharding_constraint(
348
+ outputs_TNH, self.activation_attention_out_td)
349
+ outputs_TR = outputs_TNH.reshape(outputs_TNH.shape[0],
350
+ self.N * self.v_head_dim)
351
+ o_TD = jnp.einsum("TR,RD -> TD", outputs_TR,
352
+ self.kernel_o_proj_RD.value)
353
+
354
+ return new_kv_cache, o_TD
275
355
 
276
356
  def attention(
277
357
  self,
@@ -326,21 +406,22 @@ class MLA(nnx.Module):
326
406
  out_specs = (self.attn_o_tnh, P(None, None, "model"))
327
407
 
328
408
  def _ragged_paged_attention(*args):
329
- return ragged_paged_attention(
409
+ outputs = ragged_paged_attention(
330
410
  *args,
331
411
  sm_scale=self.scale,
332
412
  q_scale=q_scale,
333
413
  k_scale=k_scale,
334
414
  v_scale=v_scale,
335
415
  )
416
+ return outputs
336
417
 
337
418
  output_TNH, kv_cache = jax.jit(
338
- shard_map.shard_map(
419
+ jax.shard_map(
339
420
  _ragged_paged_attention,
340
421
  mesh=mesh,
341
422
  in_specs=in_specs,
342
423
  out_specs=out_specs,
343
- check_rep=False,
424
+ check_vma=False,
344
425
  ))(
345
426
  q_TNH,
346
427
  k_SKH,
@@ -352,3 +433,115 @@ class MLA(nnx.Module):
352
433
  md.request_distribution,
353
434
  )
354
435
  return kv_cache, output_TNH
436
+
437
+ def mla_attention(
438
+ self,
439
+ kv_cache: KVCache,
440
+ q_TNA: jax.Array,
441
+ q_rope_TNH: jax.Array,
442
+ k_SA: jax.Array,
443
+ k_rope_SH: jax.Array,
444
+ attention_metadata: AttentionMetadata,
445
+ mesh: Mesh,
446
+ ) -> Tuple[KVCache, jax.Array]:
447
+ """Performs scaled dot-product attention and updates the KV cache.
448
+
449
+ This function handles the core attention logic, which varies between
450
+ prefill and generation modes. In prefill, it computes self-attention
451
+ over the input sequence with a causal mask. In generation, it attends
452
+ to the full history of keys and values stored in the cache.
453
+
454
+ Args:
455
+ kv_cache: The key-value cache to be updated and used.
456
+ q_TNA: Query tensor of shape `(query_seq, num_attention_heads, lkv_dim)`.
457
+ q_rope_TNH: Query rope tensor of shape `(query_seq, num_attention_heads, rope_dim)`.
458
+ k_SA: Key tensor of shape `(kv_seq, lkv_dim)`.
459
+ k_rope_SH: Key rope tensor of shape `(kv_seq, rope_dim)`.
460
+ attention_metadata: Metadata containing sequence lengths.
461
+ mesh: The JAX device mesh (unused in this specific function but
462
+ kept for potential future use or API consistency).
463
+ q_scale: Quantization scale for q.
464
+ k_scale: Quantization scale for k.
465
+ v_scale: Quantization scale for v.
466
+
467
+ Returns:
468
+ A tuple containing:
469
+ - The updated KV cache.
470
+ - The attention output tensor of shape
471
+ `(seq, num_q_heads, head_dim)`.
472
+ """
473
+ md = attention_metadata
474
+ in_specs = (
475
+ self.query_tnh, # q
476
+ self.query_tnh, # q_rope
477
+ self.keyvalue_skh, # k
478
+ self.keyvalue_skh, # k_rope
479
+ P(ShardingAxisName.MLP_TENSOR), # kv_cache
480
+ P(ShardingAxisName.ATTN_DATA), # md.seq_lens: Replicated
481
+ P(ShardingAxisName.ATTN_DATA), # page_indices_flat: Replicated
482
+ P(ShardingAxisName.ATTN_DATA), # query_start_loc: Replicated
483
+ P(ShardingAxisName.ATTN_DATA), # distribution: Replicated
484
+ )
485
+
486
+ out_specs = (self.attn_o_tnh, P(ShardingAxisName.MLP_TENSOR))
487
+
488
+ def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args):
489
+
490
+ def _initialize_block_sizes():
491
+ # Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes)
492
+ # Referring to get_tuned_block_sizes() in kernels/ragged_paged_attention/v3/tuned_block_sizes.py: 'TPU v7'/128/'q_bfloat16_kv_bfloat16/q_head-128_kv_head-1_head-128'/4096
493
+ max_num_tokens = q.shape[0]
494
+ max_num_seqs = md.seq_lens.shape[0]
495
+ num_page_indices = md.block_tables.shape[0]
496
+ assert num_page_indices % max_num_seqs == 0
497
+ pages_per_seq = num_page_indices // max_num_seqs
498
+ # num_kv_pages_per_block = min(pages_per_seq, 16)
499
+ bkv_p, bq_sz = get_tuned_block_sizes(
500
+ q.dtype,
501
+ kv_cache.dtype,
502
+ self.num_attention_heads,
503
+ 1,
504
+ self.qk_nope_head_dim,
505
+ kv_cache.shape[1], # page size
506
+ max_num_tokens,
507
+ pages_per_seq,
508
+ )
509
+ num_kv_pages_per_block = min(min(pages_per_seq, bkv_p), 4)
510
+ num_queries_per_block = min(min(max_num_tokens, bq_sz),
511
+ 4) # OOMS at 8
512
+ return num_kv_pages_per_block, num_queries_per_block
513
+
514
+ num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes(
515
+ )
516
+ output, kv_cache = mla_ragged_paged_attention(
517
+ q,
518
+ q_rope,
519
+ k,
520
+ k_rope,
521
+ kv_cache,
522
+ *args,
523
+ sm_scale=self.scale,
524
+ num_kv_pages_per_block=num_kv_pages_per_block,
525
+ num_queries_per_block=num_queries_per_block)
526
+
527
+ return kv_cache, output
528
+
529
+ kv_cache, output_TNH = jax.jit(
530
+ jax.shard_map(
531
+ _mla_ragged_paged_attention,
532
+ mesh=mesh,
533
+ in_specs=in_specs,
534
+ out_specs=out_specs,
535
+ check_vma=False,
536
+ ), )(
537
+ q_TNA,
538
+ q_rope_TNH,
539
+ k_SA,
540
+ k_rope_SH,
541
+ kv_cache,
542
+ md.seq_lens,
543
+ md.block_tables,
544
+ md.query_start_loc,
545
+ md.request_distribution,
546
+ )
547
+ return kv_cache, output_TNH
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import InitVar, dataclass
2
16
  from typing import Tuple
3
17
 
@@ -5,7 +19,6 @@ import jax
5
19
  import jax.numpy as jnp
6
20
  from flax import nnx
7
21
  from flax.typing import Sharding
8
- from jax.experimental import shard_map
9
22
  from jax.sharding import Mesh
10
23
  from jax.sharding import PartitionSpec as P
11
24
 
@@ -13,6 +26,7 @@ from tpu_inference import utils
13
26
  from tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 import \
14
27
  ragged_paged_attention_hd64
15
28
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29
+ from tpu_inference.layers.common.quantization import quantize_kv
16
30
  from tpu_inference.layers.jax.base import create_param
17
31
  from tpu_inference.layers.jax.rope import GptOssRotaryEmbedding
18
32
 
@@ -158,17 +172,17 @@ class GptOssAttention(nnx.Module):
158
172
  ) -> Tuple[KVCache, jax.Array]:
159
173
  """Performs scaled dot-product attention by calling the ragged_paged_attention kernel."""
160
174
  md = attention_metadata
161
- kv_cache_spec = P(None, None, "model")
175
+ kv_cache_spec = P("data", None, "model")
162
176
 
163
177
  in_specs = (
164
178
  self.query_tnh, # q
165
179
  self.keyvalue_skh, # k
166
180
  self.keyvalue_skh, # v
167
181
  kv_cache_spec, # kv_cache
168
- P(), # md.seq_lens: Replicated
169
- P(), # page_indices_flat: Replicated
170
- P(), # query_start_loc: Replicated
171
- P(), # distribution: Replicated
182
+ P("data"), # md.seq_lens
183
+ P("data"), # page_indices_flat
184
+ P("data"), # query_start_loc
185
+ P("data"), # distribution
172
186
  P(('model')), # sinks
173
187
  )
174
188
  out_specs = (self.attn_o_tnh, kv_cache_spec)
@@ -185,12 +199,12 @@ class GptOssAttention(nnx.Module):
185
199
  )
186
200
 
187
201
  output_TNH, kv_cache = jax.jit(
188
- shard_map.shard_map(
202
+ jax.shard_map(
189
203
  _ragged_paged_attention_wrapper,
190
204
  mesh=mesh,
191
205
  in_specs=in_specs,
192
206
  out_specs=out_specs,
193
- check_rep=False,
207
+ check_vma=False,
194
208
  ))(
195
209
  q_TNH,
196
210
  k_SKH,
@@ -235,9 +249,8 @@ class GptOssAttention(nnx.Module):
235
249
  # q_scale = self._q_scale
236
250
  k_scale = self._k_scale
237
251
  v_scale = self._v_scale
238
- k_TKH, v_TKH = utils.quantize_kv(k_TKH, v_TKH,
239
- self.kv_cache_quantized_dtype,
240
- k_scale, v_scale)
252
+ k_TKH, v_TKH = quantize_kv(self.kv_cache_quantized_dtype, k_TKH,
253
+ v_TKH, k_scale, v_scale)
241
254
 
242
255
  with jax.named_scope("attn_op"):
243
256
  new_kv_cache, attn_out_TNH = self.attention(
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import dataclass
2
16
 
3
17
  import jax
@@ -5,8 +19,8 @@ import jax.numpy as jnp
5
19
  from flax import nnx
6
20
  from jax.sharding import Sharding
7
21
 
8
- from tpu_inference import utils
9
22
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
23
+ from tpu_inference.layers.common.quantization import quantize_kv
10
24
  from tpu_inference.layers.jax.attention.attention import Attention, KVCache
11
25
  from tpu_inference.layers.jax.rope_interface import apply_rope
12
26
  from tpu_inference.logger import init_logger
@@ -114,9 +128,8 @@ class Llama4Attention(Attention):
114
128
  # q_scale = self._q_scale
115
129
  k_scale = self._k_scale
116
130
  v_scale = self._v_scale
117
- k_SKH, v_SKH = utils.quantize_kv(k_SKH, v_SKH,
118
- self.kv_cache_quantized_dtype,
119
- k_scale, v_scale)
131
+ k_SKH, v_SKH = quantize_kv(self.kv_cache_quantized_dtype, k_SKH,
132
+ v_SKH, k_scale, v_scale)
120
133
 
121
134
  with jax.named_scope("attn_op"):
122
135
  new_kv_cache, outputs_TNH = self.attention(
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import dataclasses
2
16
  from dataclasses import dataclass, fields
3
17
  from typing import Any, Callable, Mapping
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """
2
15
  Current Used Abbreviation for Tensor Dimensions:
3
16
  B: Batch size
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import InitVar, dataclass
2
16
  from typing import Any
3
17
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import math
2
16
  from typing import Tuple
3
17
 
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.