tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,268 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import InitVar, dataclass
16
+ from typing import Any, Tuple
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from flax import nnx
21
+ from flax.typing import Sharding
22
+ from jax.sharding import Mesh
23
+ from jax.sharding import PartitionSpec as P
24
+
25
+ from tpu_inference import utils
26
+ from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
27
+ ragged_paged_attention
28
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29
+ from tpu_inference.layers.common.quantization import quantize_kv
30
+ from tpu_inference.layers.common.sharding import ShardingAxisName
31
+ from tpu_inference.layers.jax.base import create_param
32
+ from tpu_inference.layers.jax.rope_interface import apply_rope
33
+
34
+ KVCache = Tuple[jax.Array, jax.Array]
35
+
36
+
37
+ @dataclass(kw_only=True)
38
+ class Attention(nnx.Module):
39
+ """An implementation of attention.
40
+
41
+ This module performs the attention mechanism for a transformer model,
42
+ including query, key, and value projections, application of Rotary
43
+ Position Embeddings (RoPE), and management of a KV cache for efficient
44
+ autoregressive generation. It supports both prefill and generation
45
+ (decode) modes and handles tensor sharding for distributed computation.
46
+
47
+ Attributes:
48
+ mesh: The JAX device mesh for distributed computation.
49
+ """
50
+ hidden_size: int
51
+ num_attention_heads: int
52
+ num_key_value_heads: int
53
+ head_dim: int
54
+ rope_theta: float
55
+ rope_scaling: dict[str, Any]
56
+ dtype: jnp.dtype
57
+ mesh: Mesh
58
+ kv_cache_dtype: str
59
+
60
+ dnh_sharding: Sharding = ()
61
+ dkh_sharding: Sharding = ()
62
+ nhd_sharding: Sharding = ()
63
+
64
+ activation_q_td: Sharding = (ShardingAxisName.ATTN_DATA)
65
+ query_tnh: P = P(ShardingAxisName.ATTN_DATA)
66
+ keyvalue_skh: P = P(ShardingAxisName.ATTN_DATA)
67
+
68
+ attn_o_tnh: P = P(ShardingAxisName.ATTN_DATA)
69
+ rngs: InitVar[nnx.Rngs]
70
+
71
+ random_init: bool = False
72
+ attention_chunk_size: int | None = None
73
+ rope_input_ordering: str = "split"
74
+
75
+ _q_scale: float = 1.0
76
+ _k_scale: float = 1.0
77
+ _v_scale: float = 1.0
78
+
79
+ kv_cache_quantized_dtype = None
80
+
81
+ def __post_init__(self, rngs: nnx.Rngs):
82
+ """Initializes the weight kernels for Q, K, V, and O projections."""
83
+ N = self.num_attention_heads
84
+ K = self.num_key_value_heads
85
+ D = self.hidden_size
86
+ H = self.head_dim
87
+
88
+ self.kernel_q_proj_DNH = create_param(rngs, (D, N, H),
89
+ self.dnh_sharding,
90
+ self.dtype,
91
+ random_init=self.random_init)
92
+ self.kernel_k_proj_DKH = create_param(rngs, (D, K, H),
93
+ self.dkh_sharding,
94
+ self.dtype,
95
+ random_init=self.random_init)
96
+ self.kernel_v_proj_DKH = create_param(rngs, (D, K, H),
97
+ self.dkh_sharding,
98
+ self.dtype,
99
+ random_init=self.random_init)
100
+ self.kernel_o_proj_NHD = create_param(rngs, (N, H, D),
101
+ self.nhd_sharding,
102
+ self.dtype,
103
+ random_init=self.random_init)
104
+
105
+ if self.kv_cache_dtype != "auto":
106
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
107
+ self.kv_cache_dtype)
108
+
109
+ def __call__(self,
110
+ x,
111
+ is_prefill,
112
+ kv_cache: KVCache,
113
+ attention_metadata: AttentionMetadata,
114
+ use_attention_rope: bool = True):
115
+ """Performs the forward pass of the attention module.
116
+
117
+ This method computes the attention output by projecting the input `x`
118
+ to queries, keys, and values, applying RoPE, performing scaled
119
+ dot-product attention, and projecting the result back to the model
120
+ dimension. It updates and utilizes a KV cache.
121
+
122
+ Args:
123
+ x: The input tensor of shape `(seq_len, d_model)`.
124
+ is_prefill: Whether the operation mode is prefill (otherwise it is generate).
125
+ kv_cache: The key-value cache for storing past attention states.
126
+ attention_metadata: Metadata for attention, such as input positions.
127
+ use_attention_rope: Whether to use RoPE.
128
+
129
+ Returns:
130
+ A tuple containing:
131
+ - The updated KV cache.
132
+ - The attention output tensor of shape
133
+ `(batch_size, seq_len, d_model)`.
134
+ """
135
+ md = attention_metadata
136
+ x_SD = jnp.asarray(x, self.dtype)
137
+ x_q_TD = nnx.with_sharding_constraint(x, self.activation_q_td)
138
+ H = self.head_dim
139
+ with jax.named_scope("q_proj"):
140
+ q_TNH = jnp.einsum('TD,DNH -> TNH', x_q_TD,
141
+ self.kernel_q_proj_DNH.value)
142
+ if use_attention_rope:
143
+ q_TNH = apply_rope(q_TNH, md.input_positions, H,
144
+ self.rope_theta, self.rope_scaling,
145
+ self.rope_input_ordering)
146
+ q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
147
+ with jax.named_scope("k_proj"):
148
+ k_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
149
+ self.kernel_k_proj_DKH.value)
150
+ if use_attention_rope:
151
+ k_SKH = apply_rope(k_SKH, md.input_positions, H,
152
+ self.rope_theta, self.rope_scaling,
153
+ self.rope_input_ordering)
154
+ k_SKH = nnx.with_sharding_constraint(k_SKH, self.keyvalue_skh)
155
+
156
+ with jax.named_scope("v_proj"):
157
+ v_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
158
+ self.kernel_v_proj_DKH.value)
159
+
160
+ q_scale = k_scale = v_scale = None
161
+ if self.kv_cache_quantized_dtype:
162
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
163
+ # q_scale = self._q_scale
164
+ k_scale = self._k_scale
165
+ v_scale = self._v_scale
166
+ k_SKH, v_SKH = quantize_kv(self.kv_cache_quantized_dtype, k_SKH,
167
+ v_SKH, k_scale, v_scale)
168
+
169
+ with jax.named_scope("attn_op"):
170
+ new_kv_cache, outputs_TNH = self.attention(
171
+ is_prefill,
172
+ kv_cache,
173
+ q_TNH,
174
+ k_SKH,
175
+ v_SKH,
176
+ attention_metadata,
177
+ self.mesh,
178
+ q_scale=q_scale,
179
+ k_scale=k_scale,
180
+ v_scale=v_scale,
181
+ )
182
+
183
+ with jax.named_scope("o_proj"):
184
+ o_TD = jnp.einsum('TNH,NHD -> TD', outputs_TNH,
185
+ self.kernel_o_proj_NHD.value)
186
+ return new_kv_cache, o_TD
187
+
188
+ def attention(
189
+ self,
190
+ is_prefill: bool,
191
+ kv_cache: KVCache,
192
+ q_TNH: jax.Array,
193
+ k_SKH: jax.Array,
194
+ v_SKH: jax.Array,
195
+ attention_metadata: AttentionMetadata,
196
+ mesh: Mesh,
197
+ q_scale: float | None = None,
198
+ k_scale: float | None = None,
199
+ v_scale: float | None = None,
200
+ ) -> Tuple[KVCache, jax.Array]:
201
+ """Performs scaled dot-product attention and updates the KV cache.
202
+
203
+ This function handles the core attention logic, which varies between
204
+ prefill and generation modes. In prefill, it computes self-attention
205
+ over the input sequence with a causal mask. In generation, it attends
206
+ to the full history of keys and values stored in the cache.
207
+
208
+ Args:
209
+ is_prefill: A boolean indicating if the mode is 'prefill'.
210
+ kv_cache: The key-value cache to be updated and used.
211
+ q_TNH: Query tensor of shape `(query_seq, num_attention_heads, head_dim)`.
212
+ k_SKH: Key tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
213
+ v_SKH: Value tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
214
+ attention_metadata: Metadata containing sequence lengths.
215
+ mesh: The JAX device mesh (unused in this specific function but
216
+ kept for potential future use or API consistency).
217
+ q_scale: Quantization scale for q.
218
+ k_scale: Quantization scale for k.
219
+ v_scale: Quantization scale for v.
220
+
221
+ Returns:
222
+ A tuple containing:
223
+ - The updated KV cache.
224
+ - The attention output tensor of shape
225
+ `(seq, num_q_heads, head_dim)`.
226
+ """
227
+ md = attention_metadata
228
+ kv_cache_spec = P(ShardingAxisName.ATTN_DATA, None, "model")
229
+ in_specs = (
230
+ self.query_tnh, # q
231
+ self.keyvalue_skh, # k
232
+ self.keyvalue_skh, # v
233
+ kv_cache_spec, # kv_cache
234
+ P(ShardingAxisName.ATTN_DATA), # md.seq_lens
235
+ P(ShardingAxisName.ATTN_DATA), # page_indices_flat
236
+ P(ShardingAxisName.ATTN_DATA), # query_start_loc
237
+ P(ShardingAxisName.ATTN_DATA), # distribution
238
+ )
239
+
240
+ out_specs = (self.attn_o_tnh, kv_cache_spec)
241
+
242
+ def _ragged_paged_attention(*args):
243
+ return ragged_paged_attention(
244
+ *args,
245
+ sm_scale=q_TNH.shape[-1]**-0.5,
246
+ q_scale=q_scale,
247
+ k_scale=k_scale,
248
+ v_scale=v_scale,
249
+ )
250
+
251
+ output_TNH, kv_cache = jax.jit(
252
+ jax.shard_map(
253
+ _ragged_paged_attention,
254
+ mesh=mesh,
255
+ in_specs=in_specs,
256
+ out_specs=out_specs,
257
+ check_vma=False,
258
+ ))(
259
+ q_TNH,
260
+ k_SKH,
261
+ v_SKH,
262
+ kv_cache,
263
+ md.seq_lens,
264
+ md.block_tables,
265
+ md.query_start_loc,
266
+ md.request_distribution,
267
+ )
268
+ return kv_cache, output_TNH