tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,255 @@
1
+ from dataclasses import InitVar, dataclass
2
+ from typing import Any, Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax import nnx
7
+ from flax.typing import Sharding
8
+ from jax.experimental import shard_map
9
+ from jax.sharding import Mesh
10
+ from jax.sharding import PartitionSpec as P
11
+
12
+ from tpu_inference import utils
13
+ from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
14
+ ragged_paged_attention
15
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
16
+ from tpu_inference.layers.common.sharding import ShardingAxisName
17
+ from tpu_inference.layers.jax.base import create_param
18
+ from tpu_inference.layers.jax.rope_interface import apply_rope
19
+
20
+ KVCache = Tuple[jax.Array, jax.Array]
21
+
22
+
23
+ @dataclass(kw_only=True)
24
+ class Attention(nnx.Module):
25
+ """An implementation of attention.
26
+
27
+ This module performs the attention mechanism for a transformer model,
28
+ including query, key, and value projections, application of Rotary
29
+ Position Embeddings (RoPE), and management of a KV cache for efficient
30
+ autoregressive generation. It supports both prefill and generation
31
+ (decode) modes and handles tensor sharding for distributed computation.
32
+
33
+ Attributes:
34
+ mesh: The JAX device mesh for distributed computation.
35
+ """
36
+ hidden_size: int
37
+ num_attention_heads: int
38
+ num_key_value_heads: int
39
+ head_dim: int
40
+ rope_theta: float
41
+ rope_scaling: dict[str, Any]
42
+ dtype: jnp.dtype
43
+ mesh: Mesh
44
+ kv_cache_dtype: str
45
+
46
+ dnh_sharding: Sharding = ()
47
+ dkh_sharding: Sharding = ()
48
+ nhd_sharding: Sharding = ()
49
+
50
+ activation_q_td: Sharding = (ShardingAxisName.ATTN_DATA)
51
+ query_tnh: P = P(ShardingAxisName.ATTN_DATA)
52
+ keyvalue_skh: P = P(ShardingAxisName.ATTN_DATA)
53
+
54
+ attn_o_tnh: P = P(ShardingAxisName.ATTN_DATA)
55
+ rngs: InitVar[nnx.Rngs]
56
+
57
+ random_init: bool = False
58
+ attention_chunk_size: int | None = None
59
+ rope_input_ordering: str = "split"
60
+
61
+ _q_scale: float = 1.0
62
+ _k_scale: float = 1.0
63
+ _v_scale: float = 1.0
64
+
65
+ kv_cache_quantized_dtype = None
66
+
67
+ def __post_init__(self, rngs: nnx.Rngs):
68
+ """Initializes the weight kernels for Q, K, V, and O projections."""
69
+ N = self.num_attention_heads
70
+ K = self.num_key_value_heads
71
+ D = self.hidden_size
72
+ H = self.head_dim
73
+
74
+ self.kernel_q_proj_DNH = create_param(rngs, (D, N, H),
75
+ self.dnh_sharding,
76
+ self.dtype,
77
+ random_init=self.random_init)
78
+ self.kernel_k_proj_DKH = create_param(rngs, (D, K, H),
79
+ self.dkh_sharding,
80
+ self.dtype,
81
+ random_init=self.random_init)
82
+ self.kernel_v_proj_DKH = create_param(rngs, (D, K, H),
83
+ self.dkh_sharding,
84
+ self.dtype,
85
+ random_init=self.random_init)
86
+ self.kernel_o_proj_NHD = create_param(rngs, (N, H, D),
87
+ self.nhd_sharding,
88
+ self.dtype,
89
+ random_init=self.random_init)
90
+
91
+ if self.kv_cache_dtype != "auto":
92
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
93
+ self.kv_cache_dtype)
94
+
95
+ def __call__(self,
96
+ x,
97
+ is_prefill,
98
+ kv_cache: KVCache,
99
+ attention_metadata: AttentionMetadata,
100
+ use_attention_rope: bool = True):
101
+ """Performs the forward pass of the attention module.
102
+
103
+ This method computes the attention output by projecting the input `x`
104
+ to queries, keys, and values, applying RoPE, performing scaled
105
+ dot-product attention, and projecting the result back to the model
106
+ dimension. It updates and utilizes a KV cache.
107
+
108
+ Args:
109
+ x: The input tensor of shape `(seq_len, d_model)`.
110
+ is_prefill: Whether the operation mode is prefill (otherwise it is generate).
111
+ kv_cache: The key-value cache for storing past attention states.
112
+ attention_metadata: Metadata for attention, such as input positions.
113
+ use_attention_rope: Whether to use RoPE.
114
+
115
+ Returns:
116
+ A tuple containing:
117
+ - The updated KV cache.
118
+ - The attention output tensor of shape
119
+ `(batch_size, seq_len, d_model)`.
120
+ """
121
+ md = attention_metadata
122
+ x_SD = jnp.asarray(x, self.dtype)
123
+ x_q_TD = nnx.with_sharding_constraint(x, self.activation_q_td)
124
+ H = self.head_dim
125
+ with jax.named_scope("q_proj"):
126
+ q_TNH = jnp.einsum('TD,DNH -> TNH', x_q_TD,
127
+ self.kernel_q_proj_DNH.value)
128
+ if use_attention_rope:
129
+ q_TNH = apply_rope(q_TNH, md.input_positions, H,
130
+ self.rope_theta, self.rope_scaling,
131
+ self.rope_input_ordering)
132
+ q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
133
+ with jax.named_scope("k_proj"):
134
+ k_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
135
+ self.kernel_k_proj_DKH.value)
136
+ if use_attention_rope:
137
+ k_SKH = apply_rope(k_SKH, md.input_positions, H,
138
+ self.rope_theta, self.rope_scaling,
139
+ self.rope_input_ordering)
140
+ k_SKH = nnx.with_sharding_constraint(k_SKH, self.keyvalue_skh)
141
+
142
+ with jax.named_scope("v_proj"):
143
+ v_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
144
+ self.kernel_v_proj_DKH.value)
145
+
146
+ q_scale = k_scale = v_scale = None
147
+ if self.kv_cache_quantized_dtype:
148
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
149
+ # q_scale = self._q_scale
150
+ k_scale = self._k_scale
151
+ v_scale = self._v_scale
152
+ k_SKH, v_SKH = utils.quantize_kv(k_SKH, v_SKH,
153
+ self.kv_cache_quantized_dtype,
154
+ k_scale, v_scale)
155
+
156
+ with jax.named_scope("attn_op"):
157
+ new_kv_cache, outputs_TNH = self.attention(
158
+ is_prefill,
159
+ kv_cache,
160
+ q_TNH,
161
+ k_SKH,
162
+ v_SKH,
163
+ attention_metadata,
164
+ self.mesh,
165
+ q_scale=q_scale,
166
+ k_scale=k_scale,
167
+ v_scale=v_scale,
168
+ )
169
+
170
+ with jax.named_scope("o_proj"):
171
+ o_TD = jnp.einsum('TNH,NHD -> TD', outputs_TNH,
172
+ self.kernel_o_proj_NHD.value)
173
+ return new_kv_cache, o_TD
174
+
175
+ def attention(
176
+ self,
177
+ is_prefill: bool,
178
+ kv_cache: KVCache,
179
+ q_TNH: jax.Array,
180
+ k_SKH: jax.Array,
181
+ v_SKH: jax.Array,
182
+ attention_metadata: AttentionMetadata,
183
+ mesh: Mesh,
184
+ q_scale: float | None = None,
185
+ k_scale: float | None = None,
186
+ v_scale: float | None = None,
187
+ ) -> Tuple[KVCache, jax.Array]:
188
+ """Performs scaled dot-product attention and updates the KV cache.
189
+
190
+ This function handles the core attention logic, which varies between
191
+ prefill and generation modes. In prefill, it computes self-attention
192
+ over the input sequence with a causal mask. In generation, it attends
193
+ to the full history of keys and values stored in the cache.
194
+
195
+ Args:
196
+ is_prefill: A boolean indicating if the mode is 'prefill'.
197
+ kv_cache: The key-value cache to be updated and used.
198
+ q_TNH: Query tensor of shape `(query_seq, num_attention_heads, head_dim)`.
199
+ k_SKH: Key tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
200
+ v_SKH: Value tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
201
+ attention_metadata: Metadata containing sequence lengths.
202
+ mesh: The JAX device mesh (unused in this specific function but
203
+ kept for potential future use or API consistency).
204
+ q_scale: Quantization scale for q.
205
+ k_scale: Quantization scale for k.
206
+ v_scale: Quantization scale for v.
207
+
208
+ Returns:
209
+ A tuple containing:
210
+ - The updated KV cache.
211
+ - The attention output tensor of shape
212
+ `(seq, num_q_heads, head_dim)`.
213
+ """
214
+ md = attention_metadata
215
+ kv_cache_spec = P(ShardingAxisName.ATTN_DATA, None, "model")
216
+ in_specs = (
217
+ self.query_tnh, # q
218
+ self.keyvalue_skh, # k
219
+ self.keyvalue_skh, # v
220
+ kv_cache_spec, # kv_cache
221
+ P(ShardingAxisName.ATTN_DATA), # md.seq_lens
222
+ P(ShardingAxisName.ATTN_DATA), # page_indices_flat
223
+ P(ShardingAxisName.ATTN_DATA), # query_start_loc
224
+ P(ShardingAxisName.ATTN_DATA), # distribution
225
+ )
226
+
227
+ out_specs = (self.attn_o_tnh, kv_cache_spec)
228
+
229
+ def _ragged_paged_attention(*args):
230
+ return ragged_paged_attention(
231
+ *args,
232
+ sm_scale=q_TNH.shape[-1]**-0.5,
233
+ q_scale=q_scale,
234
+ k_scale=k_scale,
235
+ v_scale=v_scale,
236
+ )
237
+
238
+ output_TNH, kv_cache = jax.jit(
239
+ shard_map.shard_map(
240
+ _ragged_paged_attention,
241
+ mesh=mesh,
242
+ in_specs=in_specs,
243
+ out_specs=out_specs,
244
+ check_rep=False,
245
+ ))(
246
+ q_TNH,
247
+ k_SKH,
248
+ v_SKH,
249
+ kv_cache,
250
+ md.seq_lens,
251
+ md.block_tables,
252
+ md.query_start_loc,
253
+ md.request_distribution,
254
+ )
255
+ return kv_cache, output_TNH
@@ -0,0 +1,354 @@
1
+ import math
2
+ from dataclasses import InitVar, dataclass
3
+ from typing import Any, Tuple
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from flax import nnx
8
+ from flax.typing import Sharding
9
+ from jax.experimental import shard_map
10
+ from jax.sharding import Mesh
11
+ from jax.sharding import PartitionSpec as P
12
+
13
+ from tpu_inference import utils
14
+ from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
15
+ ragged_paged_attention
16
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
17
+ from tpu_inference.layers.jax.base import create_param
18
+ from tpu_inference.layers.jax.layers import RMSNorm
19
+ from tpu_inference.layers.jax.rope import DeepseekScalingRotaryEmbedding
20
+
21
+ KVCache = Tuple[jax.Array, jax.Array]
22
+
23
+
24
+ # TODO (wenxindongwork): Add MLA KV cache implementation. For now, cache complete KV vectors.
25
+ @dataclass(kw_only=True)
26
+ class MLA(nnx.Module):
27
+ """An implementation of Multi-Head Latent Attention as
28
+ described in the DeepSeek V3 paper.
29
+
30
+ Attributes:
31
+ mesh: The JAX device mesh for distributed computation.
32
+ """
33
+ hidden_size: int
34
+ num_attention_heads: int
35
+ num_key_value_heads: int
36
+ head_dim: int
37
+ rope_theta: float
38
+ rope_scaling: dict[str, Any]
39
+ dtype: jnp.dtype
40
+ kv_cache_dtype: str
41
+ mesh: Mesh
42
+
43
+ q_lora_rank: int
44
+ kv_lora_rank: int
45
+ qk_nope_head_dim: int
46
+ qk_rope_head_dim: int
47
+ v_head_dim: int
48
+ rms_norm_eps: float
49
+
50
+ # Sharding attributes
51
+ nhd_sharding: Sharding = ()
52
+ q_da_sharding: Sharding = ()
53
+ anh_sharding: Sharding = ()
54
+ kv_da_sharding: Sharding = ()
55
+
56
+ activation_attention_td: Sharding = ()
57
+ activation_q_td: Sharding = ()
58
+ query_tnh: P = P()
59
+ keyvalue_skh: P = P()
60
+
61
+ attn_o_tnh: P = P()
62
+ activation_attention_out_td: Sharding = ()
63
+
64
+ random_init: bool = False
65
+ attention_chunk_size: int | None = None
66
+ rope_input_ordering: str = "split"
67
+ quant: Any | None = None
68
+ rope_mscale_all_dim: float = 1.0
69
+
70
+ rngs: InitVar[nnx.Rngs]
71
+
72
+ _q_scale: float = 1
73
+ _k_scale: float = 1
74
+ _v_scale: float = 1
75
+
76
+ def __post_init__(self, rngs: nnx.Rngs):
77
+ self.N = self.num_attention_heads
78
+ self.K = self.num_key_value_heads
79
+ self.D = self.hidden_size
80
+
81
+ self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
82
+
83
+ assert self.N == self.K, "N and K must be equal for MLA"
84
+
85
+ if self.rope_scaling["factor"] <= 1.0:
86
+ yarn_mscale = 1.0
87
+ else:
88
+ yarn_mscale = 0.1 * self.rope_mscale_all_dim * math.log(
89
+ self.rope_scaling["factor"]) + 1.0
90
+ self.scale = self.qk_head_dim**-0.5 * yarn_mscale**2
91
+
92
+ self.rope = DeepseekScalingRotaryEmbedding(
93
+ rotary_dim=self.qk_rope_head_dim,
94
+ rope_theta=self.rope_theta,
95
+ original_max_position_embeddings=self.
96
+ rope_scaling["original_max_position_embeddings"],
97
+ scaling_factor=self.rope_scaling["factor"],
98
+ dtype=self.dtype,
99
+ beta_fast=self.rope_scaling["beta_fast"],
100
+ beta_slow=self.rope_scaling["beta_slow"],
101
+ mscale_value=self.rope_scaling["mscale"],
102
+ mscale_all_dim=self.rope_scaling["mscale_all_dim"],
103
+ )
104
+
105
+ # Initializes the weight kernels
106
+ self.kernel_q_down_proj_DA = create_param(rngs,
107
+ (self.D, self.q_lora_rank),
108
+ self.q_da_sharding,
109
+ self.dtype,
110
+ random_init=self.random_init)
111
+ self.kernel_q_up_proj_ANH = create_param(
112
+ rngs,
113
+ (self.q_lora_rank, self.N, self.qk_head_dim),
114
+ self.anh_sharding,
115
+ self.dtype,
116
+ random_init=self.random_init,
117
+ )
118
+ self.kernel_kv_down_proj_DA = create_param(
119
+ rngs,
120
+ (self.D, self.kv_lora_rank + self.qk_rope_head_dim),
121
+ self.kv_da_sharding,
122
+ self.dtype,
123
+ random_init=self.random_init,
124
+ )
125
+ self.kernel_kv_up_proj_ANH = create_param(
126
+ rngs,
127
+ (self.kv_lora_rank, self.N,
128
+ self.qk_nope_head_dim + self.v_head_dim),
129
+ self.anh_sharding,
130
+ self.dtype,
131
+ random_init=self.random_init,
132
+ )
133
+ self.kernel_o_proj_NHD = create_param(
134
+ rngs, (self.N, self.v_head_dim, self.D),
135
+ self.nhd_sharding,
136
+ self.dtype,
137
+ random_init=self.random_init)
138
+ self.q_rms_norm = RMSNorm(
139
+ dims=self.q_lora_rank,
140
+ epsilon=self.rms_norm_eps,
141
+ with_scale=True,
142
+ dtype=self.dtype,
143
+ random_init=self.random_init,
144
+ rngs=rngs,
145
+ )
146
+
147
+ self.kv_rms_norm = RMSNorm(
148
+ dims=self.kv_lora_rank,
149
+ random_init=self.random_init,
150
+ epsilon=self.rms_norm_eps,
151
+ with_scale=True,
152
+ dtype=self.dtype,
153
+ rngs=rngs,
154
+ )
155
+
156
+ self.kv_cache_quantized_dtype = None
157
+ if self.kv_cache_dtype != "auto":
158
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
159
+ self.kv_cache_dtype)
160
+
161
+ def __call__(self,
162
+ x,
163
+ is_prefill,
164
+ kv_cache: KVCache,
165
+ attention_metadata: AttentionMetadata,
166
+ use_attention_rope: bool = True):
167
+ """Performs the forward pass of the attention module.
168
+
169
+ Args:
170
+ x: The input tensor of shape `(batch_size, seq_len, d_model)`.
171
+ is_prefill: Whether the operation mode is prefill (otherwise it is generate).
172
+ kv_cache: The key-value cache for storing past attention states.
173
+ attention_metadata: Metadata for attention, such as input positions.
174
+
175
+ Returns:
176
+ A tuple containing:
177
+ - The updated KV cache.
178
+ - The attention output tensor of shape
179
+ `(batch_size, seq_len, d_model)`.
180
+ """
181
+ md = attention_metadata
182
+ x = jnp.asarray(x, self.dtype)
183
+ x_SD = nnx.with_sharding_constraint(x, self.activation_attention_td)
184
+ x_q_TD = nnx.with_sharding_constraint(x, self.activation_q_td)
185
+
186
+ with jax.named_scope("q_proj"):
187
+ # Query down projection.
188
+ q_TA = jnp.einsum("TD,DA -> TA", x_q_TD,
189
+ self.kernel_q_down_proj_DA.value)
190
+ q_TA = self.q_rms_norm(q_TA)
191
+ # Query up projection.
192
+ q_TNH = jnp.einsum("TA,ANH -> TNH", q_TA,
193
+ self.kernel_q_up_proj_ANH.value)
194
+ # Split the query into nope and rope.
195
+ q_nope_TNH = q_TNH[..., :self.qk_nope_head_dim]
196
+ q_rope_TNH = q_TNH[..., self.qk_nope_head_dim:]
197
+ q_rope_TNH = self.rope.apply_rope(md.input_positions, q_rope_TNH)
198
+ # Concatenate the nope and rope queries.
199
+ q_TNH = jnp.concatenate([q_nope_TNH, q_rope_TNH], axis=-1)
200
+ # Multiple the query by scaling factor
201
+ q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
202
+
203
+ with jax.named_scope("kv_proj"):
204
+ # KV down projection.
205
+ kv_SA = jnp.einsum("SD,DA -> SA", x_SD,
206
+ self.kernel_kv_down_proj_DA.value)
207
+ # Split the key and value into latent kv vector and k rope vector.
208
+ k_rope_SH = kv_SA[..., self.kv_lora_rank:]
209
+ # Reshape k_rope_BSH to include head dimension for RoPE application
210
+ k_rope_SNH = k_rope_SH[..., None, :]
211
+ k_rope_SNH = self.rope.apply_rope(md.input_positions, k_rope_SNH)
212
+ k_rope_SNH = jnp.broadcast_to(
213
+ k_rope_SNH,
214
+ (k_rope_SNH.shape[0], self.N, self.qk_rope_head_dim))
215
+ kv_SA = kv_SA[..., :self.kv_lora_rank]
216
+ kv_SA = self.kv_rms_norm(kv_SA)
217
+ # KV up projection.
218
+ kv_nope_SNH = jnp.einsum("SA,ANH -> SNH", kv_SA,
219
+ self.kernel_kv_up_proj_ANH.value)
220
+ # Split the latent kv vector into k nope vector and v vector.
221
+ k_nope_SNH = kv_nope_SNH[..., :self.qk_nope_head_dim]
222
+ v_SNH = kv_nope_SNH[..., self.qk_nope_head_dim:]
223
+ # Concatenate the key vector.
224
+ k_SNH = jnp.concatenate([k_nope_SNH, k_rope_SNH], axis=-1)
225
+ k_SNH = nnx.with_sharding_constraint(k_SNH, self.keyvalue_skh)
226
+ v_SNH = nnx.with_sharding_constraint(v_SNH, self.keyvalue_skh)
227
+
228
+ with jax.named_scope("attn_op"):
229
+ # TODO(wenxindongwork): K and V have different head dimension,
230
+ # which is not supported by the current kv cache implementation.
231
+ # For now we are padding the v dimension to match the k dimension.
232
+ # Furthermore, deepseekv3 k head dimension is 192, which is
233
+ # not supported by the current attention kernel, which expects
234
+ # q, k, v head dimension to be multiple of 128. For now, we will
235
+ # pad the q, k, v dimension to multiple of 128.
236
+ # We should update the MLA kv cache implementation in the future.
237
+ multiple_of_128 = ((self.qk_head_dim - 1) // 128 + 1) * 128
238
+ q_TNH = jnp.pad(q_TNH, ((0, 0), (0, 0),
239
+ (0, multiple_of_128 - self.qk_head_dim)))
240
+ k_SNH = jnp.pad(k_SNH, ((0, 0), (0, 0),
241
+ (0, multiple_of_128 - self.qk_head_dim)))
242
+ v_SNH = jnp.pad(v_SNH, ((0, 0), (0, 0),
243
+ (0, multiple_of_128 - self.v_head_dim)))
244
+ q_scale = k_scale = v_scale = None
245
+ if self.kv_cache_quantized_dtype:
246
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
247
+ # q_scale = self._q_scale
248
+ k_scale = self._k_scale
249
+ v_scale = self._v_scale
250
+ k_SNH, v_SNH = utils.quantize_kv(k_SNH, v_SNH,
251
+ self.kv_cache_quantized_dtype,
252
+ k_scale, v_scale)
253
+ new_kv_cache, outputs_TNH = self.attention(
254
+ is_prefill,
255
+ kv_cache,
256
+ q_TNH,
257
+ k_SNH,
258
+ v_SNH,
259
+ attention_metadata,
260
+ self.mesh,
261
+ q_scale,
262
+ k_scale,
263
+ v_scale,
264
+ )
265
+ # TODO(wenxindongwork): For now, unpad the outputs_TNH to match the v_head_dim.
266
+ # We shall add the MLA kv cache implementation in the future.
267
+ outputs_TNH = outputs_TNH[..., :self.v_head_dim]
268
+
269
+ with jax.named_scope("o_proj"):
270
+ o_TD = jnp.einsum("TNH,NHD -> TD", outputs_TNH,
271
+ self.kernel_o_proj_NHD.value)
272
+ o_TD = nnx.with_sharding_constraint(
273
+ o_TD, self.activation_attention_out_td)
274
+ return new_kv_cache, o_TD
275
+
276
+ def attention(
277
+ self,
278
+ is_prefill: bool,
279
+ kv_cache: KVCache,
280
+ q_TNH: jax.Array,
281
+ k_SKH: jax.Array,
282
+ v_SKH: jax.Array,
283
+ attention_metadata: AttentionMetadata,
284
+ mesh: Mesh,
285
+ q_scale: float | None = None,
286
+ k_scale: float | None = None,
287
+ v_scale: float | None = None,
288
+ ) -> Tuple[KVCache, jax.Array]:
289
+ """Performs scaled dot-product attention and updates the KV cache.
290
+
291
+ This function handles the core attention logic, which varies between
292
+ prefill and generation modes. In prefill, it computes self-attention
293
+ over the input sequence with a causal mask. In generation, it attends
294
+ to the full history of keys and values stored in the cache.
295
+
296
+ Args:
297
+ is_prefill: A boolean indicating if the mode is 'prefill'.
298
+ kv_cache: The key-value cache to be updated and used.
299
+ q_TNH: Query tensor of shape `(query_seq, num_attention_heads, head_dim)`.
300
+ k_SKH: Key tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
301
+ v_SKH: Value tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
302
+ attention_metadata: Metadata containing sequence lengths.
303
+ mesh: The JAX device mesh (unused in this specific function but
304
+ kept for potential future use or API consistency).
305
+ q_scale: Quantization scale for q.
306
+ k_scale: Quantization scale for k.
307
+ v_scale: Quantization scale for v.
308
+
309
+ Returns:
310
+ A tuple containing:
311
+ - The updated KV cache.
312
+ - The attention output tensor of shape
313
+ `(seq, num_q_heads, head_dim)`.
314
+ """
315
+ md = attention_metadata
316
+ in_specs = (
317
+ self.query_tnh, # q
318
+ self.keyvalue_skh, # k
319
+ self.keyvalue_skh, # v
320
+ P(None, None, "model"), # kv_cache
321
+ P(), # md.seq_lens: Replicated
322
+ P(), # page_indices_flat: Replicated
323
+ P(), # query_start_loc: Replicated
324
+ P(), # distribution: Replicated
325
+ )
326
+ out_specs = (self.attn_o_tnh, P(None, None, "model"))
327
+
328
+ def _ragged_paged_attention(*args):
329
+ return ragged_paged_attention(
330
+ *args,
331
+ sm_scale=self.scale,
332
+ q_scale=q_scale,
333
+ k_scale=k_scale,
334
+ v_scale=v_scale,
335
+ )
336
+
337
+ output_TNH, kv_cache = jax.jit(
338
+ shard_map.shard_map(
339
+ _ragged_paged_attention,
340
+ mesh=mesh,
341
+ in_specs=in_specs,
342
+ out_specs=out_specs,
343
+ check_rep=False,
344
+ ))(
345
+ q_TNH,
346
+ k_SKH,
347
+ v_SKH,
348
+ kv_cache,
349
+ md.seq_lens,
350
+ md.block_tables,
351
+ md.query_start_loc,
352
+ md.request_distribution,
353
+ )
354
+ return kv_cache, output_TNH