tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,288 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import functools
5
+
6
+ import jax
7
+ from jax._src import dtypes
8
+ from jax.experimental import pallas as pl
9
+ from jax.experimental.pallas import tpu as pltpu
10
+ from jax.sharding import Mesh
11
+ from jax.sharding import PartitionSpec as P
12
+
13
+ from tpu_inference.utils import TPU_HEAD_SIZE_ALIGNMENT, get_dtype_packing
14
+
15
+
16
+ def _ceil_div(a, b):
17
+ assert b != 0
18
+ return (a + b - 1) // b
19
+
20
+
21
+ def _kv_cache_update_kernel(
22
+ # Prefetch
23
+ slices_ref, # [3, padded_num_slices], list of (kv_cache_start, new_kv_start,
24
+ # slice_len)
25
+ num_slices_ref, # [1]
26
+ # Input
27
+ new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
28
+ kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
29
+ # head_dim]
30
+ # Output
31
+ _, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
32
+ # Scratch
33
+ scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
34
+ # head_dim]
35
+ sem,
36
+ ):
37
+ async_copies = []
38
+ block_idx = pl.program_id(0)
39
+ num_slices_per_block = scratch.shape[0]
40
+
41
+ # Copy from new_kv_hbm_ref to scratch
42
+ for i in range(num_slices_per_block):
43
+ offset_i = i + block_idx * num_slices_per_block
44
+ new_kv_start = jax.lax.select(offset_i < num_slices_ref[0],
45
+ slices_ref[1, offset_i], 0)
46
+ length = jax.lax.select(offset_i < num_slices_ref[0],
47
+ slices_ref[2, offset_i], 0)
48
+ async_copy = pltpu.make_async_copy(
49
+ new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
50
+ scratch.at[i, pl.ds(0, length), ...],
51
+ sem,
52
+ )
53
+ async_copy.start()
54
+ async_copies.append(async_copy)
55
+
56
+ for async_copy in async_copies:
57
+ async_copy.wait()
58
+
59
+ # Copy from scratch to kv_cache_hbm_ref
60
+ async_copies.clear()
61
+ for i in range(num_slices_per_block):
62
+ offset_i = i + block_idx * num_slices_per_block
63
+ kv_cache_start = jax.lax.select(offset_i < num_slices_ref[0],
64
+ slices_ref[0, offset_i], 0)
65
+ length = jax.lax.select(offset_i < num_slices_ref[0],
66
+ slices_ref[2, offset_i], 0)
67
+ async_copy = pltpu.make_async_copy(
68
+ scratch.at[i, pl.ds(0, length), ...],
69
+ kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
70
+ sem,
71
+ )
72
+ async_copy.start()
73
+ async_copies.append(async_copy)
74
+ for async_copy in async_copies:
75
+ async_copy.wait()
76
+
77
+
78
+ def _dynamic_validate_inputs(slices, new_token_num, kv_cache_token_num,
79
+ page_size, num_slices):
80
+ slices = slices.tolist()
81
+ # NOTE: The padding part is unnecessary to check because kv_cache_start, new_kv_start,
82
+ # slice_len will be set to 0 in the kernel implementation.
83
+ for i in range(num_slices[0]):
84
+ kv_cache_start = slices[0][i]
85
+ new_kv_start = slices[1][i]
86
+ slice_len = slices[2][i]
87
+ if new_kv_start < 0:
88
+ raise ValueError(
89
+ f"{new_kv_start=} must be greater than or equal to 0")
90
+ if kv_cache_start < 0:
91
+ raise ValueError(
92
+ f"{kv_cache_start=} must be greater than or equal to 0")
93
+ if not 0 < slice_len <= page_size:
94
+ raise ValueError(
95
+ f"{slice_len=} must be less or equal to {page_size=} and greater than 0"
96
+ )
97
+ if new_kv_start + slice_len > new_token_num:
98
+ raise ValueError(
99
+ f"{new_kv_start=} + {slice_len=} must be less or equal to {new_token_num=}"
100
+ )
101
+ if kv_cache_start + slice_len > kv_cache_token_num:
102
+ raise ValueError(
103
+ f"{kv_cache_start=} + {slice_len=} must be less or equal to {kv_cache_token_num=}"
104
+ )
105
+ if kv_cache_start // page_size != (kv_cache_start + slice_len -
106
+ 1) // page_size:
107
+ raise ValueError(
108
+ f"Each slice must reside in the same page, but got {kv_cache_start=} and {slice_len=}"
109
+ )
110
+
111
+ new_kv_intervals = []
112
+ kv_cache_intervals = []
113
+ for i in range(num_slices[0]):
114
+ new_kv_intervals.append((slices[1][i], slices[1][i] + slices[2][i]))
115
+ kv_cache_intervals.append((slices[0][i], slices[0][i] + slices[2][i]))
116
+
117
+ new_kv_intervals.sort()
118
+ kv_cache_intervals.sort()
119
+
120
+ # The new_kv slices should be continuous
121
+ for i in range(len(new_kv_intervals) - 1):
122
+ if new_kv_intervals[i][1] != new_kv_intervals[i + 1][0]:
123
+ raise ValueError(
124
+ f"{new_kv_intervals[i][1]=} is expeced to equal to {new_kv_intervals[i + 1][0]}"
125
+ )
126
+
127
+ # There should be no overlap among the kv cache slices
128
+ for i in range(len(kv_cache_intervals) - 1):
129
+ if kv_cache_intervals[i][1] > kv_cache_intervals[i + 1][0]:
130
+ raise ValueError(
131
+ f"Overlap detected in kv_cache intervals: {kv_cache_intervals[i]} and {kv_cache_intervals[i+1]}"
132
+ )
133
+
134
+
135
+ def _kv_cache_update(
136
+ new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
137
+ slices: jax.Array, # [3, slices], list of (kv_cache_start, new_kv_start,
138
+ # slice_len)
139
+ kv_cache: jax.
140
+ Array, # [total_num_pages * page_size, num_combined_kv_heads,
141
+ # head_dim]
142
+ num_slices: jax.Array, # [1]
143
+ page_size: int,
144
+ num_slices_per_block: int,
145
+ dynamic_validate_inputs: bool,
146
+ vmem_limit_bytes: int = 40 * 1024 * 1024,
147
+ ):
148
+ new_token_num, num_combined_kv_heads, head_dim = new_kv.shape
149
+ assert kv_cache.shape[1] == num_combined_kv_heads
150
+ assert kv_cache.shape[2] == head_dim
151
+ assert head_dim % 128 == 0
152
+ if dynamic_validate_inputs is True:
153
+ _dynamic_validate_inputs(slices, new_token_num, kv_cache.shape[0],
154
+ page_size, num_slices)
155
+
156
+ in_specs = [
157
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
158
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
159
+ ]
160
+
161
+ out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
162
+ out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
163
+
164
+ scalar_prefetches = [slices, num_slices]
165
+ scratch = pltpu.VMEM(
166
+ (num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
167
+ new_kv.dtype,
168
+ )
169
+
170
+ scratch_shapes = [
171
+ scratch,
172
+ pltpu.SemaphoreType.DMA,
173
+ ]
174
+
175
+ kernel = pl.pallas_call(
176
+ _kv_cache_update_kernel,
177
+ grid_spec=pltpu.PrefetchScalarGridSpec(
178
+ num_scalar_prefetch=len(scalar_prefetches),
179
+ in_specs=in_specs,
180
+ out_specs=out_specs,
181
+ grid=(_ceil_div(num_slices[0], num_slices_per_block), ),
182
+ scratch_shapes=scratch_shapes,
183
+ ),
184
+ out_shape=out_shape,
185
+ input_output_aliases={len(scalar_prefetches) + 1: 0},
186
+ compiler_params=pltpu.CompilerParams(
187
+ vmem_limit_bytes=vmem_limit_bytes, ),
188
+ )
189
+
190
+ return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
191
+
192
+
193
+ def _prev_power_of_2(n: int) -> int:
194
+ """The previous power of 2 (inclusive)"""
195
+ if n <= 0:
196
+ return 0
197
+ return 1 << (n.bit_length() - 1)
198
+
199
+
200
+ def _get_page_size_bytes(block_size: int, num_combined_kv_heads: int,
201
+ head_size: int, kv_cache_dtype) -> int:
202
+ """Returns the size in bytes of one page of the KV cache."""
203
+ kv_cache_dtype_bit_size = (dtypes.bit_width(kv_cache_dtype) if hasattr(
204
+ dtypes, "bit_width") else dtypes.itemsize_bits(kv_cache_dtype))
205
+ padded_head_size = _ceil_div(
206
+ head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
207
+
208
+ # NOTE: for the implicit padding in XLA
209
+ packing = get_dtype_packing(kv_cache_dtype)
210
+ num_combined_kv_heads = _ceil_div(num_combined_kv_heads, packing) * packing
211
+
212
+ return block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bit_size // 8
213
+
214
+
215
+ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int,
216
+ vmem_limit_bytes: int) -> int:
217
+ """Find the optimum number of slices to copy per Pallas program instance.
218
+ Increasing the number of slices copied in one instance of the kernel program
219
+ will increase HBM bandwidth utilization via more in-flight DMAs.
220
+ However, it will also use more VMEM, and experimentally, we observed
221
+ performance regression at 128 slices on v6e, likely due to running
222
+ out of scalar registers. Thus this function will limit the number of
223
+ slices to 64.
224
+ """
225
+ # NOTE: We assume 1MB vmem is used for register spill and others
226
+ assert vmem_limit_bytes >= 1024 * 1024, "vmem_limit_bytes must be at least 1MB"
227
+ num_slices_per_block = (vmem_limit_bytes - 1024 * 1024) // page_size_bytes
228
+ assert num_slices_per_block > 0, "Number of slices should be positive"
229
+ num_slices_per_block = _prev_power_of_2(num_slices_per_block)
230
+ return min(num_slices_per_block, 64)
231
+
232
+
233
+ @functools.partial(
234
+ jax.jit,
235
+ static_argnames=[
236
+ "page_size", "num_slices_per_block", "mesh", "kv_cache_pspec"
237
+ ],
238
+ donate_argnames="kv_cache",
239
+ )
240
+ def kv_cache_update(
241
+ new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
242
+ slices: jax.
243
+ Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
244
+ kv_cache: jax.
245
+ Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
246
+ num_slices: jax.Array, # [1]
247
+ *,
248
+ page_size: int = 32,
249
+ num_slices_per_block: int | None = None,
250
+ mesh: Mesh | None = None,
251
+ kv_cache_pspec: P
252
+ | None = None, # Only sharding along head_dim is supported
253
+ dynamic_validate_inputs: bool = False,
254
+ vmem_limit_bytes: int = 40 * 1024 * 1024,
255
+ ):
256
+ if num_slices_per_block is None:
257
+ _, num_combined_kv_heads, head_dim = new_kv.shape
258
+ page_size_bytes = _get_page_size_bytes(page_size,
259
+ num_combined_kv_heads, head_dim,
260
+ kv_cache.dtype)
261
+ num_slices_per_block = _get_num_slices_per_kv_cache_update_block(
262
+ page_size_bytes, vmem_limit_bytes)
263
+
264
+ if mesh is None:
265
+ return _kv_cache_update(new_kv, slices, kv_cache, num_slices,
266
+ page_size, num_slices_per_block,
267
+ dynamic_validate_inputs)
268
+
269
+ if kv_cache_pspec is None:
270
+ raise ValueError(
271
+ "kv_cache_pspec must be provided when mesh is specified")
272
+
273
+ in_specs = (kv_cache_pspec, P(), kv_cache_pspec, P())
274
+ out_specs = kv_cache_pspec
275
+ shard_map_wrapped = jax.shard_map(
276
+ functools.partial(
277
+ _kv_cache_update,
278
+ page_size=page_size,
279
+ num_slices_per_block=num_slices_per_block,
280
+ dynamic_validate_inputs=dynamic_validate_inputs,
281
+ vmem_limit_bytes=vmem_limit_bytes,
282
+ ),
283
+ mesh=mesh,
284
+ in_specs=in_specs,
285
+ out_specs=out_specs,
286
+ check_vma=False,
287
+ )
288
+ return shard_map_wrapped(new_kv, slices, kv_cache, num_slices)