tpu-inference 0.11.1rc1__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (50) hide show
  1. tpu_inference/kernels/collectives/__init__.py +0 -0
  2. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  3. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  4. tpu_inference/kernels/collectives/util.py +47 -0
  5. tpu_inference/layers/__init__.py +0 -0
  6. tpu_inference/layers/common/__init__.py +0 -0
  7. tpu_inference/layers/common/attention_metadata.py +34 -0
  8. tpu_inference/layers/jax/__init__.py +0 -0
  9. tpu_inference/layers/jax/attention/__init__.py +0 -0
  10. tpu_inference/layers/jax/attention/attention.py +254 -0
  11. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  12. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  13. tpu_inference/layers/jax/attention_interface.py +356 -0
  14. tpu_inference/layers/jax/base.py +151 -0
  15. tpu_inference/layers/jax/binary_search.py +295 -0
  16. tpu_inference/layers/jax/constants.py +88 -0
  17. tpu_inference/layers/jax/layers.py +301 -0
  18. tpu_inference/layers/jax/misc.py +16 -0
  19. tpu_inference/layers/jax/moe/__init__.py +0 -0
  20. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  21. tpu_inference/layers/jax/moe/moe.py +209 -0
  22. tpu_inference/layers/jax/rope.py +172 -0
  23. tpu_inference/layers/jax/rope_interface.py +214 -0
  24. tpu_inference/layers/jax/sample/__init__.py +0 -0
  25. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  26. tpu_inference/layers/jax/sample/sampling.py +95 -0
  27. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  28. tpu_inference/layers/jax/sharding.py +406 -0
  29. tpu_inference/layers/jax/transformer_block.py +76 -0
  30. tpu_inference/layers/vllm/__init__.py +0 -0
  31. tpu_inference/layers/vllm/attention.py +184 -0
  32. tpu_inference/layers/vllm/fused_moe.py +399 -0
  33. tpu_inference/layers/vllm/linear_common.py +186 -0
  34. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  35. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  36. tpu_inference/layers/vllm/quantization/common.py +105 -0
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  38. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  39. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  40. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  41. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  42. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  43. tpu_inference/layers/vllm/sharding.py +151 -0
  44. tpu_inference/models/common/__init__.py +0 -0
  45. tpu_inference/models/common/model_loader.py +433 -0
  46. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
  47. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
  48. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +0 -0
  49. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
  50. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,354 @@
1
+ import math
2
+ from dataclasses import InitVar, dataclass
3
+ from typing import Any, Tuple
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from flax import nnx
8
+ from flax.typing import Sharding
9
+ from jax.experimental import shard_map
10
+ from jax.sharding import Mesh
11
+ from jax.sharding import PartitionSpec as P
12
+
13
+ from tpu_inference import utils
14
+ from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
15
+ ragged_paged_attention
16
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
17
+ from tpu_inference.layers.jax.base import create_param
18
+ from tpu_inference.layers.jax.layers import RMSNorm
19
+ from tpu_inference.layers.jax.rope import DeepseekScalingRotaryEmbedding
20
+
21
+ KVCache = Tuple[jax.Array, jax.Array]
22
+
23
+
24
+ # TODO (wenxindongwork): Add MLA KV cache implementation. For now, cache complete KV vectors.
25
+ @dataclass(kw_only=True)
26
+ class MLA(nnx.Module):
27
+ """An implementation of Multi-Head Latent Attention as
28
+ described in the DeepSeek V3 paper.
29
+
30
+ Attributes:
31
+ mesh: The JAX device mesh for distributed computation.
32
+ """
33
+ hidden_size: int
34
+ num_attention_heads: int
35
+ num_key_value_heads: int
36
+ head_dim: int
37
+ rope_theta: float
38
+ rope_scaling: dict[str, Any]
39
+ dtype: jnp.dtype
40
+ kv_cache_dtype: str
41
+ mesh: Mesh
42
+
43
+ q_lora_rank: int
44
+ kv_lora_rank: int
45
+ qk_nope_head_dim: int
46
+ qk_rope_head_dim: int
47
+ v_head_dim: int
48
+ rms_norm_eps: float
49
+
50
+ # Sharding attributes
51
+ nhd_sharding: Sharding = ()
52
+ q_da_sharding: Sharding = ()
53
+ anh_sharding: Sharding = ()
54
+ kv_da_sharding: Sharding = ()
55
+
56
+ activation_attention_td: Sharding = ()
57
+ activation_q_td: Sharding = ()
58
+ query_tnh: P = P()
59
+ keyvalue_skh: P = P()
60
+
61
+ attn_o_tnh: P = P()
62
+ activation_attention_out_td: Sharding = ()
63
+
64
+ random_init: bool = False
65
+ attention_chunk_size: int | None = None
66
+ rope_input_ordering: str = "split"
67
+ quant: Any | None = None
68
+ rope_mscale_all_dim: float = 1.0
69
+
70
+ rngs: InitVar[nnx.Rngs]
71
+
72
+ _q_scale: float = 1
73
+ _k_scale: float = 1
74
+ _v_scale: float = 1
75
+
76
+ def __post_init__(self, rngs: nnx.Rngs):
77
+ self.N = self.num_attention_heads
78
+ self.K = self.num_key_value_heads
79
+ self.D = self.hidden_size
80
+
81
+ self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
82
+
83
+ assert self.N == self.K, "N and K must be equal for MLA"
84
+
85
+ if self.rope_scaling["factor"] <= 1.0:
86
+ yarn_mscale = 1.0
87
+ else:
88
+ yarn_mscale = 0.1 * self.rope_mscale_all_dim * math.log(
89
+ self.rope_scaling["factor"]) + 1.0
90
+ self.scale = self.qk_head_dim**-0.5 * yarn_mscale**2
91
+
92
+ self.rope = DeepseekScalingRotaryEmbedding(
93
+ rotary_dim=self.qk_rope_head_dim,
94
+ rope_theta=self.rope_theta,
95
+ original_max_position_embeddings=self.
96
+ rope_scaling["original_max_position_embeddings"],
97
+ scaling_factor=self.rope_scaling["factor"],
98
+ dtype=self.dtype,
99
+ beta_fast=self.rope_scaling["beta_fast"],
100
+ beta_slow=self.rope_scaling["beta_slow"],
101
+ mscale_value=self.rope_scaling["mscale"],
102
+ mscale_all_dim=self.rope_scaling["mscale_all_dim"],
103
+ )
104
+
105
+ # Initializes the weight kernels
106
+ self.kernel_q_down_proj_DA = create_param(rngs,
107
+ (self.D, self.q_lora_rank),
108
+ self.q_da_sharding,
109
+ self.dtype,
110
+ random_init=self.random_init)
111
+ self.kernel_q_up_proj_ANH = create_param(
112
+ rngs,
113
+ (self.q_lora_rank, self.N, self.qk_head_dim),
114
+ self.anh_sharding,
115
+ self.dtype,
116
+ random_init=self.random_init,
117
+ )
118
+ self.kernel_kv_down_proj_DA = create_param(
119
+ rngs,
120
+ (self.D, self.kv_lora_rank + self.qk_rope_head_dim),
121
+ self.kv_da_sharding,
122
+ self.dtype,
123
+ random_init=self.random_init,
124
+ )
125
+ self.kernel_kv_up_proj_ANH = create_param(
126
+ rngs,
127
+ (self.kv_lora_rank, self.N,
128
+ self.qk_nope_head_dim + self.v_head_dim),
129
+ self.anh_sharding,
130
+ self.dtype,
131
+ random_init=self.random_init,
132
+ )
133
+ self.kernel_o_proj_NHD = create_param(
134
+ rngs, (self.N, self.v_head_dim, self.D),
135
+ self.nhd_sharding,
136
+ self.dtype,
137
+ random_init=self.random_init)
138
+ self.q_rms_norm = RMSNorm(
139
+ dims=self.q_lora_rank,
140
+ epsilon=self.rms_norm_eps,
141
+ with_scale=True,
142
+ dtype=self.dtype,
143
+ random_init=self.random_init,
144
+ rngs=rngs,
145
+ )
146
+
147
+ self.kv_rms_norm = RMSNorm(
148
+ dims=self.kv_lora_rank,
149
+ random_init=self.random_init,
150
+ epsilon=self.rms_norm_eps,
151
+ with_scale=True,
152
+ dtype=self.dtype,
153
+ rngs=rngs,
154
+ )
155
+
156
+ self.kv_cache_quantized_dtype = None
157
+ if self.kv_cache_dtype != "auto":
158
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
159
+ self.kv_cache_dtype)
160
+
161
+ def __call__(self,
162
+ x,
163
+ is_prefill,
164
+ kv_cache: KVCache,
165
+ attention_metadata: AttentionMetadata,
166
+ use_attention_rope: bool = True):
167
+ """Performs the forward pass of the attention module.
168
+
169
+ Args:
170
+ x: The input tensor of shape `(batch_size, seq_len, d_model)`.
171
+ is_prefill: Whether the operation mode is prefill (otherwise it is generate).
172
+ kv_cache: The key-value cache for storing past attention states.
173
+ attention_metadata: Metadata for attention, such as input positions.
174
+
175
+ Returns:
176
+ A tuple containing:
177
+ - The updated KV cache.
178
+ - The attention output tensor of shape
179
+ `(batch_size, seq_len, d_model)`.
180
+ """
181
+ md = attention_metadata
182
+ x = jnp.asarray(x, self.dtype)
183
+ x_SD = nnx.with_sharding_constraint(x, self.activation_attention_td)
184
+ x_q_TD = nnx.with_sharding_constraint(x, self.activation_q_td)
185
+
186
+ with jax.named_scope("q_proj"):
187
+ # Query down projection.
188
+ q_TA = jnp.einsum("TD,DA -> TA", x_q_TD,
189
+ self.kernel_q_down_proj_DA.value)
190
+ q_TA = self.q_rms_norm(q_TA)
191
+ # Query up projection.
192
+ q_TNH = jnp.einsum("TA,ANH -> TNH", q_TA,
193
+ self.kernel_q_up_proj_ANH.value)
194
+ # Split the query into nope and rope.
195
+ q_nope_TNH = q_TNH[..., :self.qk_nope_head_dim]
196
+ q_rope_TNH = q_TNH[..., self.qk_nope_head_dim:]
197
+ q_rope_TNH = self.rope.apply_rope(md.input_positions, q_rope_TNH)
198
+ # Concatenate the nope and rope queries.
199
+ q_TNH = jnp.concatenate([q_nope_TNH, q_rope_TNH], axis=-1)
200
+ # Multiple the query by scaling factor
201
+ q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
202
+
203
+ with jax.named_scope("kv_proj"):
204
+ # KV down projection.
205
+ kv_SA = jnp.einsum("SD,DA -> SA", x_SD,
206
+ self.kernel_kv_down_proj_DA.value)
207
+ # Split the key and value into latent kv vector and k rope vector.
208
+ k_rope_SH = kv_SA[..., self.kv_lora_rank:]
209
+ # Reshape k_rope_BSH to include head dimension for RoPE application
210
+ k_rope_SNH = k_rope_SH[..., None, :]
211
+ k_rope_SNH = self.rope.apply_rope(md.input_positions, k_rope_SNH)
212
+ k_rope_SNH = jnp.broadcast_to(
213
+ k_rope_SNH,
214
+ (k_rope_SNH.shape[0], self.N, self.qk_rope_head_dim))
215
+ kv_SA = kv_SA[..., :self.kv_lora_rank]
216
+ kv_SA = self.kv_rms_norm(kv_SA)
217
+ # KV up projection.
218
+ kv_nope_SNH = jnp.einsum("SA,ANH -> SNH", kv_SA,
219
+ self.kernel_kv_up_proj_ANH.value)
220
+ # Split the latent kv vector into k nope vector and v vector.
221
+ k_nope_SNH = kv_nope_SNH[..., :self.qk_nope_head_dim]
222
+ v_SNH = kv_nope_SNH[..., self.qk_nope_head_dim:]
223
+ # Concatenate the key vector.
224
+ k_SNH = jnp.concatenate([k_nope_SNH, k_rope_SNH], axis=-1)
225
+ k_SNH = nnx.with_sharding_constraint(k_SNH, self.keyvalue_skh)
226
+ v_SNH = nnx.with_sharding_constraint(v_SNH, self.keyvalue_skh)
227
+
228
+ with jax.named_scope("attn_op"):
229
+ # TODO(wenxindongwork): K and V have different head dimension,
230
+ # which is not supported by the current kv cache implementation.
231
+ # For now we are padding the v dimension to match the k dimension.
232
+ # Furthermore, deepseekv3 k head dimension is 192, which is
233
+ # not supported by the current attention kernel, which expects
234
+ # q, k, v head dimension to be multiple of 128. For now, we will
235
+ # pad the q, k, v dimension to multiple of 128.
236
+ # We should update the MLA kv cache implementation in the future.
237
+ multiple_of_128 = ((self.qk_head_dim - 1) // 128 + 1) * 128
238
+ q_TNH = jnp.pad(q_TNH, ((0, 0), (0, 0),
239
+ (0, multiple_of_128 - self.qk_head_dim)))
240
+ k_SNH = jnp.pad(k_SNH, ((0, 0), (0, 0),
241
+ (0, multiple_of_128 - self.qk_head_dim)))
242
+ v_SNH = jnp.pad(v_SNH, ((0, 0), (0, 0),
243
+ (0, multiple_of_128 - self.v_head_dim)))
244
+ q_scale = k_scale = v_scale = None
245
+ if self.kv_cache_quantized_dtype:
246
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
247
+ # q_scale = self._q_scale
248
+ k_scale = self._k_scale
249
+ v_scale = self._v_scale
250
+ k_SNH, v_SNH = utils.quantize_kv(k_SNH, v_SNH,
251
+ self.kv_cache_quantized_dtype,
252
+ k_scale, v_scale)
253
+ new_kv_cache, outputs_TNH = self.attention(
254
+ is_prefill,
255
+ kv_cache,
256
+ q_TNH,
257
+ k_SNH,
258
+ v_SNH,
259
+ attention_metadata,
260
+ self.mesh,
261
+ q_scale,
262
+ k_scale,
263
+ v_scale,
264
+ )
265
+ # TODO(wenxindongwork): For now, unpad the outputs_TNH to match the v_head_dim.
266
+ # We shall add the MLA kv cache implementation in the future.
267
+ outputs_TNH = outputs_TNH[..., :self.v_head_dim]
268
+
269
+ with jax.named_scope("o_proj"):
270
+ o_TD = jnp.einsum("TNH,NHD -> TD", outputs_TNH,
271
+ self.kernel_o_proj_NHD.value)
272
+ o_TD = nnx.with_sharding_constraint(
273
+ o_TD, self.activation_attention_out_td)
274
+ return new_kv_cache, o_TD
275
+
276
+ def attention(
277
+ self,
278
+ is_prefill: bool,
279
+ kv_cache: KVCache,
280
+ q_TNH: jax.Array,
281
+ k_SKH: jax.Array,
282
+ v_SKH: jax.Array,
283
+ attention_metadata: AttentionMetadata,
284
+ mesh: Mesh,
285
+ q_scale: float | None = None,
286
+ k_scale: float | None = None,
287
+ v_scale: float | None = None,
288
+ ) -> Tuple[KVCache, jax.Array]:
289
+ """Performs scaled dot-product attention and updates the KV cache.
290
+
291
+ This function handles the core attention logic, which varies between
292
+ prefill and generation modes. In prefill, it computes self-attention
293
+ over the input sequence with a causal mask. In generation, it attends
294
+ to the full history of keys and values stored in the cache.
295
+
296
+ Args:
297
+ is_prefill: A boolean indicating if the mode is 'prefill'.
298
+ kv_cache: The key-value cache to be updated and used.
299
+ q_TNH: Query tensor of shape `(query_seq, num_attention_heads, head_dim)`.
300
+ k_SKH: Key tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
301
+ v_SKH: Value tensor of shape `(kv_seq, num_key_value_heads, head_dim)`.
302
+ attention_metadata: Metadata containing sequence lengths.
303
+ mesh: The JAX device mesh (unused in this specific function but
304
+ kept for potential future use or API consistency).
305
+ q_scale: Quantization scale for q.
306
+ k_scale: Quantization scale for k.
307
+ v_scale: Quantization scale for v.
308
+
309
+ Returns:
310
+ A tuple containing:
311
+ - The updated KV cache.
312
+ - The attention output tensor of shape
313
+ `(seq, num_q_heads, head_dim)`.
314
+ """
315
+ md = attention_metadata
316
+ in_specs = (
317
+ self.query_tnh, # q
318
+ self.keyvalue_skh, # k
319
+ self.keyvalue_skh, # v
320
+ P(None, None, "model"), # kv_cache
321
+ P(), # md.seq_lens: Replicated
322
+ P(), # page_indices_flat: Replicated
323
+ P(), # query_start_loc: Replicated
324
+ P(), # distribution: Replicated
325
+ )
326
+ out_specs = (self.attn_o_tnh, P(None, None, "model"))
327
+
328
+ def _ragged_paged_attention(*args):
329
+ return ragged_paged_attention(
330
+ *args,
331
+ sm_scale=self.scale,
332
+ q_scale=q_scale,
333
+ k_scale=k_scale,
334
+ v_scale=v_scale,
335
+ )
336
+
337
+ output_TNH, kv_cache = jax.jit(
338
+ shard_map.shard_map(
339
+ _ragged_paged_attention,
340
+ mesh=mesh,
341
+ in_specs=in_specs,
342
+ out_specs=out_specs,
343
+ check_rep=False,
344
+ ))(
345
+ q_TNH,
346
+ k_SKH,
347
+ v_SKH,
348
+ kv_cache,
349
+ md.seq_lens,
350
+ md.block_tables,
351
+ md.query_start_loc,
352
+ md.request_distribution,
353
+ )
354
+ return kv_cache, output_TNH
@@ -0,0 +1,153 @@
1
+ from dataclasses import dataclass
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import nnx
6
+ from jax.sharding import Sharding
7
+
8
+ from tpu_inference import utils
9
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
10
+ from tpu_inference.layers.jax.attention.attention import Attention, KVCache
11
+ from tpu_inference.layers.jax.rope_interface import apply_rope
12
+ from tpu_inference.logger import init_logger
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ class L2Norm(nnx.Module):
18
+ """
19
+ Implementation of L2 Norm in JAX (taken from MaxText repo - maxtext/MaxText/layers/attentions.py).
20
+
21
+ Attributes:
22
+ eps: float, epsilon used for numerical stability (default value should be ok for most cases).
23
+ """
24
+
25
+ def __init__(self, eps: float = 1e-6):
26
+ self.eps = eps
27
+
28
+ def __call__(self, x):
29
+ return x * jax.lax.rsqrt(
30
+ jnp.mean(x**2, axis=-1, keepdims=True) + self.eps)
31
+
32
+
33
+ @dataclass(kw_only=True)
34
+ class Llama4Attention(Attention):
35
+ use_qk_norm: bool
36
+ temperature_tuning: bool
37
+ temperature_tuning_floor_scale: float
38
+ temperature_tuning_scale: float
39
+ activation_attention_td: Sharding
40
+ activation_attention_out_td: Sharding
41
+
42
+ def __call__(self,
43
+ x,
44
+ is_prefill,
45
+ kv_cache: KVCache,
46
+ attention_metadata: AttentionMetadata,
47
+ use_attention_rope: bool = True):
48
+ """Performs the forward pass of the attention module.
49
+
50
+ This method computes the attention output by projecting the input `x`
51
+ to queries, keys, and values, applying RoPE and L2Norm if specified,
52
+ performing scaled dot-product attention, and projecting the results
53
+ back to the model dimension.
54
+ If no RoPE (NoPE) is specified, one can also perform temperature tuning
55
+ which is useful to combat dilution of attention scores in long-context attention.
56
+
57
+ Args:
58
+ x: The input tensor of shape `(seq_len, d_model)`.
59
+ is_prefill: Whether the operation mode is prefill (otherwise it is generate).
60
+ kv_cache: The key-value cache for storing past attention states.
61
+ attention_metadata: Metadata for attention, such as input positions.
62
+ use_attention_rope: Whether to use RoPE.
63
+
64
+ Returns:
65
+ A tuple containing:
66
+ - The updated KV cache.
67
+ - The attention output tensor of shape
68
+ `(batch_size, seq_len, d_model)`.
69
+ """
70
+ md = attention_metadata
71
+ x = jnp.asarray(x, self.dtype)
72
+ x_SD = nnx.with_sharding_constraint(x, self.activation_attention_td)
73
+ x_q_TD = nnx.with_sharding_constraint(x, self.activation_q_td)
74
+ rope_scaling = self.rope_scaling
75
+ rope_theta = self.rope_theta
76
+ H = self.head_dim
77
+ l2_norm = L2Norm()
78
+
79
+ with jax.named_scope("q_proj"):
80
+ q_TNH = jnp.einsum('TD,DNH -> TNH', x_q_TD,
81
+ self.kernel_q_proj_DNH.value)
82
+ if use_attention_rope:
83
+ q_TNH = apply_rope(q_TNH, md.input_positions, H, rope_theta,
84
+ rope_scaling, self.rope_input_ordering)
85
+
86
+ # Apply normaliation after RoPE
87
+ if self.use_qk_norm:
88
+ q_TNH = l2_norm(q_TNH)
89
+ else:
90
+ if self.temperature_tuning:
91
+ q_TNH = self.apply_temperature_tuning(md, q_TNH)
92
+
93
+ q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
94
+ with jax.named_scope("k_proj"):
95
+ k_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
96
+ self.kernel_k_proj_DKH.value)
97
+ if use_attention_rope:
98
+ k_SKH = apply_rope(k_SKH, md.input_positions, H, rope_theta,
99
+ rope_scaling, self.rope_input_ordering)
100
+
101
+ # Apply normaliation after RoPE
102
+ if self.use_qk_norm:
103
+ k_SKH = l2_norm(k_SKH)
104
+ k_SKH = nnx.with_sharding_constraint(k_SKH, self.keyvalue_skh)
105
+
106
+ with jax.named_scope("v_proj"):
107
+ v_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
108
+ self.kernel_v_proj_DKH.value)
109
+ v_SKH = nnx.with_sharding_constraint(v_SKH, self.keyvalue_skh)
110
+
111
+ q_scale = k_scale = v_scale = None
112
+ if self.kv_cache_quantized_dtype:
113
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
114
+ # q_scale = self._q_scale
115
+ k_scale = self._k_scale
116
+ v_scale = self._v_scale
117
+ k_SKH, v_SKH = utils.quantize_kv(k_SKH, v_SKH,
118
+ self.kv_cache_quantized_dtype,
119
+ k_scale, v_scale)
120
+
121
+ with jax.named_scope("attn_op"):
122
+ new_kv_cache, outputs_TNH = self.attention(
123
+ is_prefill,
124
+ kv_cache,
125
+ q_TNH,
126
+ k_SKH,
127
+ v_SKH,
128
+ attention_metadata,
129
+ self.mesh,
130
+ q_scale=q_scale,
131
+ k_scale=k_scale,
132
+ v_scale=v_scale,
133
+ )
134
+
135
+ with jax.named_scope("o_proj"):
136
+ o_TD = jnp.einsum('TNH,NHD -> TD', outputs_TNH,
137
+ self.kernel_o_proj_NHD.value)
138
+ o_TD = nnx.with_sharding_constraint(
139
+ o_TD, self.activation_attention_out_td)
140
+ return new_kv_cache, o_TD
141
+
142
+ def apply_temperature_tuning(self, md: AttentionMetadata,
143
+ input_arr_TNH: jax.Array) -> jax.Array:
144
+ """Applies temperature tuning to the input array of shape (T, N, H).
145
+ Args:
146
+ md: AttentionMetadata object containing the input positions.
147
+ input_arr_TNH: Input array of shape (T, N, H) which will have scaled temperatures applied.
148
+ """
149
+ attn_scales = (jnp.log(
150
+ jnp.floor((md.input_positions.astype(self.dtype) + 1.0) /
151
+ self.temperature_tuning_floor_scale) + 1.0) *
152
+ self.temperature_tuning_scale + 1.0)
153
+ return input_arr_TNH * attn_scales[:, None, None]