tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,547 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import InitVar, dataclass
17
+ from typing import Any, Tuple
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from flax import nnx
22
+ from flax.typing import Sharding
23
+ from jax.sharding import Mesh
24
+ from jax.sharding import PartitionSpec as P
25
+
26
+ from tpu_inference import utils
27
+ from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention
28
+ from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
29
+ ragged_paged_attention
30
+ from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import \
31
+ get_tuned_block_sizes
32
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
33
+ from tpu_inference.layers.common.sharding import ShardingAxisName
34
+ from tpu_inference.layers.jax.base import create_param
35
+ from tpu_inference.layers.jax.layers import RMSNorm
36
+ from tpu_inference.layers.jax.rope import DeepseekScalingRotaryEmbedding
37
+
38
+ KVCache = Tuple[jax.Array, jax.Array]
39
+
40
+
41
+ # TODO (wenxindongwork): Add MLA KV cache implementation. For now, cache complete KV vectors.
42
+ @dataclass(kw_only=True)
43
+ class MLA(nnx.Module):
44
+ """An implementation of Multi-Head Latent Attention as
45
+ described in the DeepSeek V3 paper.
46
+
47
+ Attributes:
48
+ mesh: The JAX device mesh for distributed computation.
49
+ """
50
+ hidden_size: int
51
+ num_attention_heads: int
52
+ num_key_value_heads: int
53
+ head_dim: int
54
+ rope_theta: float
55
+ rope_scaling: dict[str, Any]
56
+ dtype: jnp.dtype
57
+ kv_cache_dtype: str
58
+ mesh: Mesh
59
+
60
+ q_lora_rank: int
61
+ kv_lora_rank: int
62
+ qk_nope_head_dim: int
63
+ qk_rope_head_dim: int
64
+ v_head_dim: int
65
+ rms_norm_eps: float
66
+
67
+ # Sharding attributes
68
+ rd_sharding: Sharding = ()
69
+ q_da_sharding: Sharding = ()
70
+ ap_sharding: Sharding = ()
71
+ anh_sharding: Sharding = ()
72
+ kv_da_sharding: Sharding = ()
73
+
74
+ activation_attention_td: Sharding = ()
75
+ activation_q_td: Sharding = ()
76
+ query_tnh: P = P()
77
+ keyvalue_skh: P = P()
78
+
79
+ attn_o_tnh: P = P()
80
+ activation_attention_out_td: Sharding = ()
81
+
82
+ random_init: bool = False
83
+ attention_chunk_size: int | None = None
84
+ rope_input_ordering: str = "split"
85
+ quant: Any | None = None
86
+ rope_mscale_all_dim: float = 1.0
87
+ use_mla_kernel: bool = False
88
+
89
+ rngs: InitVar[nnx.Rngs]
90
+
91
+ _q_scale: float = 1
92
+ _k_scale: float = 1
93
+ _v_scale: float = 1
94
+
95
+ def __post_init__(self, rngs: nnx.Rngs):
96
+ self.N = self.num_attention_heads
97
+ self.K = self.num_key_value_heads
98
+ self.D = self.hidden_size
99
+ self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
100
+
101
+ if not self.use_mla_kernel:
102
+ assert self.N == self.K, "N and K must be equal for MLA"
103
+
104
+ if self.rope_scaling["factor"] <= 1.0:
105
+ yarn_mscale = 1.0
106
+ else:
107
+ yarn_mscale = 0.1 * self.rope_mscale_all_dim * math.log(
108
+ self.rope_scaling["factor"]) + 1.0
109
+ self.scale = self.qk_head_dim**-0.5 * yarn_mscale**2
110
+
111
+ self.rope = DeepseekScalingRotaryEmbedding(
112
+ rotary_dim=self.qk_rope_head_dim,
113
+ rope_theta=self.rope_theta,
114
+ original_max_position_embeddings=self.
115
+ rope_scaling["original_max_position_embeddings"],
116
+ scaling_factor=self.rope_scaling["factor"],
117
+ dtype=self.dtype,
118
+ beta_fast=self.rope_scaling["beta_fast"],
119
+ beta_slow=self.rope_scaling["beta_slow"],
120
+ mscale_value=self.rope_scaling["mscale"],
121
+ mscale_all_dim=self.rope_scaling["mscale_all_dim"],
122
+ )
123
+
124
+ # Initializes the weight kernels
125
+ self.kernel_q_down_proj_DA = create_param(rngs,
126
+ (self.D, self.q_lora_rank),
127
+ self.q_da_sharding,
128
+ self.dtype,
129
+ random_init=self.random_init)
130
+ self.kernel_q_up_proj_AP = create_param(
131
+ rngs,
132
+ (self.q_lora_rank, self.N * self.qk_head_dim),
133
+ self.ap_sharding,
134
+ self.dtype,
135
+ random_init=self.random_init,
136
+ )
137
+ self.kernel_kv_down_proj_DA = create_param(
138
+ rngs,
139
+ (self.D, self.kv_lora_rank + self.qk_rope_head_dim),
140
+ self.kv_da_sharding,
141
+ self.dtype,
142
+ random_init=self.random_init,
143
+ )
144
+ # NOTE (jacobplatin): we are keeping these variables as 3D because
145
+ # we would need to reshape them before the below projection,
146
+ # which caused issues as Qwix wasn't quantizing it correctly
147
+ # on the abstract pass
148
+ if self.use_mla_kernel:
149
+ self.kernel_k_up_proj_ANH = create_param(
150
+ rngs,
151
+ (self.kv_lora_rank, self.N, self.qk_nope_head_dim),
152
+ self.anh_sharding,
153
+ self.dtype,
154
+ random_init=self.random_init,
155
+ )
156
+ self.kernel_v_up_proj_ANH = create_param(
157
+ rngs,
158
+ (self.kv_lora_rank, self.N, self.v_head_dim),
159
+ self.anh_sharding,
160
+ self.dtype,
161
+ random_init=self.random_init,
162
+ )
163
+ else:
164
+ self.kernel_kv_up_proj_AL = create_param(
165
+ rngs,
166
+ (self.kv_lora_rank, self.N *
167
+ (self.qk_nope_head_dim + self.v_head_dim)),
168
+ self.
169
+ ap_sharding, # NOTE: we use the same sharding for kv_up_proj_AL and kernel_q_up_proj_AP
170
+ self.dtype,
171
+ random_init=self.random_init,
172
+ )
173
+ self.kernel_o_proj_RD = create_param(
174
+ rngs, (self.N * self.v_head_dim, self.D),
175
+ self.rd_sharding,
176
+ self.dtype,
177
+ random_init=self.random_init)
178
+ self.q_rms_norm = RMSNorm(
179
+ dims=self.q_lora_rank,
180
+ epsilon=self.rms_norm_eps,
181
+ with_scale=True,
182
+ dtype=self.dtype,
183
+ random_init=self.random_init,
184
+ rngs=rngs,
185
+ )
186
+
187
+ self.kv_rms_norm = RMSNorm(
188
+ dims=self.kv_lora_rank,
189
+ random_init=self.random_init,
190
+ epsilon=self.rms_norm_eps,
191
+ with_scale=True,
192
+ dtype=self.dtype,
193
+ rngs=rngs,
194
+ )
195
+
196
+ self.kv_cache_quantized_dtype = None
197
+ if self.kv_cache_dtype != "auto":
198
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
199
+ self.kv_cache_dtype)
200
+
201
+ def __call__(self,
202
+ x,
203
+ is_prefill,
204
+ kv_cache: KVCache,
205
+ attention_metadata: AttentionMetadata,
206
+ use_attention_rope: bool = True):
207
+ """Performs the forward pass of the attention module.
208
+
209
+ Args:
210
+ x: The input tensor of shape `(batch_size, seq_len, d_model)`.
211
+ is_prefill: Whether the operation mode is prefill (otherwise it is generate).
212
+ kv_cache: The key-value cache for storing past attention states.
213
+ attention_metadata: Metadata for attention, such as input positions.
214
+
215
+ Returns:
216
+ A tuple containing:
217
+ - The updated KV cache.
218
+ - The attention output tensor of shape
219
+ `(batch_size, seq_len, d_model)`.
220
+ """
221
+ md = attention_metadata
222
+ x = jnp.asarray(x, self.dtype)
223
+ x_SD = nnx.with_sharding_constraint(x, self.activation_attention_td)
224
+ x_q_TD = nnx.with_sharding_constraint(x, self.activation_q_td)
225
+
226
+ with jax.named_scope("q_proj"):
227
+ # Query down projection.
228
+ q_TA = jnp.einsum("TD,DA -> TA", x_q_TD,
229
+ self.kernel_q_down_proj_DA.value)
230
+ q_TA = self.q_rms_norm(q_TA)
231
+ # Query up projection, then reshape to TNH.
232
+ q_TP = jnp.einsum("TA,AP -> TP", q_TA,
233
+ self.kernel_q_up_proj_AP.value)
234
+ q_TNH = q_TP.reshape(q_TA.shape[0], self.N, self.qk_head_dim)
235
+ # Split the query into nope and rope.
236
+ q_nope_TNH = q_TNH[..., :self.qk_nope_head_dim]
237
+ q_rope_TNH = q_TNH[..., self.qk_nope_head_dim:]
238
+ q_rope_TNH = self.rope.apply_rope(md.input_positions, q_rope_TNH)
239
+ if self.use_mla_kernel:
240
+ # Absorb the k up-projection matrix into q
241
+ q_TNA = jnp.einsum("TNH,ANH -> TNA", q_nope_TNH,
242
+ self.kernel_k_up_proj_ANH.value)
243
+ q_TNA = nnx.with_sharding_constraint(q_TNA, self.query_tnh)
244
+ else:
245
+ # Concatenate the nope and rope queries.
246
+ q_TNH = jnp.concatenate([q_nope_TNH, q_rope_TNH], axis=-1)
247
+ # Multiply the query by scaling factor
248
+ q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
249
+
250
+ with jax.named_scope("kv_proj"):
251
+ # KV down projection.
252
+ kv_SA = jnp.einsum("SD,DA -> SA", x_SD,
253
+ self.kernel_kv_down_proj_DA.value)
254
+ # Split the key and value into latent kv vector and k rope vector.
255
+ k_rope_SH = kv_SA[..., self.kv_lora_rank:]
256
+ # Reshape k_rope_BSH to include head dimension for RoPE application
257
+ k_rope_SNH = k_rope_SH[..., None, :]
258
+ k_rope_SNH = self.rope.apply_rope(md.input_positions, k_rope_SNH)
259
+ assert k_rope_SNH.shape[1] == 1
260
+ k_rope_SH = k_rope_SNH[:, 0, :]
261
+
262
+ kv_SA = kv_SA[..., :self.kv_lora_rank]
263
+ kv_SA = self.kv_rms_norm(kv_SA)
264
+ kv_SA = nnx.with_sharding_constraint(kv_SA, self.keyvalue_skh)
265
+
266
+ if not self.use_mla_kernel:
267
+ k_rope_SNH = jnp.broadcast_to(
268
+ k_rope_SNH,
269
+ (k_rope_SNH.shape[0], self.N, self.qk_rope_head_dim))
270
+ # KV up projection, then reshape to SN(Hk+Hv).
271
+ kv_SL = jnp.einsum("SA,AL -> SL", kv_SA,
272
+ self.kernel_kv_up_proj_AL.value)
273
+ kv_nope_SNH = kv_SL.reshape(
274
+ kv_SA.shape[0], self.N,
275
+ self.qk_nope_head_dim + self.v_head_dim)
276
+ # Split the latent kv vector into k nope vector and v vector.
277
+ k_nope_SNH = kv_nope_SNH[..., :self.qk_nope_head_dim]
278
+ v_SNH = kv_nope_SNH[..., self.qk_nope_head_dim:]
279
+ # Concatenate the key vector.
280
+ k_SNH = jnp.concatenate([k_nope_SNH, k_rope_SNH], axis=-1)
281
+ k_SNH = nnx.with_sharding_constraint(k_SNH, self.keyvalue_skh)
282
+ v_SNH = nnx.with_sharding_constraint(v_SNH, self.keyvalue_skh)
283
+
284
+ with jax.named_scope("attn_op"):
285
+ # TODO(wenxindongwork): K and V have different head dimension,
286
+ # which is not supported by the current kv cache implementation.
287
+ # For now we are padding the v dimension to match the k dimension.
288
+ # Furthermore, deepseekv3 k head dimension is 192, which is
289
+ # not supported by the current attention kernel, which expects
290
+ # q, k, v head dimension to be multiple of 128. For now, we will
291
+ # pad the q, k, v dimension to multiple of 128.
292
+ # We should update the MLA kv cache implementation in the future.
293
+ if not self.use_mla_kernel: # MLA kernel handles padding
294
+ multiple_of_128 = ((self.qk_head_dim - 1) // 128 + 1) * 128
295
+ q_TNH = jnp.pad(q_TNH,
296
+ ((0, 0), (0, 0),
297
+ (0, multiple_of_128 - self.qk_head_dim)))
298
+ k_SNH = jnp.pad(k_SNH,
299
+ ((0, 0), (0, 0),
300
+ (0, multiple_of_128 - self.qk_head_dim)))
301
+ v_SNH = jnp.pad(v_SNH,
302
+ ((0, 0), (0, 0),
303
+ (0, multiple_of_128 - self.v_head_dim)))
304
+
305
+ q_scale = k_scale = v_scale = None
306
+
307
+ # TODO(gpolovets): MLA does not currently support quantized KV!
308
+ if not self.use_mla_kernel:
309
+ if self.kv_cache_quantized_dtype:
310
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
311
+ k_scale = self._k_scale
312
+ v_scale = self._v_scale
313
+ k_SNH, v_SNH = utils.quantize_kv(
314
+ k_SNH, v_SNH, self.kv_cache_quantized_dtype, k_scale,
315
+ v_scale)
316
+
317
+ new_kv_cache, outputs_TNH = self.attention(
318
+ is_prefill,
319
+ kv_cache,
320
+ q_TNH,
321
+ k_SNH,
322
+ v_SNH,
323
+ attention_metadata,
324
+ self.mesh,
325
+ q_scale,
326
+ k_scale,
327
+ v_scale,
328
+ )
329
+ # TODO(wenxindongwork): For now, unpad the outputs_TNH to match the v_head_dim.
330
+ # We shall add the MLA kv cache implementation in the future.
331
+ outputs_TNH = outputs_TNH[..., :self.v_head_dim]
332
+
333
+ else:
334
+ new_kv_cache, outputs_TNA = self.mla_attention(
335
+ kv_cache,
336
+ q_TNA,
337
+ q_rope_TNH,
338
+ kv_SA,
339
+ k_rope_SH,
340
+ attention_metadata,
341
+ self.mesh,
342
+ )
343
+ outputs_TNH = jnp.einsum("TNA,ANH -> TNH", outputs_TNA,
344
+ self.kernel_v_up_proj_ANH.value)
345
+
346
+ with jax.named_scope("o_proj"):
347
+ outputs_TNH = nnx.with_sharding_constraint(
348
+ outputs_TNH, self.activation_attention_out_td)
349
+ outputs_TR = outputs_TNH.reshape(outputs_TNH.shape[0],
350
+ self.N * self.v_head_dim)
351
+ o_TD = jnp.einsum("TR,RD -> TD", outputs_TR,
352
+ self.kernel_o_proj_RD.value)
353
+
354
+ return new_kv_cache, o_TD
355
+
356
+ def attention(
357
+ self,
358
+ is_prefill: bool,
359
+ kv_cache: KVCache,
360
+ q_TNH: jax.Array,
361
+ k_SKH: jax.Array,
362
+ v_SKH: jax.Array,
363
+ attention_metadata: AttentionMetadata,
364
+ mesh: Mesh,
365
+ q_scale: float | None = None,
366
+ k_scale: float | None = None,
367
+ v_scale: float | None = None,
368
+ ) -> Tuple[KVCache, jax.Array]:
369
+ """Performs scaled dot-product attention and updates the KV cache.
370
+
371
+ This function handles the core attention logic, which varies between
372
+ prefill and generation modes. In prefill, it computes self-attention
373
+ over the input sequence with a causal mask. In generation, it attends
374
+ to the full history of keys and values stored in the cache.
375
+
376
+ Args:
377
+ is_prefill: A boolean indicating if the mode is 'prefill'.
378
+ kv_cache: The key-value cache to be updated and used.
379
+ q_TNH: Query tensor of shape `(query_seq, num_attention_heads, head_dim)`.
380
+ k_SKH: Key tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
381
+ v_SKH: Value tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
382
+ attention_metadata: Metadata containing sequence lengths.
383
+ mesh: The JAX device mesh (unused in this specific function but
384
+ kept for potential future use or API consistency).
385
+ q_scale: Quantization scale for q.
386
+ k_scale: Quantization scale for k.
387
+ v_scale: Quantization scale for v.
388
+
389
+ Returns:
390
+ A tuple containing:
391
+ - The updated KV cache.
392
+ - The attention output tensor of shape
393
+ `(seq, num_q_heads, head_dim)`.
394
+ """
395
+ md = attention_metadata
396
+ in_specs = (
397
+ self.query_tnh, # q
398
+ self.keyvalue_skh, # k
399
+ self.keyvalue_skh, # v
400
+ P(None, None, "model"), # kv_cache
401
+ P(), # md.seq_lens: Replicated
402
+ P(), # page_indices_flat: Replicated
403
+ P(), # query_start_loc: Replicated
404
+ P(), # distribution: Replicated
405
+ )
406
+ out_specs = (self.attn_o_tnh, P(None, None, "model"))
407
+
408
+ def _ragged_paged_attention(*args):
409
+ outputs = ragged_paged_attention(
410
+ *args,
411
+ sm_scale=self.scale,
412
+ q_scale=q_scale,
413
+ k_scale=k_scale,
414
+ v_scale=v_scale,
415
+ )
416
+ return outputs
417
+
418
+ output_TNH, kv_cache = jax.jit(
419
+ jax.shard_map(
420
+ _ragged_paged_attention,
421
+ mesh=mesh,
422
+ in_specs=in_specs,
423
+ out_specs=out_specs,
424
+ check_vma=False,
425
+ ))(
426
+ q_TNH,
427
+ k_SKH,
428
+ v_SKH,
429
+ kv_cache,
430
+ md.seq_lens,
431
+ md.block_tables,
432
+ md.query_start_loc,
433
+ md.request_distribution,
434
+ )
435
+ return kv_cache, output_TNH
436
+
437
+ def mla_attention(
438
+ self,
439
+ kv_cache: KVCache,
440
+ q_TNA: jax.Array,
441
+ q_rope_TNH: jax.Array,
442
+ k_SA: jax.Array,
443
+ k_rope_SH: jax.Array,
444
+ attention_metadata: AttentionMetadata,
445
+ mesh: Mesh,
446
+ ) -> Tuple[KVCache, jax.Array]:
447
+ """Performs scaled dot-product attention and updates the KV cache.
448
+
449
+ This function handles the core attention logic, which varies between
450
+ prefill and generation modes. In prefill, it computes self-attention
451
+ over the input sequence with a causal mask. In generation, it attends
452
+ to the full history of keys and values stored in the cache.
453
+
454
+ Args:
455
+ kv_cache: The key-value cache to be updated and used.
456
+ q_TNA: Query tensor of shape `(query_seq, num_attention_heads, lkv_dim)`.
457
+ q_rope_TNH: Query rope tensor of shape `(query_seq, num_attention_heads, rope_dim)`.
458
+ k_SA: Key tensor of shape `(kv_seq, lkv_dim)`.
459
+ k_rope_SH: Key rope tensor of shape `(kv_seq, rope_dim)`.
460
+ attention_metadata: Metadata containing sequence lengths.
461
+ mesh: The JAX device mesh (unused in this specific function but
462
+ kept for potential future use or API consistency).
463
+ q_scale: Quantization scale for q.
464
+ k_scale: Quantization scale for k.
465
+ v_scale: Quantization scale for v.
466
+
467
+ Returns:
468
+ A tuple containing:
469
+ - The updated KV cache.
470
+ - The attention output tensor of shape
471
+ `(seq, num_q_heads, head_dim)`.
472
+ """
473
+ md = attention_metadata
474
+ in_specs = (
475
+ self.query_tnh, # q
476
+ self.query_tnh, # q_rope
477
+ self.keyvalue_skh, # k
478
+ self.keyvalue_skh, # k_rope
479
+ P(ShardingAxisName.MLP_TENSOR), # kv_cache
480
+ P(ShardingAxisName.ATTN_DATA), # md.seq_lens: Replicated
481
+ P(ShardingAxisName.ATTN_DATA), # page_indices_flat: Replicated
482
+ P(ShardingAxisName.ATTN_DATA), # query_start_loc: Replicated
483
+ P(ShardingAxisName.ATTN_DATA), # distribution: Replicated
484
+ )
485
+
486
+ out_specs = (self.attn_o_tnh, P(ShardingAxisName.MLP_TENSOR))
487
+
488
+ def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args):
489
+
490
+ def _initialize_block_sizes():
491
+ # Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes)
492
+ # Referring to get_tuned_block_sizes() in kernels/ragged_paged_attention/v3/tuned_block_sizes.py: 'TPU v7'/128/'q_bfloat16_kv_bfloat16/q_head-128_kv_head-1_head-128'/4096
493
+ max_num_tokens = q.shape[0]
494
+ max_num_seqs = md.seq_lens.shape[0]
495
+ num_page_indices = md.block_tables.shape[0]
496
+ assert num_page_indices % max_num_seqs == 0
497
+ pages_per_seq = num_page_indices // max_num_seqs
498
+ # num_kv_pages_per_block = min(pages_per_seq, 16)
499
+ bkv_p, bq_sz = get_tuned_block_sizes(
500
+ q.dtype,
501
+ kv_cache.dtype,
502
+ self.num_attention_heads,
503
+ 1,
504
+ self.qk_nope_head_dim,
505
+ kv_cache.shape[1], # page size
506
+ max_num_tokens,
507
+ pages_per_seq,
508
+ )
509
+ num_kv_pages_per_block = min(min(pages_per_seq, bkv_p), 4)
510
+ num_queries_per_block = min(min(max_num_tokens, bq_sz),
511
+ 4) # OOMS at 8
512
+ return num_kv_pages_per_block, num_queries_per_block
513
+
514
+ num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes(
515
+ )
516
+ output, kv_cache = mla_ragged_paged_attention(
517
+ q,
518
+ q_rope,
519
+ k,
520
+ k_rope,
521
+ kv_cache,
522
+ *args,
523
+ sm_scale=self.scale,
524
+ num_kv_pages_per_block=num_kv_pages_per_block,
525
+ num_queries_per_block=num_queries_per_block)
526
+
527
+ return kv_cache, output
528
+
529
+ kv_cache, output_TNH = jax.jit(
530
+ jax.shard_map(
531
+ _mla_ragged_paged_attention,
532
+ mesh=mesh,
533
+ in_specs=in_specs,
534
+ out_specs=out_specs,
535
+ check_vma=False,
536
+ ), )(
537
+ q_TNA,
538
+ q_rope_TNH,
539
+ k_SA,
540
+ k_rope_SH,
541
+ kv_cache,
542
+ md.seq_lens,
543
+ md.block_tables,
544
+ md.query_start_loc,
545
+ md.request_distribution,
546
+ )
547
+ return kv_cache, output_TNH