tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,403 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import math
17
+ from typing import Any, Callable, Optional, Tuple
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
22
+ from jax.experimental.pallas.ops.tpu.splash_attention import \
23
+ splash_attention_kernel as splash
24
+ from jax.experimental.pallas.ops.tpu.splash_attention import \
25
+ splash_attention_mask as mask_lib
26
+ from jax.sharding import Mesh
27
+ from jax.sharding import PartitionSpec as P
28
+
29
+ import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
30
+ import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
31
+ from tpu_inference.kernels.flash_attention.kernel import flash_attention
32
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
33
+ from tpu_inference.layers.common.sharding import ShardingAxisName
34
+ from tpu_inference.utils import get_megacore
35
+
36
+ MAX_ALLOWED_PAGE_INDICES_N = (
37
+ 128 * 1024
38
+ ) # Based on experiments on v5e, 256x1024 results in smem oom but 128x1024 not. TODO: Adjust this based on TPU version.
39
+
40
+ ragged_paged_attention = rpa.ragged_paged_attention
41
+ get_kv_cache_shape = rpa.get_kv_cache_shape
42
+
43
+ ragged_paged_attention_hd64 = rpa_hd64.ragged_paged_attention_hd64
44
+ get_kv_cache_shape_hd64 = rpa_hd64.get_kv_cache_shape
45
+
46
+
47
+ def sharded_flash_attention(
48
+ mesh: Mesh,
49
+ causal: bool = True,
50
+ sm_scale: Optional[float] = None,
51
+ vmem_limit_bytes: int | None = None,
52
+ ) -> Callable[..., Any]:
53
+ in_specs = (
54
+ P("data", "model", None, None), # q
55
+ P("data", "model", None, None), # k
56
+ P("data", "model", None, None), # v
57
+ P(), # segment_ids
58
+ )
59
+ out_specs = P("data", "model", None, None)
60
+
61
+ def _flash_attention(q, k, v, segment_ids):
62
+ return flash_attention(q,
63
+ k,
64
+ v,
65
+ segment_ids=segment_ids,
66
+ sm_scale=sm_scale,
67
+ causal=causal,
68
+ vmem_limit_bytes=vmem_limit_bytes)
69
+
70
+ return jax.jit(
71
+ jax.shard_map(_flash_attention,
72
+ mesh=mesh,
73
+ in_specs=in_specs,
74
+ out_specs=out_specs,
75
+ check_vma=False))
76
+
77
+
78
+ def sharded_paged_attention(
79
+ mesh: Mesh,
80
+ attn_logits_soft_cap: Optional[float] = None,
81
+ ) -> Callable[..., Any]:
82
+ """Shards GQA PagedAttention along KV heads."""
83
+ in_specs = (
84
+ P(None, "model", None), # q
85
+ P("model", None, None, None), # k
86
+ P("model", None, None, None), # v
87
+ P(), # lengths
88
+ P(), # page_indices
89
+ )
90
+ out_specs = P(None, "model", None)
91
+
92
+ def _paged_attention_fn(q, k, v, lengths, page_indices):
93
+ if page_indices.size > MAX_ALLOWED_PAGE_INDICES_N:
94
+ raise ValueError(
95
+ "This will result in smem OOM. Use `paged_attention_with_guarded_smem` to run with minibatches."
96
+ )
97
+ return paged_attention(
98
+ q,
99
+ k,
100
+ v,
101
+ lengths,
102
+ page_indices,
103
+ attn_logits_soft_cap=attn_logits_soft_cap,
104
+ pages_per_compute_block=min(
105
+ 16, page_indices.shape[1]), # 512 / page_size:32,
106
+ megacore_mode="kv_head" if get_megacore() else None,
107
+ )
108
+
109
+ return jax.jit(
110
+ jax.shard_map(
111
+ _paged_attention_fn,
112
+ mesh=mesh,
113
+ in_specs=in_specs,
114
+ out_specs=out_specs,
115
+ check_vma=False,
116
+ ))
117
+
118
+
119
+ # TODO(xiangxu): merge this with sharded_paged_attention
120
+ @functools.partial(jax.jit, static_argnums=[0])
121
+ def paged_attention_with_guarded_smem(
122
+ paged_attention_kernel: Callable,
123
+ q: jax.Array,
124
+ k_pages: jax.Array,
125
+ v_pages: jax.Array,
126
+ lengths: jax.Array,
127
+ page_indices: jax.Array,
128
+ ):
129
+ # Addresses b/336316706. Summary:
130
+ # Paged attention kernel stores `lengths` (batch_size * 4 bytes) and `page_indices` (batch_size * num_blocks_per_seq * 4 bytes) in SMEM.
131
+ # Capacity of SMEM is quite limited which is also TPU version dependent. Models with higher context length or higher batch size, can cause OOM in SMEM.
132
+ # There are two solutions:
133
+ # 1. Reduce blocks per seq by increasing page size.
134
+ # 2. Splitting the batch into several minibatches (Higher perf based on my benchmark).
135
+
136
+ batch_size, blocks_per_seq = page_indices.shape
137
+
138
+ if page_indices.size <= MAX_ALLOWED_PAGE_INDICES_N:
139
+ return paged_attention_kernel(q, k_pages, v_pages, lengths,
140
+ page_indices)
141
+
142
+ mini_batch_size = MAX_ALLOWED_PAGE_INDICES_N // blocks_per_seq
143
+
144
+ # If batch_size is not disible by mini_batch_size,
145
+ # we set mini_batch_size to a smaller value, i.e GCD,
146
+ # which will trigger more kernel launches but it's fine.
147
+ # TODO: Fix --decode_seqs_padding with this limitation.
148
+ mini_batch_size = math.gcd(batch_size, mini_batch_size)
149
+
150
+ num_kernel_launches = batch_size // mini_batch_size
151
+
152
+ outputs = jnp.zeros_like(q).reshape(
153
+ (num_kernel_launches, mini_batch_size, *q.shape[1:]))
154
+ q = q.reshape((num_kernel_launches, mini_batch_size, *q.shape[1:]))
155
+ seq_lens = lengths.reshape((num_kernel_launches, mini_batch_size))
156
+ block_indices = page_indices.reshape(
157
+ (num_kernel_launches, mini_batch_size, page_indices.shape[1]))
158
+
159
+ for i in range(num_kernel_launches):
160
+ outputs = outputs.at[i].set(
161
+ paged_attention_kernel(q[i], k_pages, v_pages, seq_lens[i],
162
+ block_indices[i]))
163
+
164
+ outputs = outputs.reshape((batch_size, *outputs.shape[2:]))
165
+
166
+ return outputs
167
+
168
+
169
+ # ruff: noqa: E741
170
+ def update_cache(
171
+ is_prefill,
172
+ cache,
173
+ indices,
174
+ operand,
175
+ prefill_seq_len=None,
176
+ sliding_window=None,
177
+ ) -> jax.Array:
178
+
179
+ # (8, 55640, 32, 128) (1, 8, 256, 128) -> K (8, 8, 32, 128)
180
+ # I = B * T // S
181
+ # k cache, operand
182
+
183
+ B, K, T, H = operand.shape
184
+ K_c, L, S, H = cache.shape
185
+ assert K == K_c
186
+ # NOTE: The cache updating is pretty tricky:
187
+ # 1. The random access updating cache is not as performant as the slice updating.
188
+ # If the random access is necessary, make sure the indexing count is as small as possible.
189
+ # 2. The random access updating may trigger extra tranpose (memory copy) of cache,
190
+ # which is a disaster because the cache is huge. This is a data formatting op inserted by
191
+ # the XLA compiler and not well documented.
192
+ # To mitigate the issues above:
193
+ # For prefill:
194
+ # We reshape the operand so that we can update the cache in block wise, which only requires the block indices.
195
+ # For decode:
196
+ # We reshape the cache so that we can update the cache in token wise, which only requires the token indices (block_id + offset).
197
+ if is_prefill:
198
+ # In the case of sliding window, we should select sliding_window tokens from actual prompt, not from the padded tokens.
199
+ if sliding_window and T > sliding_window:
200
+ assert B == 1
201
+ start_index = jax.lax.max(0, prefill_seq_len - sliding_window)
202
+ operand = jax.lax.dynamic_slice_in_dim(
203
+ operand, start_index, sliding_window,
204
+ axis=2) # TODO: @pooyam Perf check this.
205
+ T = sliding_window
206
+
207
+ I = B * T // S
208
+ # cache: (K, L, S, H)
209
+ # operand: (B, K, T, H) -> (K, I, S, H)
210
+ # indices: (B, T // S) -> (I,)
211
+ operand = jnp.swapaxes(operand, 0, 1).reshape(K, I, S, H)
212
+ indices = indices.reshape(I)
213
+ cache = cache.at[:, indices, :, :].set(operand)
214
+ else:
215
+ # cache: (K, L, S, H) -> (K, L * S, H)
216
+ # operand: (B, K, 1, H) -> (K, B, H)
217
+ # indices: (B,)
218
+ cache = cache.reshape(K, L * S, H)
219
+ operand = jnp.swapaxes(operand, 0, 1).reshape(K, B, H)
220
+ # NOTE: `cache.[:, indices, :].set()` will trigger the extra tranpose of the cache.
221
+ # The `jnp.arange(K)[..., None]` trick is to avoid it. WTF?
222
+ cache = cache.at[jnp.arange(K)[..., None], indices, :].set(operand)
223
+ cache = cache.reshape(K, L, S, H)
224
+ return cache
225
+
226
+
227
+ @functools.partial(
228
+ jax.jit, static_argnames=["window_size", "attn_logits_soft_cap", "is_mqa"])
229
+ def apply_splash(q, k, v, window_size, attn_logits_soft_cap,
230
+ is_mqa) -> jax.Array:
231
+ # q: (batch_size, num_heads, seq_len, head_dim)
232
+ num_heads = q.shape[1]
233
+ q_seq_len = q.shape[2]
234
+ kv_seq_len = k.shape[2]
235
+ assert kv_seq_len >= q_seq_len
236
+
237
+ masks = [
238
+ mask_lib.LocalMask((q_seq_len, kv_seq_len), (window_size, 0),
239
+ kv_seq_len - q_seq_len) for _ in range(num_heads)
240
+ ]
241
+ mask = mask_lib.MultiHeadMask(tuple((m for m in masks)))
242
+ block_sizes = splash.BlockSizes.get_default()
243
+
244
+ if is_mqa:
245
+ attn = splash.make_splash_mqa_single_device(
246
+ mask,
247
+ block_sizes=block_sizes,
248
+ attn_logits_soft_cap=attn_logits_soft_cap)
249
+ else:
250
+ attn = splash.make_splash_mha_single_device(
251
+ mask,
252
+ block_sizes=block_sizes,
253
+ attn_logits_soft_cap=attn_logits_soft_cap)
254
+ attn = jax.vmap(attn)
255
+ outputs = attn(q, k, v, None)
256
+
257
+ return outputs
258
+
259
+
260
+ def sharded_splash_attention(
261
+ mesh: Mesh,
262
+ window_size: Optional[int] = None,
263
+ attn_logits_soft_cap: Optional[float] = None,
264
+ is_mqa: bool = False,
265
+ ) -> Callable[..., Any]:
266
+ in_specs = (
267
+ P("data", "model", None, None), # q
268
+ P("data", "model", None, None), # k
269
+ P("data", "model", None, None), # vx
270
+ )
271
+ out_specs = P("data", "model", None, None)
272
+ return jax.jit(
273
+ jax.shard_map(
274
+ functools.partial(
275
+ apply_splash,
276
+ window_size=window_size,
277
+ attn_logits_soft_cap=attn_logits_soft_cap,
278
+ is_mqa=is_mqa,
279
+ ),
280
+ mesh=mesh,
281
+ in_specs=in_specs,
282
+ out_specs=out_specs,
283
+ check_vma=False,
284
+ ))
285
+
286
+
287
+ def sharded_ragged_paged_attention(
288
+ mesh: Mesh,
289
+ q: jax.Array,
290
+ k: jax.Array,
291
+ v: jax.Array,
292
+ kv_cache: jax.Array,
293
+ kv_lens: jax.Array,
294
+ page_indices: jax.Array,
295
+ cu_q_lens: jax.Array,
296
+ distribution: jax.Array,
297
+ attention_sink: jax.Array | None,
298
+ sm_scale: float,
299
+ attention_chunk_size: int | None = None,
300
+ q_scale: float | None = None,
301
+ k_scale: float | None = None,
302
+ v_scale: float | None = None,
303
+ ):
304
+ """Shards along KV heads."""
305
+
306
+ qkv_spec = P(ShardingAxisName.ATTN_DATA, ShardingAxisName.ATTN_HEAD, None)
307
+ kv_cache_spec = P(ShardingAxisName.ATTN_DATA, None,
308
+ ShardingAxisName.ATTN_HEAD, None, None)
309
+ in_specs = (
310
+ qkv_spec, # q
311
+ qkv_spec, # k
312
+ qkv_spec, # v
313
+ kv_cache_spec, # kv cache
314
+ P(ShardingAxisName.ATTN_DATA), # kv_lens
315
+ P(ShardingAxisName.ATTN_DATA), # page_indices
316
+ P(ShardingAxisName.ATTN_DATA), # cu_q_lens
317
+ P(ShardingAxisName.ATTN_DATA), # distribution
318
+ )
319
+ out_specs = (qkv_spec, kv_cache_spec)
320
+
321
+ args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
322
+
323
+ use_hd64 = q.shape[-1] == 64
324
+ func = ragged_paged_attention_hd64 if use_hd64 else ragged_paged_attention
325
+
326
+ if attention_sink is not None:
327
+ if not use_hd64:
328
+ raise NotImplementedError(
329
+ "Attention sink support is only available when head_dim==64")
330
+
331
+ in_specs += (P(ShardingAxisName.ATTN_HEAD), )
332
+ args += (attention_sink, )
333
+
334
+ def _ragged_paged_attention(*args):
335
+ return func(
336
+ *args,
337
+ sm_scale=sm_scale,
338
+ sliding_window=attention_chunk_size,
339
+ q_scale=q_scale,
340
+ k_scale=k_scale,
341
+ v_scale=v_scale,
342
+ )
343
+
344
+ return jax.shard_map(
345
+ _ragged_paged_attention,
346
+ mesh=mesh,
347
+ in_specs=in_specs,
348
+ out_specs=out_specs,
349
+ check_vma=False,
350
+ )(*args)
351
+
352
+
353
+ def attention(
354
+ kv_cache: jax.Array,
355
+ q: jax.Array,
356
+ k: jax.Array,
357
+ v: jax.Array,
358
+ attention_metadata: AttentionMetadata,
359
+ mesh: Mesh,
360
+ head_dim_original: int | None = None, # before padding,
361
+ attention_chunk_size: int | None = None,
362
+ q_scale: float | None = None,
363
+ k_scale: float | None = None,
364
+ v_scale: float | None = None,
365
+ sinks: jax.Array | None = None,
366
+ ) -> Tuple[jax.Array, jax.Array]:
367
+ # T: seq_len
368
+ # N: num_heads
369
+ # K: num_kv_heads
370
+ # D: hidden_size
371
+ # H: head_dim
372
+ # L: num_blocks
373
+ # S: block_size
374
+
375
+ # TODO(jevinjiang, cuiq): transpose q weight offline.
376
+ # q: (T, N, H)
377
+ # k,v: (T, K, H)
378
+
379
+ if head_dim_original is None:
380
+ head_dim_original = q.shape[-1]
381
+
382
+ md = attention_metadata
383
+
384
+ # (T, N, H)
385
+ output, kv_cache = sharded_ragged_paged_attention(
386
+ mesh,
387
+ q,
388
+ k,
389
+ v,
390
+ kv_cache,
391
+ md.seq_lens,
392
+ md.block_tables,
393
+ md.query_start_loc,
394
+ md.request_distribution,
395
+ sinks,
396
+ sm_scale=head_dim_original**-0.5,
397
+ attention_chunk_size=attention_chunk_size,
398
+ q_scale=q_scale,
399
+ k_scale=k_scale,
400
+ v_scale=v_scale,
401
+ )
402
+
403
+ return kv_cache, output
@@ -0,0 +1,48 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ from dataclasses import dataclass, field
17
+ from typing import Any
18
+
19
+ import jax
20
+
21
+
22
+ @functools.partial(
23
+ jax.tree_util.register_dataclass,
24
+ data_fields=[
25
+ "input_positions",
26
+ "block_tables",
27
+ "seq_lens",
28
+ "query_start_loc",
29
+ "request_distribution",
30
+ ],
31
+ meta_fields=[],
32
+ drop_fields=["query_start_loc_cpu", "seq_lens_cpu"],
33
+ )
34
+ @dataclass
35
+ class AttentionMetadata(object):
36
+ # (padded_total_num_scheduled_tokens,)
37
+ input_positions: jax.Array
38
+ # (max_num_seqs * max_num_blocks_per_req,)
39
+ block_tables: jax.Array = None
40
+ # (max_num_seqs,)
41
+ seq_lens: jax.Array = None
42
+ # (max_num_seqs + 1,)
43
+ query_start_loc: jax.Array = None
44
+ # (3,)
45
+ request_distribution: jax.Array = None
46
+
47
+ query_start_loc_cpu: Any = field(init=False)
48
+ seq_lens_cpu: Any = field(init=False)