tpu-inference 0.11.1rc1__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (50) hide show
  1. tpu_inference/kernels/collectives/__init__.py +0 -0
  2. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  3. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  4. tpu_inference/kernels/collectives/util.py +47 -0
  5. tpu_inference/layers/__init__.py +0 -0
  6. tpu_inference/layers/common/__init__.py +0 -0
  7. tpu_inference/layers/common/attention_metadata.py +34 -0
  8. tpu_inference/layers/jax/__init__.py +0 -0
  9. tpu_inference/layers/jax/attention/__init__.py +0 -0
  10. tpu_inference/layers/jax/attention/attention.py +254 -0
  11. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  12. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  13. tpu_inference/layers/jax/attention_interface.py +356 -0
  14. tpu_inference/layers/jax/base.py +151 -0
  15. tpu_inference/layers/jax/binary_search.py +295 -0
  16. tpu_inference/layers/jax/constants.py +88 -0
  17. tpu_inference/layers/jax/layers.py +301 -0
  18. tpu_inference/layers/jax/misc.py +16 -0
  19. tpu_inference/layers/jax/moe/__init__.py +0 -0
  20. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  21. tpu_inference/layers/jax/moe/moe.py +209 -0
  22. tpu_inference/layers/jax/rope.py +172 -0
  23. tpu_inference/layers/jax/rope_interface.py +214 -0
  24. tpu_inference/layers/jax/sample/__init__.py +0 -0
  25. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  26. tpu_inference/layers/jax/sample/sampling.py +95 -0
  27. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  28. tpu_inference/layers/jax/sharding.py +406 -0
  29. tpu_inference/layers/jax/transformer_block.py +76 -0
  30. tpu_inference/layers/vllm/__init__.py +0 -0
  31. tpu_inference/layers/vllm/attention.py +184 -0
  32. tpu_inference/layers/vllm/fused_moe.py +399 -0
  33. tpu_inference/layers/vllm/linear_common.py +186 -0
  34. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  35. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  36. tpu_inference/layers/vllm/quantization/common.py +105 -0
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  38. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  39. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  40. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  41. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  42. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  43. tpu_inference/layers/vllm/sharding.py +151 -0
  44. tpu_inference/models/common/__init__.py +0 -0
  45. tpu_inference/models/common/model_loader.py +433 -0
  46. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
  47. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
  48. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +0 -0
  49. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
  50. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,209 @@
1
+ from dataclasses import InitVar, dataclass
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import nnx
6
+ from flax.typing import Sharding
7
+ from jaxtyping import Float
8
+
9
+ from tpu_inference.layers.jax.base import create_param
10
+ from tpu_inference.layers.jax.layers import FlaxUtils
11
+
12
+ modeling_flax_utils = FlaxUtils()
13
+
14
+
15
+ @dataclass(kw_only=True)
16
+ class Router(nnx.Module):
17
+ """Router module for Mixture-of-Experts (MoE) layers.
18
+
19
+ This module determines which experts each token should be routed to based on the input.
20
+
21
+ Attributes:
22
+ """
23
+ dtype: jnp.dtype
24
+ hidden_size: int
25
+ num_experts: int
26
+ num_experts_per_tok: int
27
+ router_act: str
28
+ rngs: InitVar[nnx.Rngs]
29
+ activation_ffw_td: Sharding
30
+ ed_sharding: Sharding
31
+ random_init: bool = False
32
+
33
+ def __call__(self, x_TD: Float):
34
+ """Routes tokens to experts.
35
+
36
+ Args:
37
+ x_TD: Input array of shape (sequence_length, d_model).
38
+
39
+ Returns:
40
+ A tuple containing:
41
+ - normalized_weights_TX: Normalized weights for selected experts, shape (sequence_length, num_experts_per_tok).
42
+ - selected_experts_TX: Indices of selected experts, shape (sequence_length, num_experts_per_tok).
43
+ """
44
+ x_TD = jnp.asarray(x_TD, self.dtype)
45
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
46
+ router_act = modeling_flax_utils.ACT2FN[self.router_act]
47
+ router_logits_TE = jnp.einsum('TD,DE -> TE', x_TD,
48
+ self.kernel_DE.value)
49
+ weights_TX, selected_experts_TX = jax.lax.top_k(
50
+ router_logits_TE, self.num_experts_per_tok)
51
+ if self.router_act != "sigmoid": # sigmoid does not accept axis argument.
52
+ normalized_weights_TX = router_act(weights_TX.astype(self.dtype),
53
+ axis=-1)
54
+ else:
55
+ normalized_weights_TX = router_act(weights_TX.astype(self.dtype))
56
+ return normalized_weights_TX, selected_experts_TX
57
+
58
+ def __post_init__(self, rngs: nnx.Rngs):
59
+ """Generates the router kernel (weights) for routing."""
60
+ shape = (self.hidden_size, self.num_experts)
61
+ self.kernel_DE = create_param(rngs,
62
+ shape=shape,
63
+ dtype=self.dtype,
64
+ sharding=self.ed_sharding,
65
+ random_init=self.random_init)
66
+
67
+
68
+ @dataclass(kw_only=True)
69
+ class MoE(nnx.Module):
70
+ """Mixture-of-Experts (MoE) Routed MLP Layer.
71
+
72
+ This module implements a MoE layer with a router and multiple expert MLPs.
73
+
74
+ Attributes:
75
+ router: The Router module.
76
+ """
77
+ dtype: jnp.dtype
78
+ num_local_experts: int
79
+ apply_expert_weight_before_computation: bool
80
+ hidden_size: int
81
+ intermediate_size_moe: int
82
+ hidden_act: str
83
+ rngs: InitVar[nnx.Rngs]
84
+ router: nnx.Module
85
+ activation_ffw_td: Sharding
86
+ activation_ffw_ted: Sharding
87
+ edf_sharding: Sharding
88
+ efd_sharding: Sharding
89
+ random_init: bool = False
90
+
91
+ def __call__(self, x_TD: Float):
92
+ """Performs the forward pass of the MoE layer.
93
+
94
+ Args:
95
+ x_TD: Input array of shape (sequence_length, d_model).
96
+
97
+ Returns:
98
+ Output array of shape (sequence_length, d_model) after passing through MoE.
99
+ """
100
+ x_TD = jnp.asarray(x_TD, self.dtype)
101
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
102
+ weights_TX, indices_TX = self.router(x_TD)
103
+ one_hot_indices_TXE = jax.nn.one_hot(
104
+ indices_TX, num_classes=self.num_local_experts, dtype=self.dtype)
105
+ full_weights_TE = jnp.sum(one_hot_indices_TXE * weights_TX[..., None],
106
+ axis=1)
107
+
108
+ # Some models use the routing scores to weight the data instead of
109
+ # weighting the expert outputs.
110
+ if self.apply_expert_weight_before_computation:
111
+ with jax.named_scope("pre_computing_weight"):
112
+ return self._moe_fwd_preapply_router_weights(
113
+ x_TD, full_weights_TE)
114
+ else:
115
+ return self._moe_fwd(x_TD, full_weights_TE)
116
+
117
+ def __post_init__(self, rngs: nnx.Rngs):
118
+ """Generates the kernels (weights) for the router and experts (gating, up-projection, and down-projection layers)."""
119
+
120
+ D = self.hidden_size
121
+ F = self.intermediate_size_moe
122
+ shape_gating = (self.num_local_experts, D, F)
123
+ shape_up = (self.num_local_experts, D, F)
124
+ shape_down = (self.num_local_experts, F, D)
125
+
126
+ self.kernel_gating_EDF = create_param(rngs,
127
+ shape=shape_gating,
128
+ dtype=self.dtype,
129
+ sharding=self.edf_sharding,
130
+ random_init=self.random_init)
131
+ self.kernel_up_proj_EDF = create_param(rngs,
132
+ shape=shape_up,
133
+ dtype=self.dtype,
134
+ sharding=self.edf_sharding,
135
+ random_init=self.random_init)
136
+ self.kernel_down_proj_EFD = create_param(rngs,
137
+ shape=shape_down,
138
+ dtype=self.dtype,
139
+ sharding=self.efd_sharding,
140
+ random_init=self.random_init)
141
+
142
+ def _moe_fwd_preapply_router_weights(self, x_TD: jax.Array, weights_TE):
143
+ """Performs the forward pass of the MoE experts with router weights pre-applied to the inputs.
144
+
145
+ Args:
146
+ x_TD: Input array for the experts, shape (sequence_length, hidden_size).
147
+ weights_TE: Router weights, shape (sequence_length, num_experts).
148
+
149
+ Returns:
150
+ Output array of shape (sequence_length, d_model).
151
+ """
152
+ # Data needs to be replicated since it will be weighted by the router
153
+ # scores before being passed to each expert.
154
+ num_experts = weights_TE.shape[-1]
155
+ x_TED = jnp.repeat(x_TD[:, None, :], num_experts, 1)
156
+ weights_TED = weights_TE[..., None]
157
+ x_TED = jnp.asarray(x_TED, self.dtype)
158
+
159
+ with jax.named_scope("activation_expert_weighting"):
160
+ x_TED = x_TED * weights_TED
161
+
162
+ x_TED = nnx.with_sharding_constraint(x_TED, self.activation_ffw_ted)
163
+ with jax.named_scope("gating"):
164
+ gating_TEF = jnp.einsum('TED,EDF -> TEF', x_TED,
165
+ self.kernel_gating_EDF.value)
166
+ activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
167
+ gating_TEF)
168
+ with jax.named_scope("up_projection"):
169
+ up_proj_TEF = jnp.einsum('TED,EDF -> TEF', x_TED,
170
+ self.kernel_up_proj_EDF.value)
171
+
172
+ fuse_TEF = activated_gating_TEF * up_proj_TEF
173
+
174
+ with jax.named_scope("down_projection"):
175
+ down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
176
+ self.kernel_down_proj_EFD.value)
177
+ with jax.named_scope("sum"):
178
+ output_TD = down_proj_TED.sum(axis=1)
179
+ return output_TD.astype(self.dtype)
180
+
181
+ def _moe_fwd(self, x_TD: Float, weights):
182
+ """Performs the basic forward pass of the MoE experts without dropping or megablocks.
183
+
184
+ Args:
185
+ x_TD: Input array for the experts, shape (sequence_length, d_model).
186
+ weights: Weights for combining expert outputs, shape (sequence_length, num_experts).
187
+
188
+ Returns:
189
+ Output array of shape (sequence_length, d_model).
190
+ """
191
+ x_TD = jnp.asarray(x_TD, self.dtype)
192
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
193
+ with jax.named_scope("gating"):
194
+ gating_TEF = jnp.einsum('TD,EDF -> TEF', x_TD,
195
+ self.kernel_gating_EDF.value)
196
+ activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
197
+ gating_TEF)
198
+ with jax.named_scope("up_projection"):
199
+ up_proj_TEF = jnp.einsum('TD,EDF -> TEF', x_TD,
200
+ self.kernel_up_proj_EDF.value)
201
+
202
+ fuse_TEF = activated_gating_TEF * up_proj_TEF
203
+
204
+ with jax.named_scope("down_projection"):
205
+ down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
206
+ self.kernel_down_proj_EFD.value)
207
+ with jax.named_scope("sum"):
208
+ output_TD = jnp.einsum('TED,TE -> TD', down_proj_TED, weights)
209
+ return output_TD.astype(self.dtype)
@@ -0,0 +1,172 @@
1
+ import math
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional
4
+
5
+ import jax
6
+ from flax import nnx
7
+ from jax import numpy as jnp
8
+
9
+
10
+ @dataclass(kw_only=True)
11
+ class RotaryEmbedding(nnx.Module):
12
+ """
13
+ An implementation of the original rotary positional embedding.
14
+ """
15
+ rotary_dim: int
16
+ rope_theta: float
17
+ original_max_position_embeddings: int
18
+ dtype: jnp.dtype
19
+ sin_cos_cache: Optional[jax.Array] = field(init=False, default=None)
20
+
21
+ def initialize_cache(self):
22
+ """Computes and caches the sin/cos embeddings."""
23
+ if self.sin_cos_cache is None:
24
+ self.sin_cos_cache = self._compute_sin_cos()
25
+
26
+ def _compute_inv_freq(self):
27
+ fractions_H = jnp.arange(0, self.rotary_dim, 2,
28
+ dtype=jnp.float32) / self.rotary_dim
29
+ inv_freq_H = 1.0 / (self.rope_theta**fractions_H)
30
+ return inv_freq_H
31
+
32
+ def _compute_sin_cos(self):
33
+ inv_freq_H = self._compute_inv_freq()
34
+ t = jnp.arange(self.original_max_position_embeddings,
35
+ dtype=jnp.float32)
36
+
37
+ freqs = jnp.einsum("...T,k->...Tk",
38
+ t,
39
+ inv_freq_H,
40
+ precision=jax.lax.Precision.HIGHEST)
41
+ sin, cos = jnp.sin(freqs), jnp.cos(freqs)
42
+ cache = jnp.concatenate((cos, sin), axis=-1)
43
+ return cache
44
+
45
+ def apply_rope(self, positions: jax.Array, x_TNH: jax.Array):
46
+ assert x_TNH.ndim == 3
47
+ assert self.sin_cos_cache is not None, "RoPE cache not initialized."
48
+ cos_sin_TH = self.sin_cos_cache[positions]
49
+ # cos, sin: (T, H/2)
50
+ cos_TH, sin_TH = jnp.split(cos_sin_TH, 2, axis=-1)
51
+ assert sin_TH.ndim == 2 and cos_TH.ndim == 2
52
+ # cos, sin: (T, 1, H/2)
53
+ cos_T1H, sin_T1H = cos_TH[:, None, :], sin_TH[:, None, :]
54
+ # first_half, second_half: (T, N, H/2)
55
+ first_half_TNH, second_half_TNH = jnp.split(x_TNH, 2, axis=-1)
56
+ combined = jnp.concatenate([
57
+ first_half_TNH * cos_T1H - second_half_TNH * sin_T1H,
58
+ second_half_TNH * cos_T1H + first_half_TNH * sin_T1H
59
+ ],
60
+ axis=-1)
61
+ return combined.astype(self.dtype)
62
+
63
+
64
+ @dataclass(kw_only=True)
65
+ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
66
+ """
67
+ Rotary Embedding for deepseek, with scaling and YaRN method.
68
+ """
69
+ scaling_factor: float
70
+ beta_fast: int = 32
71
+ beta_slow: int = 1
72
+ mscale_value: float = 1
73
+ mscale_all_dim: float = 0
74
+
75
+ def initialize_cache(self):
76
+ """Computes and caches the sin/cos embeddings."""
77
+ # The second condition is for the Qwix case, where we need to call `initialize_cache` on
78
+ # the abstract model. Thus, when we go to call `initialize_cache` on the concrete model,
79
+ # this method will have been called already, but we need to recompute the cache so that
80
+ # it's concrete (otherwise, it'll still be a jax.ShapeDtypeStruct).
81
+ if self.sin_cos_cache is not None and not isinstance(
82
+ self.sin_cos_cache, jax.ShapeDtypeStruct):
83
+ return
84
+ self.mscale = _yarn_get_mscale(
85
+ self.scaling_factor, self.mscale_value) / _yarn_get_mscale(
86
+ self.scaling_factor, self.mscale_all_dim)
87
+ self.sin_cos_cache = self._compute_sin_cos()
88
+
89
+ def _compute_inv_freq(self):
90
+ fractions = jnp.arange(0, self.rotary_dim, 2,
91
+ dtype=jnp.float32) / self.rotary_dim
92
+ inv_freq_extrapolation = 1.0 / (self.rope_theta**fractions)
93
+ inv_freq_interpolation = 1.0 / (self.scaling_factor *
94
+ self.rope_theta**fractions)
95
+ low, high = _yarn_find_correction_range(
96
+ self.beta_fast, self.beta_slow, self.rotary_dim, self.rope_theta,
97
+ self.original_max_position_embeddings)
98
+
99
+ # Get n-d rotational scaling corrected for extrapolation
100
+ inv_freq_mask = 1 - _yarn_linear_ramp_mask(
101
+ low, high, self.rotary_dim // 2).astype(jnp.float32)
102
+ inv_freq = inv_freq_interpolation * (
103
+ 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
104
+ return inv_freq
105
+
106
+ def _compute_sin_cos(self):
107
+ inv_freq_H = self._compute_inv_freq()
108
+ t = jnp.arange(self.original_max_position_embeddings *
109
+ self.scaling_factor,
110
+ dtype=jnp.float32)
111
+ freqs = jnp.einsum("...T,k->...Tk", t, inv_freq_H)
112
+ sin, cos = jnp.sin(freqs) * self.mscale, jnp.cos(freqs) * self.mscale
113
+ cache = jnp.concatenate((cos, sin), axis=-1)
114
+ return cache
115
+
116
+ def apply_rope(self, positions: jax.Array, x_TNH: jax.Array):
117
+ assert x_TNH.ndim == 3
118
+ assert self.sin_cos_cache is not None, "RoPE cache not initialized."
119
+ cos_sin_TH = self.sin_cos_cache[positions]
120
+ # cos, sin: (T, H/2)
121
+ cos_TH, sin_TH = jnp.split(cos_sin_TH, 2, axis=-1)
122
+ assert sin_TH.ndim == 2 and cos_TH.ndim == 2
123
+ # cos, sin: (T, 1, H/2)
124
+ cos_T1H, sin_T1H = cos_TH[:, None, :], sin_TH[:, None, :]
125
+ # even, odd: (T, N, H/2)
126
+ even_TNH, odd_TNH = x_TNH[..., ::2], x_TNH[..., 1::2]
127
+ combined_TNH = jnp.stack([
128
+ even_TNH * cos_T1H - odd_TNH * sin_T1H,
129
+ odd_TNH * cos_T1H + even_TNH * sin_T1H
130
+ ],
131
+ axis=-1).reshape(x_TNH.shape)
132
+ return combined_TNH.astype(self.dtype)
133
+
134
+
135
+ # Calculates the temperature scaling factor for YaRN to adjust
136
+ # RoPE embedding magnitudes.
137
+ def _yarn_get_mscale(scale, mscale):
138
+ return jnp.where(scale <= 1, 1.0, 0.1 * mscale * jnp.log(scale) + 1.0)
139
+
140
+
141
+ # Inverses dim formula to find dim based on number of rotations.
142
+ def _yarn_find_correction_dim(num_rotations,
143
+ dim,
144
+ base=10000,
145
+ max_position_embeddings=2048):
146
+ return (dim * math.log(max_position_embeddings /
147
+ (num_rotations * 2 * math.pi))) / (2 *
148
+ math.log(base))
149
+
150
+
151
+ # Finds dim range bounds based on rotations.
152
+ def _yarn_find_correction_range(low_rot,
153
+ high_rot,
154
+ dim,
155
+ base=10000,
156
+ max_position_embeddings=2048):
157
+ low = math.floor(
158
+ _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
159
+ high = math.ceil(
160
+ _yarn_find_correction_dim(high_rot, dim, base,
161
+ max_position_embeddings))
162
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
163
+
164
+
165
+ # Creates a 1D mask that ramps linearly from 0 to 1 between min and max indices.
166
+ def _yarn_linear_ramp_mask(min, max, dim):
167
+ if min == max:
168
+ max += 0.001 # Prevent singularity
169
+
170
+ linear_func = (jnp.arange(dim, dtype=jnp.float32) - min) / (max - min)
171
+ ramp_func = jnp.clip(linear_func, 0, 1)
172
+ return ramp_func
@@ -0,0 +1,214 @@
1
+ import math
2
+ from typing import Any, Dict
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+
8
+ def apply_rope(
9
+ # (seq_len, num_heads, head_dim)
10
+ inputs: jax.Array,
11
+ # (3, seq_len) for M-RoPE, otherwise (seq_len,)
12
+ positions: jax.Array,
13
+ head_dim: int,
14
+ rope_theta: float = 10000,
15
+ rope_scaling: Dict[str, Any] = None,
16
+ rope_input_ordering: str = "split",
17
+ ) -> jax.Array:
18
+ """
19
+ Applies Rotary Positional Embedding using the sine and cosine strategy.
20
+
21
+ This implementation assumes the input tensor has a shape that might include
22
+ padding on the last dimension (head_dim).
23
+ RoPE is applied only to the first `head_dim` features, and the result is
24
+ padded back to the original dimension if necessary.
25
+ If rope_input_ordering is "split", then the input pairs for rotation are taken one from the
26
+ first and one from the second half of the head_dim. If it is "interleaved" then
27
+ adjacent values are used as inputs for rotation.
28
+ """
29
+
30
+ # M-RoPE support for Qwen2.5-VL
31
+ if positions.ndim == 2 and positions.shape[0] == 3:
32
+ mrope_section = rope_scaling.get("mrope_section",
33
+ None) if rope_scaling else None
34
+ # NOTE: We assume mrope_section is always available
35
+ # as Qwen2.5-VL is the only model using mrope
36
+ assert mrope_section is not None
37
+
38
+ split_indices = [mrope_section[0], mrope_section[0] + mrope_section[1]]
39
+
40
+ # Indices for the features to be rotated (first half of head_dim)
41
+ all_freq_indices = jnp.arange(head_dim // 2)
42
+
43
+ # Split the indices according to mrope_section. This is valid because split_indices are static.
44
+ freq_indices_split = jnp.split(all_freq_indices, split_indices)
45
+ # freq_indices_split is a list of 3 JAX arrays.
46
+
47
+ cos_list = []
48
+ sin_list = []
49
+
50
+ for i in range(3): # For each of the 3 position dimensions
51
+ current_indices = freq_indices_split[i]
52
+
53
+ if current_indices.size == 0:
54
+ # This section is empty, skip.
55
+ continue
56
+
57
+ # inv_freq shape: (mrope_section[i],)
58
+ inv_freq = 1.0 / (rope_theta**(current_indices * 2.0 / head_dim))
59
+
60
+ # positions[i]: (seq_len,)
61
+ # freqs shape: (seq_len, mrope_section[i])
62
+ freqs = jnp.outer(positions[i], inv_freq)
63
+
64
+ cos_list.append(jnp.cos(freqs))
65
+ sin_list.append(jnp.sin(freqs))
66
+
67
+ # Concatenate along the feature dimension
68
+ # cos, sin shape: (seq_len, head_dim//2)
69
+ cos = jnp.concatenate(cos_list, axis=1)
70
+ sin = jnp.concatenate(sin_list, axis=1)
71
+
72
+ # Add num_heads dimension for broadcasting
73
+ cos = cos[:, jnp.newaxis, :] # Shape: (seq_len, 1, head_dim//2)
74
+ sin = sin[:, jnp.newaxis, :] # Shape: (seq_len, 1, head_dim//2)
75
+
76
+ # Apply rotation
77
+ inputs_real = inputs[..., :head_dim // 2]
78
+ inputs_imag = inputs[..., head_dim // 2:head_dim]
79
+
80
+ outputs_real = inputs_real * cos - inputs_imag * sin
81
+ outputs_imag = inputs_real * sin + inputs_imag * cos
82
+
83
+ out = jnp.concatenate([outputs_real, outputs_imag], axis=-1)
84
+
85
+ # Standard RoPE
86
+ else:
87
+ # Calculate inverse frequencies (timescale)
88
+ fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
89
+ timescale = 1.0 / (rope_theta**fraction)
90
+
91
+ # Apply scaling if provided
92
+ if rope_scaling:
93
+ timescale = apply_rope_scaling(timescale, rope_scaling)
94
+
95
+ # Prepare for rotation by calculating sin and cos values
96
+ # `sinusoid_inp` gets shape (batch * seq_len, head_dim/2)
97
+ sinusoid_inp = positions[..., jnp.newaxis] * timescale[jnp.newaxis, :]
98
+
99
+ # Broadcast over the 'heads' dimension, assuming shape (batch*seq, heads, head_dim)
100
+ sinusoid_inp = sinusoid_inp[:, jnp.newaxis, ...]
101
+ sin = jnp.sin(sinusoid_inp)
102
+ cos = jnp.cos(sinusoid_inp)
103
+
104
+ if rope_input_ordering == "interleaved":
105
+ # Reshape to group adjacent features for rotation, matching new_apply_rope
106
+ rotary_inputs = inputs[
107
+ ..., :head_dim] # Take just the non-padded amount.
108
+ reshaped_inputs = rotary_inputs.reshape(*rotary_inputs.shape[:-1],
109
+ -1, 2)
110
+
111
+ # Apply the rotation
112
+ first_half = reshaped_inputs[..., 0]
113
+ second_half = reshaped_inputs[..., 1]
114
+ else:
115
+ first_half = inputs[..., :head_dim // 2]
116
+ second_half = inputs[..., head_dim // 2:head_dim]
117
+
118
+ first_part = first_half * cos - second_half * sin
119
+ second_part = second_half * cos + first_half * sin
120
+
121
+ # Combine the rotated parts and reshape back
122
+ if rope_input_ordering == "interleaved":
123
+ out_stacked = jnp.stack([first_part, second_part], axis=-1)
124
+ out = out_stacked.reshape(rotary_inputs.shape)
125
+ else:
126
+ out = jnp.concatenate([first_part, second_part], axis=-1)
127
+
128
+ # If the original input was padded, pad the output with zeros to match.
129
+ padded_head_dim = inputs.shape[-1]
130
+ if padded_head_dim > head_dim:
131
+ pad_width = padded_head_dim - head_dim
132
+ pad_config = [(0, 0)] * (out.ndim - 1) + [(0, pad_width)]
133
+ out = jnp.pad(out, pad_config)
134
+
135
+ return out.astype(inputs.dtype)
136
+
137
+
138
+ def apply_longrope(
139
+ inputs: jax.Array,
140
+ positions: jax.Array,
141
+ head_dim: int,
142
+ rope_scaling: Dict[str, Any],
143
+ original_max_position_embeddings: int,
144
+ max_position_embeddings: int,
145
+ rope_theta: float = 10000,
146
+ ) -> jax.Array:
147
+ # LongRoPE implementation specific to Phi-3
148
+ # Implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py#L197-L235
149
+
150
+ scale = max_position_embeddings / original_max_position_embeddings
151
+ if scale <= 1.0:
152
+ mscale = 1.0
153
+ else:
154
+ mscale = jnp.sqrt(1 + (jnp.log(scale) /
155
+ jnp.log(original_max_position_embeddings)))
156
+
157
+ seq_len = inputs.shape[0]
158
+ if seq_len > original_max_position_embeddings:
159
+ long_factor = jnp.array(rope_scaling.get("long_factor"))
160
+ timescale = 1.0 / (long_factor * (rope_theta**(
161
+ (2 * jnp.arange(0, head_dim // 2)) / head_dim)))
162
+ else:
163
+ short_factor = jnp.array(rope_scaling.get("short_factor"))
164
+ timescale = 1.0 / (short_factor * (rope_theta**(
165
+ (2 * jnp.arange(0, head_dim // 2)) / head_dim)))
166
+
167
+ # Calculate RoPE positions
168
+ sinusoid_inp = positions[..., jnp.newaxis] * timescale[jnp.newaxis, :]
169
+ sinusoid_inp = sinusoid_inp[:, jnp.newaxis, ...]
170
+ sin = jnp.sin(sinusoid_inp) * mscale
171
+ cos = jnp.cos(sinusoid_inp) * mscale
172
+
173
+ # Padding logic
174
+ padded_head_dim = inputs.shape[-1]
175
+
176
+ # Apply RoPE mechanism
177
+ first_half = inputs[..., :head_dim // 2]
178
+ second_half = inputs[..., head_dim // 2:head_dim]
179
+
180
+ first_part = first_half * cos - second_half * sin
181
+ second_part = second_half * cos + first_half * sin
182
+ out = jnp.concatenate([first_part, second_part], axis=-1)
183
+
184
+ if padded_head_dim > head_dim:
185
+ out = jnp.pad(out, ((0, 0), (0, 0), (0, padded_head_dim - head_dim)))
186
+
187
+ return out.astype(inputs.dtype)
188
+
189
+
190
+ def apply_rope_scaling(freqs: jax.Array, rope_scaling: Dict[str,
191
+ Any]) -> jax.Array:
192
+ # Values obtained from grid search
193
+ scale_factor = rope_scaling.get("scale_factor", 8.0)
194
+ low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
195
+ high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
196
+ old_context_len = rope_scaling.get("original_max_position_embeddings",
197
+ 8192)
198
+
199
+ low_freq_wavelen = old_context_len / low_freq_factor
200
+ high_freq_wavelen = old_context_len / high_freq_factor
201
+
202
+ wavelen = 2 * math.pi / freqs
203
+ smooth = (old_context_len / wavelen -
204
+ low_freq_factor) / (high_freq_factor - low_freq_factor)
205
+
206
+ high_freqs = jnp.where(wavelen < high_freq_wavelen, freqs, 0)
207
+ low_freqs = jnp.where(wavelen > low_freq_wavelen, freqs / scale_factor, 0)
208
+ mid_freqs = jnp.where(
209
+ (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
210
+ (1 - smooth) * freqs / scale_factor + smooth * freqs,
211
+ 0,
212
+ )
213
+ new_freqs = high_freqs + low_freqs + mid_freqs
214
+ return new_freqs
File without changes