tpu-inference 0.11.1rc2__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +1 -1
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
from dataclasses import InitVar, dataclass
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from flax import nnx
|
|
6
|
+
from flax.typing import Sharding
|
|
7
|
+
from jaxtyping import Float
|
|
8
|
+
|
|
9
|
+
from tpu_inference.layers.jax.base import create_param
|
|
10
|
+
from tpu_inference.layers.jax.layers import FlaxUtils
|
|
11
|
+
|
|
12
|
+
modeling_flax_utils = FlaxUtils()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(kw_only=True)
|
|
16
|
+
class Router(nnx.Module):
|
|
17
|
+
"""Router module for Mixture-of-Experts (MoE) layers.
|
|
18
|
+
|
|
19
|
+
This module determines which experts each token should be routed to based on the input.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
"""
|
|
23
|
+
dtype: jnp.dtype
|
|
24
|
+
hidden_size: int
|
|
25
|
+
num_experts: int
|
|
26
|
+
num_experts_per_tok: int
|
|
27
|
+
router_act: str
|
|
28
|
+
rngs: InitVar[nnx.Rngs]
|
|
29
|
+
activation_ffw_td: Sharding
|
|
30
|
+
ed_sharding: Sharding
|
|
31
|
+
random_init: bool = False
|
|
32
|
+
|
|
33
|
+
def __call__(self, x_TD: Float):
|
|
34
|
+
"""Routes tokens to experts.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
x_TD: Input array of shape (sequence_length, d_model).
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A tuple containing:
|
|
41
|
+
- normalized_weights_TX: Normalized weights for selected experts, shape (sequence_length, num_experts_per_tok).
|
|
42
|
+
- selected_experts_TX: Indices of selected experts, shape (sequence_length, num_experts_per_tok).
|
|
43
|
+
"""
|
|
44
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
45
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
46
|
+
router_act = modeling_flax_utils.ACT2FN[self.router_act]
|
|
47
|
+
router_logits_TE = jnp.einsum('TD,DE -> TE', x_TD,
|
|
48
|
+
self.kernel_DE.value)
|
|
49
|
+
weights_TX, selected_experts_TX = jax.lax.top_k(
|
|
50
|
+
router_logits_TE, self.num_experts_per_tok)
|
|
51
|
+
if self.router_act != "sigmoid": # sigmoid does not accept axis argument.
|
|
52
|
+
normalized_weights_TX = router_act(weights_TX.astype(self.dtype),
|
|
53
|
+
axis=-1)
|
|
54
|
+
else:
|
|
55
|
+
normalized_weights_TX = router_act(weights_TX.astype(self.dtype))
|
|
56
|
+
return normalized_weights_TX, selected_experts_TX
|
|
57
|
+
|
|
58
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
59
|
+
"""Generates the router kernel (weights) for routing."""
|
|
60
|
+
shape = (self.hidden_size, self.num_experts)
|
|
61
|
+
self.kernel_DE = create_param(rngs,
|
|
62
|
+
shape=shape,
|
|
63
|
+
dtype=self.dtype,
|
|
64
|
+
sharding=self.ed_sharding,
|
|
65
|
+
random_init=self.random_init)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass(kw_only=True)
|
|
69
|
+
class MoE(nnx.Module):
|
|
70
|
+
"""Mixture-of-Experts (MoE) Routed MLP Layer.
|
|
71
|
+
|
|
72
|
+
This module implements a MoE layer with a router and multiple expert MLPs.
|
|
73
|
+
|
|
74
|
+
Attributes:
|
|
75
|
+
router: The Router module.
|
|
76
|
+
"""
|
|
77
|
+
dtype: jnp.dtype
|
|
78
|
+
num_local_experts: int
|
|
79
|
+
apply_expert_weight_before_computation: bool
|
|
80
|
+
hidden_size: int
|
|
81
|
+
intermediate_size_moe: int
|
|
82
|
+
hidden_act: str
|
|
83
|
+
rngs: InitVar[nnx.Rngs]
|
|
84
|
+
router: nnx.Module
|
|
85
|
+
activation_ffw_td: Sharding
|
|
86
|
+
activation_ffw_ted: Sharding
|
|
87
|
+
edf_sharding: Sharding
|
|
88
|
+
efd_sharding: Sharding
|
|
89
|
+
random_init: bool = False
|
|
90
|
+
|
|
91
|
+
def __call__(self, x_TD: Float):
|
|
92
|
+
"""Performs the forward pass of the MoE layer.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
x_TD: Input array of shape (sequence_length, d_model).
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Output array of shape (sequence_length, d_model) after passing through MoE.
|
|
99
|
+
"""
|
|
100
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
101
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
102
|
+
weights_TX, indices_TX = self.router(x_TD)
|
|
103
|
+
one_hot_indices_TXE = jax.nn.one_hot(
|
|
104
|
+
indices_TX, num_classes=self.num_local_experts, dtype=self.dtype)
|
|
105
|
+
full_weights_TE = jnp.sum(one_hot_indices_TXE * weights_TX[..., None],
|
|
106
|
+
axis=1)
|
|
107
|
+
|
|
108
|
+
# Some models use the routing scores to weight the data instead of
|
|
109
|
+
# weighting the expert outputs.
|
|
110
|
+
if self.apply_expert_weight_before_computation:
|
|
111
|
+
with jax.named_scope("pre_computing_weight"):
|
|
112
|
+
return self._moe_fwd_preapply_router_weights(
|
|
113
|
+
x_TD, full_weights_TE)
|
|
114
|
+
else:
|
|
115
|
+
return self._moe_fwd(x_TD, full_weights_TE)
|
|
116
|
+
|
|
117
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
118
|
+
"""Generates the kernels (weights) for the router and experts (gating, up-projection, and down-projection layers)."""
|
|
119
|
+
|
|
120
|
+
D = self.hidden_size
|
|
121
|
+
F = self.intermediate_size_moe
|
|
122
|
+
shape_gating = (self.num_local_experts, D, F)
|
|
123
|
+
shape_up = (self.num_local_experts, D, F)
|
|
124
|
+
shape_down = (self.num_local_experts, F, D)
|
|
125
|
+
|
|
126
|
+
self.kernel_gating_EDF = create_param(rngs,
|
|
127
|
+
shape=shape_gating,
|
|
128
|
+
dtype=self.dtype,
|
|
129
|
+
sharding=self.edf_sharding,
|
|
130
|
+
random_init=self.random_init)
|
|
131
|
+
self.kernel_up_proj_EDF = create_param(rngs,
|
|
132
|
+
shape=shape_up,
|
|
133
|
+
dtype=self.dtype,
|
|
134
|
+
sharding=self.edf_sharding,
|
|
135
|
+
random_init=self.random_init)
|
|
136
|
+
self.kernel_down_proj_EFD = create_param(rngs,
|
|
137
|
+
shape=shape_down,
|
|
138
|
+
dtype=self.dtype,
|
|
139
|
+
sharding=self.efd_sharding,
|
|
140
|
+
random_init=self.random_init)
|
|
141
|
+
|
|
142
|
+
def _moe_fwd_preapply_router_weights(self, x_TD: jax.Array, weights_TE):
|
|
143
|
+
"""Performs the forward pass of the MoE experts with router weights pre-applied to the inputs.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
x_TD: Input array for the experts, shape (sequence_length, hidden_size).
|
|
147
|
+
weights_TE: Router weights, shape (sequence_length, num_experts).
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Output array of shape (sequence_length, d_model).
|
|
151
|
+
"""
|
|
152
|
+
# Data needs to be replicated since it will be weighted by the router
|
|
153
|
+
# scores before being passed to each expert.
|
|
154
|
+
num_experts = weights_TE.shape[-1]
|
|
155
|
+
x_TED = jnp.repeat(x_TD[:, None, :], num_experts, 1)
|
|
156
|
+
weights_TED = weights_TE[..., None]
|
|
157
|
+
x_TED = jnp.asarray(x_TED, self.dtype)
|
|
158
|
+
|
|
159
|
+
with jax.named_scope("activation_expert_weighting"):
|
|
160
|
+
x_TED = x_TED * weights_TED
|
|
161
|
+
|
|
162
|
+
x_TED = nnx.with_sharding_constraint(x_TED, self.activation_ffw_ted)
|
|
163
|
+
with jax.named_scope("gating"):
|
|
164
|
+
gating_TEF = jnp.einsum('TED,EDF -> TEF', x_TED,
|
|
165
|
+
self.kernel_gating_EDF.value)
|
|
166
|
+
activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
|
|
167
|
+
gating_TEF)
|
|
168
|
+
with jax.named_scope("up_projection"):
|
|
169
|
+
up_proj_TEF = jnp.einsum('TED,EDF -> TEF', x_TED,
|
|
170
|
+
self.kernel_up_proj_EDF.value)
|
|
171
|
+
|
|
172
|
+
fuse_TEF = activated_gating_TEF * up_proj_TEF
|
|
173
|
+
|
|
174
|
+
with jax.named_scope("down_projection"):
|
|
175
|
+
down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
|
|
176
|
+
self.kernel_down_proj_EFD.value)
|
|
177
|
+
with jax.named_scope("sum"):
|
|
178
|
+
output_TD = down_proj_TED.sum(axis=1)
|
|
179
|
+
return output_TD.astype(self.dtype)
|
|
180
|
+
|
|
181
|
+
def _moe_fwd(self, x_TD: Float, weights):
|
|
182
|
+
"""Performs the basic forward pass of the MoE experts without dropping or megablocks.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
x_TD: Input array for the experts, shape (sequence_length, d_model).
|
|
186
|
+
weights: Weights for combining expert outputs, shape (sequence_length, num_experts).
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Output array of shape (sequence_length, d_model).
|
|
190
|
+
"""
|
|
191
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
192
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
193
|
+
with jax.named_scope("gating"):
|
|
194
|
+
gating_TEF = jnp.einsum('TD,EDF -> TEF', x_TD,
|
|
195
|
+
self.kernel_gating_EDF.value)
|
|
196
|
+
activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
|
|
197
|
+
gating_TEF)
|
|
198
|
+
with jax.named_scope("up_projection"):
|
|
199
|
+
up_proj_TEF = jnp.einsum('TD,EDF -> TEF', x_TD,
|
|
200
|
+
self.kernel_up_proj_EDF.value)
|
|
201
|
+
|
|
202
|
+
fuse_TEF = activated_gating_TEF * up_proj_TEF
|
|
203
|
+
|
|
204
|
+
with jax.named_scope("down_projection"):
|
|
205
|
+
down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
|
|
206
|
+
self.kernel_down_proj_EFD.value)
|
|
207
|
+
with jax.named_scope("sum"):
|
|
208
|
+
output_TD = jnp.einsum('TED,TE -> TD', down_proj_TED, weights)
|
|
209
|
+
return output_TD.astype(self.dtype)
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
from flax import nnx
|
|
7
|
+
from jax import numpy as jnp
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(kw_only=True)
|
|
11
|
+
class RotaryEmbedding(nnx.Module):
|
|
12
|
+
"""
|
|
13
|
+
An implementation of the original rotary positional embedding.
|
|
14
|
+
"""
|
|
15
|
+
rotary_dim: int
|
|
16
|
+
rope_theta: float
|
|
17
|
+
original_max_position_embeddings: int
|
|
18
|
+
dtype: jnp.dtype
|
|
19
|
+
sin_cos_cache: Optional[jax.Array] = field(init=False, default=None)
|
|
20
|
+
|
|
21
|
+
def initialize_cache(self):
|
|
22
|
+
"""Computes and caches the sin/cos embeddings."""
|
|
23
|
+
if self.sin_cos_cache is None:
|
|
24
|
+
self.sin_cos_cache = self._compute_sin_cos()
|
|
25
|
+
|
|
26
|
+
def _compute_inv_freq(self):
|
|
27
|
+
fractions_H = jnp.arange(0, self.rotary_dim, 2,
|
|
28
|
+
dtype=jnp.float32) / self.rotary_dim
|
|
29
|
+
inv_freq_H = 1.0 / (self.rope_theta**fractions_H)
|
|
30
|
+
return inv_freq_H
|
|
31
|
+
|
|
32
|
+
def _compute_sin_cos(self):
|
|
33
|
+
inv_freq_H = self._compute_inv_freq()
|
|
34
|
+
t = jnp.arange(self.original_max_position_embeddings,
|
|
35
|
+
dtype=jnp.float32)
|
|
36
|
+
|
|
37
|
+
freqs = jnp.einsum("...T,k->...Tk",
|
|
38
|
+
t,
|
|
39
|
+
inv_freq_H,
|
|
40
|
+
precision=jax.lax.Precision.HIGHEST)
|
|
41
|
+
sin, cos = jnp.sin(freqs), jnp.cos(freqs)
|
|
42
|
+
cache = jnp.concatenate((cos, sin), axis=-1)
|
|
43
|
+
return cache
|
|
44
|
+
|
|
45
|
+
def apply_rope(self, positions: jax.Array, x_TNH: jax.Array):
|
|
46
|
+
assert x_TNH.ndim == 3
|
|
47
|
+
assert self.sin_cos_cache is not None, "RoPE cache not initialized."
|
|
48
|
+
cos_sin_TH = self.sin_cos_cache[positions]
|
|
49
|
+
# cos, sin: (T, H/2)
|
|
50
|
+
cos_TH, sin_TH = jnp.split(cos_sin_TH, 2, axis=-1)
|
|
51
|
+
assert sin_TH.ndim == 2 and cos_TH.ndim == 2
|
|
52
|
+
# cos, sin: (T, 1, H/2)
|
|
53
|
+
cos_T1H, sin_T1H = cos_TH[:, None, :], sin_TH[:, None, :]
|
|
54
|
+
# first_half, second_half: (T, N, H/2)
|
|
55
|
+
first_half_TNH, second_half_TNH = jnp.split(x_TNH, 2, axis=-1)
|
|
56
|
+
combined = jnp.concatenate([
|
|
57
|
+
first_half_TNH * cos_T1H - second_half_TNH * sin_T1H,
|
|
58
|
+
second_half_TNH * cos_T1H + first_half_TNH * sin_T1H
|
|
59
|
+
],
|
|
60
|
+
axis=-1)
|
|
61
|
+
return combined.astype(self.dtype)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass(kw_only=True)
|
|
65
|
+
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
66
|
+
"""
|
|
67
|
+
Rotary Embedding for deepseek, with scaling and YaRN method.
|
|
68
|
+
"""
|
|
69
|
+
scaling_factor: float
|
|
70
|
+
beta_fast: int = 32
|
|
71
|
+
beta_slow: int = 1
|
|
72
|
+
mscale_value: float = 1
|
|
73
|
+
mscale_all_dim: float = 0
|
|
74
|
+
|
|
75
|
+
def initialize_cache(self):
|
|
76
|
+
"""Computes and caches the sin/cos embeddings."""
|
|
77
|
+
# The second condition is for the Qwix case, where we need to call `initialize_cache` on
|
|
78
|
+
# the abstract model. Thus, when we go to call `initialize_cache` on the concrete model,
|
|
79
|
+
# this method will have been called already, but we need to recompute the cache so that
|
|
80
|
+
# it's concrete (otherwise, it'll still be a jax.ShapeDtypeStruct).
|
|
81
|
+
if self.sin_cos_cache is not None and not isinstance(
|
|
82
|
+
self.sin_cos_cache, jax.ShapeDtypeStruct):
|
|
83
|
+
return
|
|
84
|
+
self.mscale = _yarn_get_mscale(
|
|
85
|
+
self.scaling_factor, self.mscale_value) / _yarn_get_mscale(
|
|
86
|
+
self.scaling_factor, self.mscale_all_dim)
|
|
87
|
+
self.sin_cos_cache = self._compute_sin_cos()
|
|
88
|
+
|
|
89
|
+
def _compute_inv_freq(self):
|
|
90
|
+
fractions = jnp.arange(0, self.rotary_dim, 2,
|
|
91
|
+
dtype=jnp.float32) / self.rotary_dim
|
|
92
|
+
inv_freq_extrapolation = 1.0 / (self.rope_theta**fractions)
|
|
93
|
+
inv_freq_interpolation = 1.0 / (self.scaling_factor *
|
|
94
|
+
self.rope_theta**fractions)
|
|
95
|
+
low, high = _yarn_find_correction_range(
|
|
96
|
+
self.beta_fast, self.beta_slow, self.rotary_dim, self.rope_theta,
|
|
97
|
+
self.original_max_position_embeddings)
|
|
98
|
+
|
|
99
|
+
# Get n-d rotational scaling corrected for extrapolation
|
|
100
|
+
inv_freq_mask = 1 - _yarn_linear_ramp_mask(
|
|
101
|
+
low, high, self.rotary_dim // 2).astype(jnp.float32)
|
|
102
|
+
inv_freq = inv_freq_interpolation * (
|
|
103
|
+
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
|
104
|
+
return inv_freq
|
|
105
|
+
|
|
106
|
+
def _compute_sin_cos(self):
|
|
107
|
+
inv_freq_H = self._compute_inv_freq()
|
|
108
|
+
t = jnp.arange(self.original_max_position_embeddings *
|
|
109
|
+
self.scaling_factor,
|
|
110
|
+
dtype=jnp.float32)
|
|
111
|
+
freqs = jnp.einsum("...T,k->...Tk", t, inv_freq_H)
|
|
112
|
+
sin, cos = jnp.sin(freqs) * self.mscale, jnp.cos(freqs) * self.mscale
|
|
113
|
+
cache = jnp.concatenate((cos, sin), axis=-1)
|
|
114
|
+
return cache
|
|
115
|
+
|
|
116
|
+
def apply_rope(self, positions: jax.Array, x_TNH: jax.Array):
|
|
117
|
+
assert x_TNH.ndim == 3
|
|
118
|
+
assert self.sin_cos_cache is not None, "RoPE cache not initialized."
|
|
119
|
+
cos_sin_TH = self.sin_cos_cache[positions]
|
|
120
|
+
# cos, sin: (T, H/2)
|
|
121
|
+
cos_TH, sin_TH = jnp.split(cos_sin_TH, 2, axis=-1)
|
|
122
|
+
assert sin_TH.ndim == 2 and cos_TH.ndim == 2
|
|
123
|
+
# cos, sin: (T, 1, H/2)
|
|
124
|
+
cos_T1H, sin_T1H = cos_TH[:, None, :], sin_TH[:, None, :]
|
|
125
|
+
# even, odd: (T, N, H/2)
|
|
126
|
+
even_TNH, odd_TNH = x_TNH[..., ::2], x_TNH[..., 1::2]
|
|
127
|
+
combined_TNH = jnp.stack([
|
|
128
|
+
even_TNH * cos_T1H - odd_TNH * sin_T1H,
|
|
129
|
+
odd_TNH * cos_T1H + even_TNH * sin_T1H
|
|
130
|
+
],
|
|
131
|
+
axis=-1).reshape(x_TNH.shape)
|
|
132
|
+
return combined_TNH.astype(self.dtype)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# Calculates the temperature scaling factor for YaRN to adjust
|
|
136
|
+
# RoPE embedding magnitudes.
|
|
137
|
+
def _yarn_get_mscale(scale, mscale):
|
|
138
|
+
return jnp.where(scale <= 1, 1.0, 0.1 * mscale * jnp.log(scale) + 1.0)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# Inverses dim formula to find dim based on number of rotations.
|
|
142
|
+
def _yarn_find_correction_dim(num_rotations,
|
|
143
|
+
dim,
|
|
144
|
+
base=10000,
|
|
145
|
+
max_position_embeddings=2048):
|
|
146
|
+
return (dim * math.log(max_position_embeddings /
|
|
147
|
+
(num_rotations * 2 * math.pi))) / (2 *
|
|
148
|
+
math.log(base))
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# Finds dim range bounds based on rotations.
|
|
152
|
+
def _yarn_find_correction_range(low_rot,
|
|
153
|
+
high_rot,
|
|
154
|
+
dim,
|
|
155
|
+
base=10000,
|
|
156
|
+
max_position_embeddings=2048):
|
|
157
|
+
low = math.floor(
|
|
158
|
+
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
|
159
|
+
high = math.ceil(
|
|
160
|
+
_yarn_find_correction_dim(high_rot, dim, base,
|
|
161
|
+
max_position_embeddings))
|
|
162
|
+
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# Creates a 1D mask that ramps linearly from 0 to 1 between min and max indices.
|
|
166
|
+
def _yarn_linear_ramp_mask(min, max, dim):
|
|
167
|
+
if min == max:
|
|
168
|
+
max += 0.001 # Prevent singularity
|
|
169
|
+
|
|
170
|
+
linear_func = (jnp.arange(dim, dtype=jnp.float32) - min) / (max - min)
|
|
171
|
+
ramp_func = jnp.clip(linear_func, 0, 1)
|
|
172
|
+
return ramp_func
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def apply_rope(
|
|
9
|
+
# (seq_len, num_heads, head_dim)
|
|
10
|
+
inputs: jax.Array,
|
|
11
|
+
# (3, seq_len) for M-RoPE, otherwise (seq_len,)
|
|
12
|
+
positions: jax.Array,
|
|
13
|
+
head_dim: int,
|
|
14
|
+
rope_theta: float = 10000,
|
|
15
|
+
rope_scaling: Dict[str, Any] = None,
|
|
16
|
+
rope_input_ordering: str = "split",
|
|
17
|
+
) -> jax.Array:
|
|
18
|
+
"""
|
|
19
|
+
Applies Rotary Positional Embedding using the sine and cosine strategy.
|
|
20
|
+
|
|
21
|
+
This implementation assumes the input tensor has a shape that might include
|
|
22
|
+
padding on the last dimension (head_dim).
|
|
23
|
+
RoPE is applied only to the first `head_dim` features, and the result is
|
|
24
|
+
padded back to the original dimension if necessary.
|
|
25
|
+
If rope_input_ordering is "split", then the input pairs for rotation are taken one from the
|
|
26
|
+
first and one from the second half of the head_dim. If it is "interleaved" then
|
|
27
|
+
adjacent values are used as inputs for rotation.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
# M-RoPE support for Qwen2.5-VL
|
|
31
|
+
if positions.ndim == 2 and positions.shape[0] == 3:
|
|
32
|
+
mrope_section = rope_scaling.get("mrope_section",
|
|
33
|
+
None) if rope_scaling else None
|
|
34
|
+
# NOTE: We assume mrope_section is always available
|
|
35
|
+
# as Qwen2.5-VL is the only model using mrope
|
|
36
|
+
assert mrope_section is not None
|
|
37
|
+
|
|
38
|
+
split_indices = [mrope_section[0], mrope_section[0] + mrope_section[1]]
|
|
39
|
+
|
|
40
|
+
# Indices for the features to be rotated (first half of head_dim)
|
|
41
|
+
all_freq_indices = jnp.arange(head_dim // 2)
|
|
42
|
+
|
|
43
|
+
# Split the indices according to mrope_section. This is valid because split_indices are static.
|
|
44
|
+
freq_indices_split = jnp.split(all_freq_indices, split_indices)
|
|
45
|
+
# freq_indices_split is a list of 3 JAX arrays.
|
|
46
|
+
|
|
47
|
+
cos_list = []
|
|
48
|
+
sin_list = []
|
|
49
|
+
|
|
50
|
+
for i in range(3): # For each of the 3 position dimensions
|
|
51
|
+
current_indices = freq_indices_split[i]
|
|
52
|
+
|
|
53
|
+
if current_indices.size == 0:
|
|
54
|
+
# This section is empty, skip.
|
|
55
|
+
continue
|
|
56
|
+
|
|
57
|
+
# inv_freq shape: (mrope_section[i],)
|
|
58
|
+
inv_freq = 1.0 / (rope_theta**(current_indices * 2.0 / head_dim))
|
|
59
|
+
|
|
60
|
+
# positions[i]: (seq_len,)
|
|
61
|
+
# freqs shape: (seq_len, mrope_section[i])
|
|
62
|
+
freqs = jnp.outer(positions[i], inv_freq)
|
|
63
|
+
|
|
64
|
+
cos_list.append(jnp.cos(freqs))
|
|
65
|
+
sin_list.append(jnp.sin(freqs))
|
|
66
|
+
|
|
67
|
+
# Concatenate along the feature dimension
|
|
68
|
+
# cos, sin shape: (seq_len, head_dim//2)
|
|
69
|
+
cos = jnp.concatenate(cos_list, axis=1)
|
|
70
|
+
sin = jnp.concatenate(sin_list, axis=1)
|
|
71
|
+
|
|
72
|
+
# Add num_heads dimension for broadcasting
|
|
73
|
+
cos = cos[:, jnp.newaxis, :] # Shape: (seq_len, 1, head_dim//2)
|
|
74
|
+
sin = sin[:, jnp.newaxis, :] # Shape: (seq_len, 1, head_dim//2)
|
|
75
|
+
|
|
76
|
+
# Apply rotation
|
|
77
|
+
inputs_real = inputs[..., :head_dim // 2]
|
|
78
|
+
inputs_imag = inputs[..., head_dim // 2:head_dim]
|
|
79
|
+
|
|
80
|
+
outputs_real = inputs_real * cos - inputs_imag * sin
|
|
81
|
+
outputs_imag = inputs_real * sin + inputs_imag * cos
|
|
82
|
+
|
|
83
|
+
out = jnp.concatenate([outputs_real, outputs_imag], axis=-1)
|
|
84
|
+
|
|
85
|
+
# Standard RoPE
|
|
86
|
+
else:
|
|
87
|
+
# Calculate inverse frequencies (timescale)
|
|
88
|
+
fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
|
|
89
|
+
timescale = 1.0 / (rope_theta**fraction)
|
|
90
|
+
|
|
91
|
+
# Apply scaling if provided
|
|
92
|
+
if rope_scaling:
|
|
93
|
+
timescale = apply_rope_scaling(timescale, rope_scaling)
|
|
94
|
+
|
|
95
|
+
# Prepare for rotation by calculating sin and cos values
|
|
96
|
+
# `sinusoid_inp` gets shape (batch * seq_len, head_dim/2)
|
|
97
|
+
sinusoid_inp = positions[..., jnp.newaxis] * timescale[jnp.newaxis, :]
|
|
98
|
+
|
|
99
|
+
# Broadcast over the 'heads' dimension, assuming shape (batch*seq, heads, head_dim)
|
|
100
|
+
sinusoid_inp = sinusoid_inp[:, jnp.newaxis, ...]
|
|
101
|
+
sin = jnp.sin(sinusoid_inp)
|
|
102
|
+
cos = jnp.cos(sinusoid_inp)
|
|
103
|
+
|
|
104
|
+
if rope_input_ordering == "interleaved":
|
|
105
|
+
# Reshape to group adjacent features for rotation, matching new_apply_rope
|
|
106
|
+
rotary_inputs = inputs[
|
|
107
|
+
..., :head_dim] # Take just the non-padded amount.
|
|
108
|
+
reshaped_inputs = rotary_inputs.reshape(*rotary_inputs.shape[:-1],
|
|
109
|
+
-1, 2)
|
|
110
|
+
|
|
111
|
+
# Apply the rotation
|
|
112
|
+
first_half = reshaped_inputs[..., 0]
|
|
113
|
+
second_half = reshaped_inputs[..., 1]
|
|
114
|
+
else:
|
|
115
|
+
first_half = inputs[..., :head_dim // 2]
|
|
116
|
+
second_half = inputs[..., head_dim // 2:head_dim]
|
|
117
|
+
|
|
118
|
+
first_part = first_half * cos - second_half * sin
|
|
119
|
+
second_part = second_half * cos + first_half * sin
|
|
120
|
+
|
|
121
|
+
# Combine the rotated parts and reshape back
|
|
122
|
+
if rope_input_ordering == "interleaved":
|
|
123
|
+
out_stacked = jnp.stack([first_part, second_part], axis=-1)
|
|
124
|
+
out = out_stacked.reshape(rotary_inputs.shape)
|
|
125
|
+
else:
|
|
126
|
+
out = jnp.concatenate([first_part, second_part], axis=-1)
|
|
127
|
+
|
|
128
|
+
# If the original input was padded, pad the output with zeros to match.
|
|
129
|
+
padded_head_dim = inputs.shape[-1]
|
|
130
|
+
if padded_head_dim > head_dim:
|
|
131
|
+
pad_width = padded_head_dim - head_dim
|
|
132
|
+
pad_config = [(0, 0)] * (out.ndim - 1) + [(0, pad_width)]
|
|
133
|
+
out = jnp.pad(out, pad_config)
|
|
134
|
+
|
|
135
|
+
return out.astype(inputs.dtype)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def apply_longrope(
|
|
139
|
+
inputs: jax.Array,
|
|
140
|
+
positions: jax.Array,
|
|
141
|
+
head_dim: int,
|
|
142
|
+
rope_scaling: Dict[str, Any],
|
|
143
|
+
original_max_position_embeddings: int,
|
|
144
|
+
max_position_embeddings: int,
|
|
145
|
+
rope_theta: float = 10000,
|
|
146
|
+
) -> jax.Array:
|
|
147
|
+
# LongRoPE implementation specific to Phi-3
|
|
148
|
+
# Implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py#L197-L235
|
|
149
|
+
|
|
150
|
+
scale = max_position_embeddings / original_max_position_embeddings
|
|
151
|
+
if scale <= 1.0:
|
|
152
|
+
mscale = 1.0
|
|
153
|
+
else:
|
|
154
|
+
mscale = jnp.sqrt(1 + (jnp.log(scale) /
|
|
155
|
+
jnp.log(original_max_position_embeddings)))
|
|
156
|
+
|
|
157
|
+
seq_len = inputs.shape[0]
|
|
158
|
+
if seq_len > original_max_position_embeddings:
|
|
159
|
+
long_factor = jnp.array(rope_scaling.get("long_factor"))
|
|
160
|
+
timescale = 1.0 / (long_factor * (rope_theta**(
|
|
161
|
+
(2 * jnp.arange(0, head_dim // 2)) / head_dim)))
|
|
162
|
+
else:
|
|
163
|
+
short_factor = jnp.array(rope_scaling.get("short_factor"))
|
|
164
|
+
timescale = 1.0 / (short_factor * (rope_theta**(
|
|
165
|
+
(2 * jnp.arange(0, head_dim // 2)) / head_dim)))
|
|
166
|
+
|
|
167
|
+
# Calculate RoPE positions
|
|
168
|
+
sinusoid_inp = positions[..., jnp.newaxis] * timescale[jnp.newaxis, :]
|
|
169
|
+
sinusoid_inp = sinusoid_inp[:, jnp.newaxis, ...]
|
|
170
|
+
sin = jnp.sin(sinusoid_inp) * mscale
|
|
171
|
+
cos = jnp.cos(sinusoid_inp) * mscale
|
|
172
|
+
|
|
173
|
+
# Padding logic
|
|
174
|
+
padded_head_dim = inputs.shape[-1]
|
|
175
|
+
|
|
176
|
+
# Apply RoPE mechanism
|
|
177
|
+
first_half = inputs[..., :head_dim // 2]
|
|
178
|
+
second_half = inputs[..., head_dim // 2:head_dim]
|
|
179
|
+
|
|
180
|
+
first_part = first_half * cos - second_half * sin
|
|
181
|
+
second_part = second_half * cos + first_half * sin
|
|
182
|
+
out = jnp.concatenate([first_part, second_part], axis=-1)
|
|
183
|
+
|
|
184
|
+
if padded_head_dim > head_dim:
|
|
185
|
+
out = jnp.pad(out, ((0, 0), (0, 0), (0, padded_head_dim - head_dim)))
|
|
186
|
+
|
|
187
|
+
return out.astype(inputs.dtype)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def apply_rope_scaling(freqs: jax.Array, rope_scaling: Dict[str,
|
|
191
|
+
Any]) -> jax.Array:
|
|
192
|
+
# Values obtained from grid search
|
|
193
|
+
scale_factor = rope_scaling.get("scale_factor", 8.0)
|
|
194
|
+
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
|
|
195
|
+
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
|
|
196
|
+
old_context_len = rope_scaling.get("original_max_position_embeddings",
|
|
197
|
+
8192)
|
|
198
|
+
|
|
199
|
+
low_freq_wavelen = old_context_len / low_freq_factor
|
|
200
|
+
high_freq_wavelen = old_context_len / high_freq_factor
|
|
201
|
+
|
|
202
|
+
wavelen = 2 * math.pi / freqs
|
|
203
|
+
smooth = (old_context_len / wavelen -
|
|
204
|
+
low_freq_factor) / (high_freq_factor - low_freq_factor)
|
|
205
|
+
|
|
206
|
+
high_freqs = jnp.where(wavelen < high_freq_wavelen, freqs, 0)
|
|
207
|
+
low_freqs = jnp.where(wavelen > low_freq_wavelen, freqs / scale_factor, 0)
|
|
208
|
+
mid_freqs = jnp.where(
|
|
209
|
+
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
|
|
210
|
+
(1 - smooth) * freqs / scale_factor + smooth * freqs,
|
|
211
|
+
0,
|
|
212
|
+
)
|
|
213
|
+
new_freqs = high_freqs + low_freqs + mid_freqs
|
|
214
|
+
return new_freqs
|
|
File without changes
|