tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,275 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import InitVar, dataclass
16
+ from typing import Tuple
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from flax import nnx
21
+ from flax.typing import Sharding
22
+ from jax.sharding import Mesh
23
+ from jax.sharding import PartitionSpec as P
24
+
25
+ from tpu_inference import utils
26
+ from tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 import \
27
+ ragged_paged_attention_hd64
28
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29
+ from tpu_inference.layers.jax.base import create_param
30
+ from tpu_inference.layers.jax.rope import GptOssRotaryEmbedding
31
+
32
+ KVCache = Tuple[jax.Array, jax.Array]
33
+
34
+
35
+ @dataclass(kw_only=True)
36
+ class GptOssAttention(nnx.Module):
37
+ """
38
+ JAX implementation of the GPT-OSS Attention block
39
+ """
40
+ hidden_size: int
41
+ num_attention_heads: int
42
+ num_key_value_heads: int
43
+ head_dim: int
44
+ dtype: jnp.dtype
45
+ rngs: InitVar[nnx.Rngs]
46
+
47
+ rope_theta: float
48
+ initial_context_length: int = 4096
49
+ rope_scaling_factor: float = 32.0
50
+ rope_ntk_alpha: float = 1.0
51
+ rope_ntk_beta: float = 32.0
52
+ kv_cache_dtype: str
53
+
54
+ query_tnh: P = P()
55
+ keyvalue_skh: P = P()
56
+ attn_o_tnh: P = P()
57
+ dnh_sharding: Sharding = ()
58
+ dkh_sharding: Sharding = ()
59
+ nhd_sharding: Sharding = ()
60
+ n_sharding: Sharding = ()
61
+ nh_sharding: Sharding = ()
62
+ kh_sharding: Sharding = ()
63
+ d_sharding: Sharding = ()
64
+
65
+ random_init: bool = False
66
+ mesh: Mesh
67
+
68
+ _q_scale: float = 1.0
69
+ _k_scale: float = 1.0
70
+ _v_scale: float = 1.0
71
+ kv_cache_quantized_dtype = None
72
+
73
+ def __post_init__(self, rngs: nnx.Rngs):
74
+ """Initializes weights, biases, and RoPE module."""
75
+
76
+ self.sm_scale = 1.0 / (self.head_dim**0.5)
77
+
78
+ self.sinks_N = create_param(
79
+ rngs,
80
+ shape=(self.num_attention_heads, ),
81
+ dtype=jnp.float32,
82
+ sharding=self.n_sharding,
83
+ random_init=self.random_init,
84
+ )
85
+
86
+ # Q, K, V projection kernels
87
+ self.kernel_q_DNH = create_param(
88
+ rngs,
89
+ shape=(self.hidden_size, self.num_attention_heads, self.head_dim),
90
+ dtype=self.dtype,
91
+ sharding=self.dnh_sharding,
92
+ random_init=self.random_init,
93
+ )
94
+ self.bias_q_NH = create_param(
95
+ rngs,
96
+ shape=(self.num_attention_heads, self.head_dim),
97
+ dtype=self.dtype,
98
+ sharding=self.nh_sharding,
99
+ random_init=self.random_init,
100
+ )
101
+ self.kernel_k_DKH = create_param(
102
+ rngs,
103
+ shape=(self.hidden_size, self.num_key_value_heads, self.head_dim),
104
+ dtype=self.dtype,
105
+ sharding=self.dkh_sharding,
106
+ random_init=self.random_init,
107
+ )
108
+ self.bias_k_KH = create_param(
109
+ rngs,
110
+ shape=(self.num_key_value_heads, self.head_dim),
111
+ dtype=self.dtype,
112
+ sharding=self.kh_sharding,
113
+ random_init=self.random_init,
114
+ )
115
+ self.kernel_v_DKH = create_param(
116
+ rngs,
117
+ shape=(self.hidden_size, self.num_key_value_heads, self.head_dim),
118
+ dtype=self.dtype,
119
+ sharding=self.dkh_sharding,
120
+ random_init=self.random_init,
121
+ )
122
+ self.bias_v_KH = create_param(
123
+ rngs,
124
+ shape=(self.num_key_value_heads, self.head_dim),
125
+ dtype=self.dtype,
126
+ sharding=self.kh_sharding,
127
+ random_init=self.random_init,
128
+ )
129
+ # Output projection kernel
130
+ self.kernel_o_proj_NHD = create_param(
131
+ rngs,
132
+ shape=(self.num_attention_heads, self.head_dim, self.hidden_size),
133
+ dtype=self.dtype,
134
+ sharding=self.nhd_sharding,
135
+ random_init=self.random_init,
136
+ )
137
+ self.bias_o_D = create_param(
138
+ rngs,
139
+ shape=(self.hidden_size, ),
140
+ dtype=self.dtype,
141
+ sharding=self.d_sharding,
142
+ random_init=self.random_init,
143
+ )
144
+
145
+ # RoPE Module
146
+ self.rope = GptOssRotaryEmbedding(
147
+ head_dim=self.head_dim,
148
+ rope_theta=self.rope_theta,
149
+ dtype=self.dtype,
150
+ initial_context_length=self.initial_context_length,
151
+ rope_scaling_factor=self.rope_scaling_factor,
152
+ rope_ntk_alpha=self.rope_ntk_alpha,
153
+ rope_ntk_beta=self.rope_ntk_beta)
154
+
155
+ if self.kv_cache_dtype != "auto":
156
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
157
+ self.kv_cache_dtype)
158
+
159
+ def attention(
160
+ self,
161
+ kv_cache: KVCache,
162
+ q_TNH: jax.Array,
163
+ k_SKH: jax.Array,
164
+ v_SKH: jax.Array,
165
+ sinks: jax.Array,
166
+ attention_metadata: AttentionMetadata,
167
+ mesh: Mesh,
168
+ q_scale: float | None = None,
169
+ k_scale: float | None = None,
170
+ v_scale: float | None = None,
171
+ ) -> Tuple[KVCache, jax.Array]:
172
+ """Performs scaled dot-product attention by calling the ragged_paged_attention kernel."""
173
+ md = attention_metadata
174
+ kv_cache_spec = P("data", None, "model")
175
+
176
+ in_specs = (
177
+ self.query_tnh, # q
178
+ self.keyvalue_skh, # k
179
+ self.keyvalue_skh, # v
180
+ kv_cache_spec, # kv_cache
181
+ P("data"), # md.seq_lens
182
+ P("data"), # page_indices_flat
183
+ P("data"), # query_start_loc
184
+ P("data"), # distribution
185
+ P(('model')), # sinks
186
+ )
187
+ out_specs = (self.attn_o_tnh, kv_cache_spec)
188
+
189
+ def _ragged_paged_attention_wrapper(*args):
190
+ # Pass the GPT-OSS specific parameters to the kernel
191
+ return ragged_paged_attention_hd64(
192
+ *args,
193
+ sm_scale=self.sm_scale,
194
+ sliding_window=md.sliding_window,
195
+ q_scale=q_scale,
196
+ k_scale=k_scale,
197
+ v_scale=v_scale,
198
+ )
199
+
200
+ output_TNH, kv_cache = jax.jit(
201
+ jax.shard_map(
202
+ _ragged_paged_attention_wrapper,
203
+ mesh=mesh,
204
+ in_specs=in_specs,
205
+ out_specs=out_specs,
206
+ check_vma=False,
207
+ ))(
208
+ q_TNH,
209
+ k_SKH,
210
+ v_SKH,
211
+ kv_cache,
212
+ md.seq_lens,
213
+ md.block_tables,
214
+ md.query_start_loc,
215
+ md.request_distribution,
216
+ sinks,
217
+ )
218
+ return kv_cache, output_TNH
219
+
220
+ def __call__(self,
221
+ x_TD,
222
+ is_prefill,
223
+ kv_cache: KVCache,
224
+ attention_metadata: AttentionMetadata,
225
+ use_attention_rope: bool = True):
226
+ """Forward pass for the Attention module using 3D kernels."""
227
+ md = attention_metadata
228
+ x_TD = jnp.asarray(x_TD, self.dtype)
229
+
230
+ with jax.named_scope("q_proj"):
231
+ q_TNH = jnp.einsum("TD,DNH->TNH", x_TD, self.kernel_q_DNH.value)
232
+ q_TNH += self.bias_q_NH.value
233
+
234
+ with jax.named_scope("k_proj"):
235
+ k_TKH = jnp.einsum("TD,DKH->TKH", x_TD, self.kernel_k_DKH.value)
236
+ k_TKH += self.bias_k_KH.value
237
+
238
+ with jax.named_scope("v_proj"):
239
+ v_TKH = jnp.einsum("TD,DKH->TKH", x_TD, self.kernel_v_DKH.value)
240
+ v_TKH += self.bias_v_KH.value
241
+
242
+ if use_attention_rope:
243
+ q_TNH, k_TKH = self.rope(q_TNH, k_TKH, md.input_positions)
244
+
245
+ q_scale = k_scale = v_scale = None
246
+ if self.kv_cache_quantized_dtype:
247
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
248
+ # q_scale = self._q_scale
249
+ k_scale = self._k_scale
250
+ v_scale = self._v_scale
251
+ k_TKH, v_TKH = utils.quantize_kv(k_TKH, v_TKH,
252
+ self.kv_cache_quantized_dtype,
253
+ k_scale, v_scale)
254
+
255
+ with jax.named_scope("attn_op"):
256
+ new_kv_cache, attn_out_TNH = self.attention(
257
+ kv_cache=kv_cache,
258
+ q_TNH=q_TNH,
259
+ k_SKH=k_TKH,
260
+ v_SKH=v_TKH,
261
+ sinks=self.sinks_N.value,
262
+ attention_metadata=md,
263
+ mesh=self.mesh,
264
+ q_scale=q_scale,
265
+ k_scale=k_scale,
266
+ v_scale=v_scale,
267
+ )
268
+ attn_out_TNH = attn_out_TNH[..., :self.head_dim]
269
+
270
+ with jax.named_scope("o_proj"):
271
+ output_TD = jnp.einsum("TNH,NHD->TD", attn_out_TNH,
272
+ self.kernel_o_proj_NHD.value)
273
+ output_TD += self.bias_o_D.value
274
+
275
+ return new_kv_cache, output_TD
@@ -0,0 +1,167 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from flax import nnx
20
+ from jax.sharding import Sharding
21
+
22
+ from tpu_inference import utils
23
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
24
+ from tpu_inference.layers.jax.attention.attention import Attention, KVCache
25
+ from tpu_inference.layers.jax.rope_interface import apply_rope
26
+ from tpu_inference.logger import init_logger
27
+
28
+ logger = init_logger(__name__)
29
+
30
+
31
+ class L2Norm(nnx.Module):
32
+ """
33
+ Implementation of L2 Norm in JAX (taken from MaxText repo - maxtext/MaxText/layers/attentions.py).
34
+
35
+ Attributes:
36
+ eps: float, epsilon used for numerical stability (default value should be ok for most cases).
37
+ """
38
+
39
+ def __init__(self, eps: float = 1e-6):
40
+ self.eps = eps
41
+
42
+ def __call__(self, x):
43
+ return x * jax.lax.rsqrt(
44
+ jnp.mean(x**2, axis=-1, keepdims=True) + self.eps)
45
+
46
+
47
+ @dataclass(kw_only=True)
48
+ class Llama4Attention(Attention):
49
+ use_qk_norm: bool
50
+ temperature_tuning: bool
51
+ temperature_tuning_floor_scale: float
52
+ temperature_tuning_scale: float
53
+ activation_attention_td: Sharding
54
+ activation_attention_out_td: Sharding
55
+
56
+ def __call__(self,
57
+ x,
58
+ is_prefill,
59
+ kv_cache: KVCache,
60
+ attention_metadata: AttentionMetadata,
61
+ use_attention_rope: bool = True):
62
+ """Performs the forward pass of the attention module.
63
+
64
+ This method computes the attention output by projecting the input `x`
65
+ to queries, keys, and values, applying RoPE and L2Norm if specified,
66
+ performing scaled dot-product attention, and projecting the results
67
+ back to the model dimension.
68
+ If no RoPE (NoPE) is specified, one can also perform temperature tuning
69
+ which is useful to combat dilution of attention scores in long-context attention.
70
+
71
+ Args:
72
+ x: The input tensor of shape `(seq_len, d_model)`.
73
+ is_prefill: Whether the operation mode is prefill (otherwise it is generate).
74
+ kv_cache: The key-value cache for storing past attention states.
75
+ attention_metadata: Metadata for attention, such as input positions.
76
+ use_attention_rope: Whether to use RoPE.
77
+
78
+ Returns:
79
+ A tuple containing:
80
+ - The updated KV cache.
81
+ - The attention output tensor of shape
82
+ `(batch_size, seq_len, d_model)`.
83
+ """
84
+ md = attention_metadata
85
+ x = jnp.asarray(x, self.dtype)
86
+ x_SD = nnx.with_sharding_constraint(x, self.activation_attention_td)
87
+ x_q_TD = nnx.with_sharding_constraint(x, self.activation_q_td)
88
+ rope_scaling = self.rope_scaling
89
+ rope_theta = self.rope_theta
90
+ H = self.head_dim
91
+ l2_norm = L2Norm()
92
+
93
+ with jax.named_scope("q_proj"):
94
+ q_TNH = jnp.einsum('TD,DNH -> TNH', x_q_TD,
95
+ self.kernel_q_proj_DNH.value)
96
+ if use_attention_rope:
97
+ q_TNH = apply_rope(q_TNH, md.input_positions, H, rope_theta,
98
+ rope_scaling, self.rope_input_ordering)
99
+
100
+ # Apply normaliation after RoPE
101
+ if self.use_qk_norm:
102
+ q_TNH = l2_norm(q_TNH)
103
+ else:
104
+ if self.temperature_tuning:
105
+ q_TNH = self.apply_temperature_tuning(md, q_TNH)
106
+
107
+ q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
108
+ with jax.named_scope("k_proj"):
109
+ k_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
110
+ self.kernel_k_proj_DKH.value)
111
+ if use_attention_rope:
112
+ k_SKH = apply_rope(k_SKH, md.input_positions, H, rope_theta,
113
+ rope_scaling, self.rope_input_ordering)
114
+
115
+ # Apply normaliation after RoPE
116
+ if self.use_qk_norm:
117
+ k_SKH = l2_norm(k_SKH)
118
+ k_SKH = nnx.with_sharding_constraint(k_SKH, self.keyvalue_skh)
119
+
120
+ with jax.named_scope("v_proj"):
121
+ v_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
122
+ self.kernel_v_proj_DKH.value)
123
+ v_SKH = nnx.with_sharding_constraint(v_SKH, self.keyvalue_skh)
124
+
125
+ q_scale = k_scale = v_scale = None
126
+ if self.kv_cache_quantized_dtype:
127
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
128
+ # q_scale = self._q_scale
129
+ k_scale = self._k_scale
130
+ v_scale = self._v_scale
131
+ k_SKH, v_SKH = utils.quantize_kv(k_SKH, v_SKH,
132
+ self.kv_cache_quantized_dtype,
133
+ k_scale, v_scale)
134
+
135
+ with jax.named_scope("attn_op"):
136
+ new_kv_cache, outputs_TNH = self.attention(
137
+ is_prefill,
138
+ kv_cache,
139
+ q_TNH,
140
+ k_SKH,
141
+ v_SKH,
142
+ attention_metadata,
143
+ self.mesh,
144
+ q_scale=q_scale,
145
+ k_scale=k_scale,
146
+ v_scale=v_scale,
147
+ )
148
+
149
+ with jax.named_scope("o_proj"):
150
+ o_TD = jnp.einsum('TNH,NHD -> TD', outputs_TNH,
151
+ self.kernel_o_proj_NHD.value)
152
+ o_TD = nnx.with_sharding_constraint(
153
+ o_TD, self.activation_attention_out_td)
154
+ return new_kv_cache, o_TD
155
+
156
+ def apply_temperature_tuning(self, md: AttentionMetadata,
157
+ input_arr_TNH: jax.Array) -> jax.Array:
158
+ """Applies temperature tuning to the input array of shape (T, N, H).
159
+ Args:
160
+ md: AttentionMetadata object containing the input positions.
161
+ input_arr_TNH: Input array of shape (T, N, H) which will have scaled temperatures applied.
162
+ """
163
+ attn_scales = (jnp.log(
164
+ jnp.floor((md.input_positions.astype(self.dtype) + 1.0) /
165
+ self.temperature_tuning_floor_scale) + 1.0) *
166
+ self.temperature_tuning_scale + 1.0)
167
+ return input_arr_TNH * attn_scales[:, None, None]
@@ -0,0 +1,165 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ from dataclasses import dataclass, fields
17
+ from typing import Any, Callable, Mapping
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from flax import nnx
22
+ from flax.typing import Sharding
23
+ from jax.sharding import PartitionSpec as P
24
+
25
+ from tpu_inference.logger import init_logger
26
+
27
+ # Type alias for Initializer for cleaner type hints
28
+ Initializer = Callable[..., jax.Array]
29
+ logger = init_logger(__name__)
30
+
31
+ # Define singleton initializers to avoid re-compilation.
32
+ _scale_initializer = nnx.initializers.ones
33
+ _sharded_initializer = nnx.initializers.xavier_normal()
34
+ _init_fn = nnx.initializers.uniform()
35
+
36
+
37
+ @dataclass
38
+ class Config:
39
+ """Base configuration class with a robust factory method.
40
+
41
+ This class provides a `from_cfg` classmethod that allows creating a config
42
+ instance from a dictionary, ensuring that all required fields are present
43
+ and ignoring any extraneous keys.
44
+ """
45
+
46
+ @classmethod
47
+ def from_cfg(cls, cfg: dict[str, Any] | None = None, **kwargs):
48
+ """Creates a config instance from a dictionary and/or keyword arguments.
49
+
50
+ This factory method validates that all fields without default values
51
+ are provided in the input dictionary or keyword arguments.
52
+
53
+ Args:
54
+ cfg: A dictionary of configuration parameters.
55
+ **kwargs: Additional configuration parameters passed as keyword arguments.
56
+
57
+ Returns:
58
+ An instance of the configuration class.
59
+
60
+ Raises:
61
+ ValueError: If any required parameters are missing.
62
+ """
63
+ if cfg is None:
64
+ cfg = {}
65
+ cfg.update(kwargs)
66
+
67
+ required_params = {
68
+ f.name
69
+ for f in fields(cls) if f.default is dataclasses.MISSING
70
+ and f.default_factory is dataclasses.MISSING
71
+ }
72
+
73
+ # Check if any of the truly required parameters are missing from the provided config.
74
+ missing_params = required_params - set(cfg.keys())
75
+ if missing_params:
76
+ raise ValueError(
77
+ f"Missing required parameters for {cls.__name__}: {', '.join(sorted(list(missing_params)))}"
78
+ )
79
+
80
+ known_params = {f.name for f in fields(cls)}
81
+ filtered_cfg = {k: v for k, v in cfg.items() if k in known_params}
82
+
83
+ return cls(**filtered_cfg)
84
+
85
+ # TODO: check logic with some unit tests.
86
+ def maybe_apply_overrides(self):
87
+ """Update the args with additional_configs, hf_overrides, and override_generation_config settings.
88
+ If there is overlap in overrides between the configs, then print a warning declaring which
89
+ overrides will take precedent."""
90
+
91
+ if not getattr(self, "vllm_config"):
92
+ return
93
+
94
+ def _overrides_str(original: str, original_val: Any,
95
+ new_val: Any) -> str:
96
+ return f"{original}: {original_val} ---> {new_val}"
97
+
98
+ def _get_overrides_dict(self) -> Mapping[str, Any]:
99
+ """Return the overrides from all of the possible vllm sections."""
100
+ overrides_dict = {}
101
+ vllm_model_config = self.vllm_config.model_config
102
+
103
+ for override_type in ordered_override_types:
104
+ if override_type == "additional_config":
105
+ overrides_dict[
106
+ override_type] = self.vllm_config.additional_config
107
+ else:
108
+ overrides_dict[override_type] = getattr(
109
+ vllm_model_config, override_type)
110
+ return overrides_dict
111
+
112
+ ordered_override_types = [
113
+ "additional_config", "hf_overrides", "override_generation_config"
114
+ ]
115
+
116
+ overrides_dict = _get_overrides_dict(self)
117
+
118
+ # Override the config values using the vLLM sections with highest
119
+ # precedence first.
120
+ for field in fields(self):
121
+ selected_type = None
122
+ for override_type in reversed(ordered_override_types):
123
+ if field.name in overrides_dict[override_type]:
124
+ setattr(self, field.name,
125
+ overrides_dict[override_type][field.name])
126
+ selected_type = override_type
127
+ break
128
+ if selected_type is None:
129
+ continue
130
+
131
+ # If multiple vLLM sections contain overrides, print a warning.
132
+ for override_type in ordered_override_types:
133
+ if override_type == selected_type:
134
+ break
135
+ else:
136
+ if field.name in overrides_dict[override_type]:
137
+ overriden_keys_str = _overrides_str(
138
+ field.name,
139
+ overrides_dict[override_type][field.name],
140
+ overrides_dict[selected_type][field.name])
141
+ logger.warning(
142
+ f"Overriding {override_type} arguments with the following {selected_type} args: {overriden_keys_str}"
143
+ )
144
+
145
+ def __post_init__(self):
146
+ self.maybe_apply_overrides()
147
+
148
+
149
+ def create_param(rngs: nnx.Rngs,
150
+ shape: tuple[int, ...],
151
+ sharding: Sharding = (),
152
+ dtype: Any = jnp.float32,
153
+ random_init=False) -> nnx.Param:
154
+ key = rngs.params()
155
+ if random_init:
156
+ initializer = _scale_initializer if len(
157
+ shape) == 1 else _sharded_initializer
158
+
159
+ jitted_initializer = jax.jit(initializer,
160
+ static_argnames=('shape', 'dtype'),
161
+ out_shardings=P(*sharding))
162
+ param_data = jitted_initializer(key, shape, dtype)
163
+ return nnx.Param(param_data, sharding=sharding)
164
+ else:
165
+ return nnx.Param(_init_fn(key, shape, dtype), sharding=sharding)